Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| de1a317890 | |||
| d64ee2e1cc | |||
| 13281d8235 |
@@ -18,35 +18,35 @@ The 5.x era resolves years of module system ambiguity and cleans house on legacy
|
||||
|
||||
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}()\[\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
| Old pattern | Modern replacement | Since |
|
||||
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
|
||||
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
|
||||
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)` — `const` modifier on type parameter | 5.0 |
|
||||
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
|
||||
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
|
||||
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
|
||||
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
|
||||
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
|
||||
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
|
||||
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
|
||||
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
|
||||
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
|
||||
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
|
||||
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
|
||||
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
|
||||
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
|
||||
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
|
||||
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
|
||||
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
|
||||
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
|
||||
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
|
||||
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
|
||||
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
|
||||
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
|
||||
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
|
||||
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
|
||||
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
|
||||
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
|
||||
@@ -91,6 +91,12 @@ updates:
|
||||
emotion:
|
||||
patterns:
|
||||
- "@emotion*"
|
||||
exclude-patterns:
|
||||
- "jest-runner-eslint"
|
||||
jest:
|
||||
patterns:
|
||||
- "jest"
|
||||
- "@types/jest"
|
||||
vite:
|
||||
patterns:
|
||||
- "vite*"
|
||||
|
||||
@@ -84,7 +84,6 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -140,7 +139,6 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -173,6 +171,4 @@ jobs:
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
--body "$BODY"
|
||||
|
||||
@@ -42,7 +42,6 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -117,7 +116,6 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -134,19 +132,8 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NEW_PR_URL=$(
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
)
|
||||
|
||||
# Comment on the original PR to notify the author.
|
||||
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
|
||||
fi
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT"
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -91,59 +91,6 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -254,7 +201,6 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -708,8 +654,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -747,8 +693,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -788,8 +734,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -1003,8 +949,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1121,88 +1067,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
@@ -2862,126 +2862,6 @@ func TestAPI(t *testing.T) {
|
||||
"rebuilt agent should include updated display apps")
|
||||
})
|
||||
|
||||
// Verify that when a terraform-managed subagent is injected into
|
||||
// a devcontainer, the Directory field sent to Create reflects
|
||||
// the container-internal workspaceFolder from devcontainer
|
||||
// read-configuration, not the host-side workspace_folder from
|
||||
// the terraform resource. This is the scenario described in
|
||||
// https://linear.app/codercom/issue/PRODUCT-259:
|
||||
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
|
||||
// 2. Terraform subagent → directory was stuck on host path (bug)
|
||||
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitMedium)
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
mCtrl = gomock.NewController(t)
|
||||
|
||||
terraformAgentID = uuid.New()
|
||||
containerID = "test-container-id"
|
||||
|
||||
// Given: A container with a host-side workspace folder.
|
||||
terraformContainer = codersdk.WorkspaceAgentContainer{
|
||||
ID: containerID,
|
||||
FriendlyName: "test-container",
|
||||
Image: "test-image",
|
||||
Running: true,
|
||||
CreatedAt: time.Now(),
|
||||
Labels: map[string]string{
|
||||
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
|
||||
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
},
|
||||
}
|
||||
|
||||
// Given: A terraform-defined devcontainer whose
|
||||
// workspace_folder is the HOST-side path (set by provisioner).
|
||||
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
|
||||
ID: uuid.New(),
|
||||
Name: "terraform-devcontainer",
|
||||
WorkspaceFolder: "/home/coder/project",
|
||||
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
|
||||
}
|
||||
|
||||
fCCLI = &fakeContainerCLI{
|
||||
containers: codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
|
||||
},
|
||||
arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Given: devcontainer read-configuration returns the
|
||||
// CONTAINER-INTERNAL workspace folder.
|
||||
fDCCLI = &fakeDevcontainerCLI{
|
||||
upID: containerID,
|
||||
readConfig: agentcontainers.DevcontainerConfig{
|
||||
Workspace: agentcontainers.DevcontainerWorkspace{
|
||||
WorkspaceFolder: "/workspaces/project",
|
||||
},
|
||||
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
|
||||
Customizations: agentcontainers.DevcontainerMergedCustomizations{
|
||||
Coder: []agentcontainers.CoderCustomization{{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mSAC = acmock.NewMockSubAgentClient(mCtrl)
|
||||
createCalls = make(chan agentcontainers.SubAgent, 1)
|
||||
closed bool
|
||||
)
|
||||
|
||||
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
|
||||
|
||||
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
|
||||
agent.AuthToken = uuid.New()
|
||||
createCalls <- agent
|
||||
return agent, nil
|
||||
},
|
||||
).Times(1)
|
||||
|
||||
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
|
||||
assert.True(t, closed, "Delete should only be called after Close")
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
|
||||
api := agentcontainers.NewAPI(logger,
|
||||
agentcontainers.WithContainerCLI(fCCLI),
|
||||
agentcontainers.WithDevcontainerCLI(fDCCLI),
|
||||
agentcontainers.WithDevcontainers(
|
||||
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
|
||||
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
|
||||
),
|
||||
agentcontainers.WithSubAgentClient(mSAC),
|
||||
agentcontainers.WithSubAgentURL("test-subagent-url"),
|
||||
agentcontainers.WithWatcher(watcher.NewNoop()),
|
||||
)
|
||||
api.Start()
|
||||
defer func() {
|
||||
closed = true
|
||||
api.Close()
|
||||
}()
|
||||
|
||||
// When: The devcontainer is created (triggering injection).
|
||||
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: The subagent sent to Create has the correct
|
||||
// container-internal directory, not the host path.
|
||||
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
|
||||
assert.Equal(t, terraformAgentID, createdAgent.ID,
|
||||
"agent should use terraform-defined ID")
|
||||
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
|
||||
"directory should be the container-internal path from devcontainer "+
|
||||
"read-configuration, not the host-side workspace_folder")
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -134,33 +134,6 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,144 +23,18 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -172,18 +46,20 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -209,12 +85,15 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -227,13 +106,19 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -248,10 +133,17 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -272,9 +164,16 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -294,9 +193,16 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -309,47 +215,79 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -382,13 +320,17 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -398,13 +340,17 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -412,10 +358,14 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -435,10 +385,14 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
+1036
-1139
File diff suppressed because it is too large
Load Diff
@@ -98,21 +98,6 @@ message Manifest {
|
||||
repeated WorkspaceApp apps = 11;
|
||||
repeated WorkspaceAgentMetadata.Description metadata = 12;
|
||||
repeated WorkspaceAgentDevcontainer devcontainers = 17;
|
||||
repeated WorkspaceSecret secrets = 19;
|
||||
}
|
||||
|
||||
// WorkspaceSecret is a secret included in the agent manifest
|
||||
// for injection into a workspace.
|
||||
message WorkspaceSecret {
|
||||
// Environment variable name to inject (e.g. "GITHUB_TOKEN").
|
||||
// Empty string means this secret is not injected as an env var.
|
||||
string env_name = 1;
|
||||
// File path to write the secret value to (e.g.
|
||||
// "~/.aws/credentials"). Empty string means this secret is not
|
||||
// written to a file.
|
||||
string file_path = 2;
|
||||
// The decrypted secret value.
|
||||
bytes value = 3;
|
||||
}
|
||||
|
||||
message WorkspaceAgentDevcontainer {
|
||||
|
||||
@@ -5,9 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -622,11 +620,6 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -640,60 +633,10 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,17 +4,12 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -64,8 +59,6 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -194,15 +187,10 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -797,8 +785,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -859,8 +847,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -884,7 +872,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var videoParts [2][]byte
|
||||
var bodies [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -892,10 +880,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
}
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1016,8 +1004,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1124,8 +1112,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1140,8 +1128,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1247,166 +1235,3 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,11 +105,6 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,7 +56,6 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -384,20 +383,13 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -414,7 +406,6 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -428,7 +419,6 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -437,10 +427,9 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -510,35 +499,10 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
artifact := &RecordingArtifact{
|
||||
return &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -607,33 +571,18 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
// File already removed or inaccessible; drop entry.
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -654,14 +603,13 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -682,20 +630,13 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,7 +2,6 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -585,7 +584,6 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -668,66 +666,6 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
@@ -812,18 +750,12 @@ func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout to trigger the stop-all.
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Wait for the stop timer to be created, then release it.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Advance past the 15s stop timeout so the process is
|
||||
// forcibly killed. Without this the test depends on the real
|
||||
// shell handling SIGINT promptly, which is unreliable on
|
||||
// macOS CI runners (the flake in #1461).
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
|
||||
// The recording process should now be stopped.
|
||||
require.Eventually(t, func() bool {
|
||||
pd.mu.Lock()
|
||||
@@ -945,17 +877,11 @@ func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
|
||||
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
|
||||
|
||||
// Advance past idle timeout.
|
||||
clk.Advance(idleTimeout).MustWait(ctx)
|
||||
clk.Advance(idleTimeout)
|
||||
|
||||
// Each idle monitor goroutine serializes on p.mu, so the
|
||||
// second stop timer is only created after the first stop
|
||||
// completes. Advance past the 15s stop timeout after each
|
||||
// release so the process is forcibly killed instead of
|
||||
// depending on SIGINT (unreliable on macOS — see #1461).
|
||||
// Wait for both stop timers.
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.MustWait(ctx).MustRelease(ctx)
|
||||
clk.Advance(15 * time.Second).MustWait(ctx)
|
||||
stopTrap.Close()
|
||||
|
||||
// Both recordings should be stopped.
|
||||
|
||||
@@ -87,12 +87,6 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,29 +102,3 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
-194
@@ -1,194 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+5
-29
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -149,7 +148,6 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -712,7 +710,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1436,21 +1434,6 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1469,10 +1452,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1484,16 +1467,9 @@ func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,53 +159,6 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
@@ -768,10 +768,30 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
}
|
||||
options.Pubsub = ps
|
||||
options.ChatPubsub = ps
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(ps)
|
||||
}
|
||||
defer options.Pubsub.Close()
|
||||
chatPubsub, err := pubsub.NewBatching(
|
||||
ctx,
|
||||
logger.Named("chatd").Named("pubsub_batch"),
|
||||
ps,
|
||||
sqlDB,
|
||||
dbURL,
|
||||
pubsub.BatchingConfig{
|
||||
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
|
||||
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create chat pubsub batcher: %w", err)
|
||||
}
|
||||
options.ChatPubsub = chatPubsub
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(chatPubsub)
|
||||
}
|
||||
defer options.ChatPubsub.Close()
|
||||
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
|
||||
pubsubWatchdogTimeout = psWatchdog.Timeout()
|
||||
defer psWatchdog.Close()
|
||||
|
||||
+4
-6
@@ -69,17 +69,15 @@ var (
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
|
||||
@@ -516,23 +516,6 @@ func TestIsRetryableError(t *testing.T) {
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
Columns to display in table output.
|
||||
|
||||
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
|
||||
|
||||
@@ -58,8 +58,7 @@
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": "===========[workspace ID]===========",
|
||||
"workspace_name": "test-workspace",
|
||||
"workspace_build_transition": "start"
|
||||
"workspace_name": "test-workspace"
|
||||
},
|
||||
"logs_overflowed": false,
|
||||
"organization_name": "Coder"
|
||||
|
||||
-7
@@ -211,13 +211,6 @@ AI BRIDGE PROXY OPTIONS:
|
||||
certificates not trusted by the system. If not provided, the system
|
||||
certificate pool is used.
|
||||
|
||||
CHAT OPTIONS:
|
||||
Configure the background chat processing daemon.
|
||||
|
||||
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
|
||||
Force chat debug logging on for every chat, bypassing the runtime
|
||||
admin and user opt-in settings.
|
||||
|
||||
CLIENT OPTIONS:
|
||||
These options change the behavior of how clients interact with the Coder.
|
||||
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
|
||||
|
||||
-4
@@ -757,10 +757,6 @@ chat:
|
||||
# How many pending chats a worker should acquire per polling cycle.
|
||||
# (default: 10, type: int)
|
||||
acquireBatchSize: 10
|
||||
# Force chat debug logging on for every chat, bypassing the runtime admin and user
|
||||
# opt-in settings.
|
||||
# (default: false, type: bool)
|
||||
debugLoggingEnabled: false
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
|
||||
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
// An ID is only given in the request when it is a terraform-defined devcontainer
|
||||
// that has attached resources. These subagents are pre-provisioned by terraform
|
||||
// (the agent record already exists), so we update configurable fields like
|
||||
// display_apps and directory rather than creating a new agent.
|
||||
// display_apps rather than creating a new agent.
|
||||
if req.Id != nil {
|
||||
id, err := uuid.FromBytes(req.Id)
|
||||
if err != nil {
|
||||
@@ -97,16 +97,6 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
|
||||
}
|
||||
|
||||
if req.Directory != "" {
|
||||
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: id,
|
||||
Directory: req.Directory,
|
||||
UpdatedAt: createdAt,
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &agentproto.CreateSubAgentResponse{
|
||||
Agent: &agentproto.SubAgent{
|
||||
Name: subAgent.Name,
|
||||
|
||||
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// And: The database agent's name, architecture, and OS are unchanged.
|
||||
// And: The database agent's other fields are unchanged.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
|
||||
require.Equal(t, "/different/path", updatedAgent.Directory)
|
||||
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
|
||||
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
|
||||
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
|
||||
|
||||
@@ -1280,42 +1280,6 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OK_DirectoryUpdated",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
// Given: An existing child agent with a stale host-side
|
||||
// directory (as set by the provisioner at build time).
|
||||
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
|
||||
ResourceID: agent.ResourceID,
|
||||
Name: baseChildAgent.Name,
|
||||
Directory: "/home/coder/project",
|
||||
Architecture: baseChildAgent.Architecture,
|
||||
OperatingSystem: baseChildAgent.OperatingSystem,
|
||||
DisplayApps: baseChildAgent.DisplayApps,
|
||||
})
|
||||
|
||||
// When: Agent injection sends the correct
|
||||
// container-internal path.
|
||||
return &proto.CreateSubAgentRequest{
|
||||
Id: childAgent.ID[:],
|
||||
Directory: "/workspaces/project",
|
||||
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
|
||||
proto.CreateSubAgentRequest_WEB_TERMINAL,
|
||||
},
|
||||
}
|
||||
},
|
||||
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: Directory is updated to the container-internal
|
||||
// path.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error/MalformedID",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
|
||||
Generated
-284
@@ -9514,212 +9514,6 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13445,12 +13239,6 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -14691,9 +14479,6 @@ const docTemplate = `{
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
},
|
||||
"debug_logging_enabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15357,26 +15142,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -19152,9 +18917,6 @@ const docTemplate = `{
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -21509,23 +21271,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21981,35 +21726,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
-262
@@ -8431,190 +8431,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11993,12 +11809,6 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -13204,9 +13014,6 @@
|
||||
"properties": {
|
||||
"acquire_batch_size": {
|
||||
"type": "integer"
|
||||
},
|
||||
"debug_logging_enabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13836,26 +13643,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -17512,9 +17299,6 @@
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19761,23 +19545,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -20208,35 +19975,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
+10
-15
@@ -159,7 +159,10 @@ type Options struct {
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
// ChatPubsub allows chatd to use a dedicated publish path without changing
|
||||
// the shared pubsub used by the rest of coderd.
|
||||
ChatPubsub pubsub.Pubsub
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
|
||||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
@@ -777,6 +780,11 @@ func New(options *Options) *API {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
chatPubsub := options.ChatPubsub
|
||||
if chatPubsub == nil {
|
||||
chatPubsub = options.Pubsub
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
@@ -789,7 +797,7 @@ func New(options *Options) *API {
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
Pubsub: chatPubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
@@ -1608,15 +1616,6 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1662,10 +1661,6 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,10 +147,6 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -169,9 +165,6 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -538,12 +538,6 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1240,8 +1234,6 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
@@ -1533,22 +1525,6 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func nullStringPtr(v sql.NullString) *string {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.String
|
||||
return &value
|
||||
}
|
||||
|
||||
func nullTimePtr(v sql.NullTime) *time.Time {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
value := v.Time
|
||||
return &value
|
||||
}
|
||||
|
||||
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
|
||||
// nil slices and maps to empty values for JSON serialization and
|
||||
// derives RootChatID from the parent chain when not explicitly set.
|
||||
@@ -1635,88 +1611,6 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
|
||||
return chat
|
||||
}
|
||||
|
||||
func chatDebugAttempts(raw json.RawMessage) []map[string]any {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var attempts []map[string]any
|
||||
if err := json.Unmarshal(raw, &attempts); err != nil {
|
||||
return []map[string]any{{
|
||||
"error": "malformed attempts payload",
|
||||
"raw": string(raw),
|
||||
}}
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
// rawJSONObject deserializes a JSON object payload for debug display.
|
||||
// If the payload is malformed, it returns a map with "error" and "raw"
|
||||
// keys preserving the original content for diagnostics. Callers that
|
||||
// consume the result programmatically should check for the "error" key.
|
||||
func rawJSONObject(raw json.RawMessage) map[string]any {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var object map[string]any
|
||||
if err := json.Unmarshal(raw, &object); err != nil {
|
||||
return map[string]any{
|
||||
"error": "malformed debug payload",
|
||||
"raw": string(raw),
|
||||
}
|
||||
}
|
||||
return object
|
||||
}
|
||||
|
||||
func nullRawJSONObject(raw pqtype.NullRawMessage) map[string]any {
|
||||
if !raw.Valid {
|
||||
return nil
|
||||
}
|
||||
return rawJSONObject(raw.RawMessage)
|
||||
}
|
||||
|
||||
// ChatDebugRunSummary converts a database.ChatDebugRun to a
|
||||
// codersdk.ChatDebugRunSummary.
|
||||
func ChatDebugRunSummary(r database.ChatDebugRun) codersdk.ChatDebugRunSummary {
|
||||
return codersdk.ChatDebugRunSummary{
|
||||
ID: r.ID,
|
||||
ChatID: r.ChatID,
|
||||
Kind: codersdk.ChatDebugRunKind(r.Kind),
|
||||
Status: codersdk.ChatDebugStatus(r.Status),
|
||||
Provider: nullStringPtr(r.Provider),
|
||||
Model: nullStringPtr(r.Model),
|
||||
Summary: rawJSONObject(r.Summary),
|
||||
StartedAt: r.StartedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
FinishedAt: nullTimePtr(r.FinishedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ChatDebugStep converts a database.ChatDebugStep to a
|
||||
// codersdk.ChatDebugStep.
|
||||
func ChatDebugStep(s database.ChatDebugStep) codersdk.ChatDebugStep {
|
||||
return codersdk.ChatDebugStep{
|
||||
ID: s.ID,
|
||||
RunID: s.RunID,
|
||||
ChatID: s.ChatID,
|
||||
StepNumber: s.StepNumber,
|
||||
Operation: codersdk.ChatDebugStepOperation(s.Operation),
|
||||
Status: codersdk.ChatDebugStatus(s.Status),
|
||||
HistoryTipMessageID: nullInt64Ptr(s.HistoryTipMessageID),
|
||||
AssistantMessageID: nullInt64Ptr(s.AssistantMessageID),
|
||||
NormalizedRequest: rawJSONObject(s.NormalizedRequest),
|
||||
NormalizedResponse: nullRawJSONObject(s.NormalizedResponse),
|
||||
Usage: nullRawJSONObject(s.Usage),
|
||||
Attempts: chatDebugAttempts(s.Attempts),
|
||||
Error: nullRawJSONObject(s.Error),
|
||||
Metadata: rawJSONObject(s.Metadata),
|
||||
StartedAt: s.StartedAt,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
FinishedAt: nullTimePtr(s.FinishedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRows converts a slice of database.GetChatsRow (which embeds
|
||||
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
|
||||
// from the provided map. When diffStatusesByChatID is non-nil,
|
||||
|
||||
@@ -210,231 +210,6 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
|
||||
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
startedAt := time.Now().UTC().Round(time.Second)
|
||||
finishedAt := startedAt.Add(5 * time.Second)
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
Provider: sql.NullString{String: "openai", Valid: true},
|
||||
Model: sql.NullString{String: "gpt-4o", Valid: true},
|
||||
Summary: json.RawMessage(`{"step_count":3,"has_error":false}`),
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: finishedAt,
|
||||
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
|
||||
require.Equal(t, run.ID, sdk.ID)
|
||||
require.Equal(t, run.ChatID, sdk.ChatID)
|
||||
require.Equal(t, codersdk.ChatDebugRunKindChatTurn, sdk.Kind)
|
||||
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
|
||||
require.NotNil(t, sdk.Provider)
|
||||
require.Equal(t, "openai", *sdk.Provider)
|
||||
require.NotNil(t, sdk.Model)
|
||||
require.Equal(t, "gpt-4o", *sdk.Model)
|
||||
require.Equal(t, map[string]any{"step_count": float64(3), "has_error": false}, sdk.Summary)
|
||||
require.Equal(t, startedAt, sdk.StartedAt)
|
||||
require.Equal(t, finishedAt, sdk.UpdatedAt)
|
||||
require.NotNil(t, sdk.FinishedAt)
|
||||
require.Equal(t, finishedAt, *sdk.FinishedAt)
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary_NullableFieldsNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "title_generation",
|
||||
Status: "in_progress",
|
||||
Summary: json.RawMessage(`{}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
|
||||
require.Nil(t, sdk.Provider, "NULL Provider should map to nil")
|
||||
require.Nil(t, sdk.Model, "NULL Model should map to nil")
|
||||
require.Nil(t, sdk.FinishedAt, "NULL FinishedAt should map to nil")
|
||||
}
|
||||
|
||||
func TestChatDebugStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
startedAt := time.Now().UTC().Round(time.Second)
|
||||
finishedAt := startedAt.Add(2 * time.Second)
|
||||
attempts := json.RawMessage(`[
|
||||
{
|
||||
"attempt_number": 1,
|
||||
"status": "completed",
|
||||
"raw_request": {"url": "https://example.com"},
|
||||
"raw_response": {"status": "200"},
|
||||
"duration_ms": 123,
|
||||
"started_at": "2026-03-01T10:00:01Z",
|
||||
"finished_at": "2026-03-01T10:00:02Z"
|
||||
}
|
||||
]`)
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
|
||||
Attempts: attempts,
|
||||
Metadata: json.RawMessage(`{"provider":"openai"}`),
|
||||
StartedAt: startedAt,
|
||||
UpdatedAt: finishedAt,
|
||||
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
|
||||
// Verify all scalar fields are mapped correctly.
|
||||
require.Equal(t, step.ID, sdk.ID)
|
||||
require.Equal(t, step.RunID, sdk.RunID)
|
||||
require.Equal(t, step.ChatID, sdk.ChatID)
|
||||
require.Equal(t, step.StepNumber, sdk.StepNumber)
|
||||
require.Equal(t, codersdk.ChatDebugStepOperationStream, sdk.Operation)
|
||||
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
|
||||
require.Equal(t, startedAt, sdk.StartedAt)
|
||||
require.Equal(t, finishedAt, sdk.UpdatedAt)
|
||||
require.Equal(t, &finishedAt, sdk.FinishedAt)
|
||||
|
||||
// Verify JSON object fields are deserialized.
|
||||
require.NotNil(t, sdk.NormalizedRequest)
|
||||
require.Equal(t, map[string]any{"messages": []any{}}, sdk.NormalizedRequest)
|
||||
require.NotNil(t, sdk.Metadata)
|
||||
require.Equal(t, map[string]any{"provider": "openai"}, sdk.Metadata)
|
||||
|
||||
// Verify nullable fields are nil when the DB row has NULL values.
|
||||
require.Nil(t, sdk.HistoryTipMessageID, "NULL HistoryTipMessageID should map to nil")
|
||||
require.Nil(t, sdk.AssistantMessageID, "NULL AssistantMessageID should map to nil")
|
||||
require.Nil(t, sdk.NormalizedResponse, "NULL NormalizedResponse should map to nil")
|
||||
require.Nil(t, sdk.Usage, "NULL Usage should map to nil")
|
||||
require.Nil(t, sdk.Error, "NULL Error should map to nil")
|
||||
|
||||
// Verify attempts are preserved with all fields.
|
||||
require.Len(t, sdk.Attempts, 1)
|
||||
require.Equal(t, float64(1), sdk.Attempts[0]["attempt_number"])
|
||||
require.Equal(t, "completed", sdk.Attempts[0]["status"])
|
||||
require.Equal(t, float64(123), sdk.Attempts[0]["duration_ms"])
|
||||
require.Equal(t, map[string]any{"url": "https://example.com"}, sdk.Attempts[0]["raw_request"])
|
||||
require.Equal(t, map[string]any{"status": "200"}, sdk.Attempts[0]["raw_response"])
|
||||
}
|
||||
|
||||
func TestChatDebugStep_NullableFieldsPopulated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tipID := int64(42)
|
||||
asstID := int64(99)
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 2,
|
||||
Operation: "generate",
|
||||
Status: "completed",
|
||||
HistoryTipMessageID: sql.NullInt64{Int64: tipID, Valid: true},
|
||||
AssistantMessageID: sql.NullInt64{Int64: asstID, Valid: true},
|
||||
NormalizedRequest: json.RawMessage(`{}`),
|
||||
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"hi"}`), Valid: true},
|
||||
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":10}`), Valid: true},
|
||||
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"rate_limit"}`), Valid: true},
|
||||
Attempts: json.RawMessage(`[]`),
|
||||
Metadata: json.RawMessage(`{}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
|
||||
require.NotNil(t, sdk.HistoryTipMessageID)
|
||||
require.Equal(t, tipID, *sdk.HistoryTipMessageID)
|
||||
require.NotNil(t, sdk.AssistantMessageID)
|
||||
require.Equal(t, asstID, *sdk.AssistantMessageID)
|
||||
require.NotNil(t, sdk.NormalizedResponse)
|
||||
require.Equal(t, map[string]any{"text": "hi"}, sdk.NormalizedResponse)
|
||||
require.NotNil(t, sdk.Usage)
|
||||
require.Equal(t, map[string]any{"tokens": float64(10)}, sdk.Usage)
|
||||
require.NotNil(t, sdk.Error)
|
||||
require.Equal(t, map[string]any{"code": "rate_limit"}, sdk.Error)
|
||||
}
|
||||
|
||||
func TestChatDebugStep_PreservesMalformedAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
|
||||
Attempts: json.RawMessage(`{"bad":true}`),
|
||||
Metadata: json.RawMessage(`{"provider":"openai"}`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
require.Len(t, sdk.Attempts, 1)
|
||||
require.Equal(t, "malformed attempts payload", sdk.Attempts[0]["error"])
|
||||
require.Equal(t, `{"bad":true}`, sdk.Attempts[0]["raw"])
|
||||
}
|
||||
|
||||
func TestChatDebugRunSummary_PreservesMalformedSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
run := database.ChatDebugRun{
|
||||
ID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
Kind: "chat_turn",
|
||||
Status: "completed",
|
||||
Summary: json.RawMessage(`not-an-object`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugRunSummary(run)
|
||||
require.Equal(t, "malformed debug payload", sdk.Summary["error"])
|
||||
require.Equal(t, "not-an-object", sdk.Summary["raw"])
|
||||
}
|
||||
|
||||
func TestChatDebugStep_PreservesMalformedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
step := database.ChatDebugStep{
|
||||
ID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 1,
|
||||
Operation: "stream",
|
||||
Status: "completed",
|
||||
NormalizedRequest: json.RawMessage(`[1,2,3]`),
|
||||
Attempts: json.RawMessage(`[]`),
|
||||
Metadata: json.RawMessage(`"just-a-string"`),
|
||||
StartedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
sdk := db2sdk.ChatDebugStep(step)
|
||||
require.Equal(t, "malformed debug payload", sdk.NormalizedRequest["error"])
|
||||
require.Equal(t, "[1,2,3]", sdk.NormalizedRequest["raw"])
|
||||
require.Equal(t, "malformed debug payload", sdk.Metadata["error"])
|
||||
require.Equal(t, `"just-a-string"`, sdk.Metadata["raw"])
|
||||
}
|
||||
|
||||
func TestAIBridgeInterception(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1708,17 +1708,6 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -1860,28 +1849,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteChatDebugDataAfterMessageID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteChatDebugDataByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -2202,10 +2169,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
@@ -2369,14 +2336,6 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context,
|
||||
return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
|
||||
}
|
||||
|
||||
func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
// Background sweep operates across all chats.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, err
|
||||
}
|
||||
return q.db.FinalizeStaleChatDebugRows(ctx, updatedBefore)
|
||||
}
|
||||
|
||||
func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
_, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID)
|
||||
if err != nil {
|
||||
@@ -2454,10 +2413,6 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2585,59 +2540,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
// The allow-users flag is a deployment-wide setting read by any
|
||||
// authenticated chat user. We only require that an explicit actor
|
||||
// is present in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDebugLoggingAllowUsers(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
run, err := q.db.GetChatDebugRunByID(ctx, id)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
// Authorize via the owning chat.
|
||||
chat, err := q.db.GetChatByID(ctx, run.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return run, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatDebugRunsByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
run, err := q.db.GetChatDebugRunByID(ctx, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Authorize via the owning chat.
|
||||
chat, err := q.db.GetChatByID(ctx, run.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatDebugStepsByRunID(ctx, runID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
@@ -3484,11 +3386,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsPullRequests(ctx, arg)
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
@@ -4186,17 +4088,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.GetUserChatDebugLoggingEnabled(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -4943,33 +4834,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return q.db.InsertChatDebugRun(ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugStep creates a new step in a debug run. The underlying
|
||||
// SQL uses INSERT ... SELECT ... FROM chat_debug_runs to enforce that the
|
||||
// run exists and belongs to the specified chat. If the run_id is invalid
|
||||
// or the chat_id doesn't match, the INSERT produces 0 rows and SQLC
|
||||
// returns sql.ErrNoRows.
|
||||
func (q *querier) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
return q.db.InsertChatDebugStep(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
// Authorize create on chat resource scoped to the owner and org.
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
@@ -5864,17 +5728,6 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -5968,28 +5821,6 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
return q.db.UpdateChatDebugRun(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
return q.db.UpdateChatDebugStep(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
@@ -6926,19 +6757,6 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
|
||||
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -7222,13 +7040,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -7459,17 +7270,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return q.db.UpsertTemplateUsageStats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertUserChatDebugLoggingEnabled(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
u, err := q.db.GetUserByID(ctx, arg.UserID)
|
||||
if err != nil {
|
||||
|
||||
@@ -461,89 +461,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
|
||||
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatDebugDataAfterMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.DeleteChatDebugDataAfterMessageIDParams{ChatID: chat.ID, MessageID: 123}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatDebugDataAfterMessageID(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("DeleteChatDebugDataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatDebugDataByChatID(gomock.Any(), chat.ID).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
}))
|
||||
s.Run("FinalizeStaleChatDebugRows", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
updatedBefore := dbtime.Now()
|
||||
row := database.FinalizeStaleChatDebugRowsRow{RunsFinalized: 1, StepsFinalized: 2}
|
||||
dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), updatedBefore).Return(row, nil).AnyTimes()
|
||||
check.Args(updatedBefore).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row)
|
||||
}))
|
||||
s.Run("GetChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(true)
|
||||
}))
|
||||
s.Run("GetChatDebugRunByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(run)
|
||||
}))
|
||||
s.Run("GetChatDebugRunsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
runs := []database.ChatDebugRun{{ID: uuid.New(), ChatID: chat.ID}}
|
||||
arg := database.GetChatDebugRunsByChatIDParams{ChatID: chat.ID, LimitVal: 100}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDebugRunsByChatID(gomock.Any(), arg).Return(runs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(runs)
|
||||
}))
|
||||
s.Run("GetChatDebugStepsByRunID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
steps := []database.ChatDebugStep{{ID: uuid.New(), RunID: run.ID, ChatID: chat.ID}}
|
||||
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), run.ID).Return(steps, nil).AnyTimes()
|
||||
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(steps)
|
||||
}))
|
||||
s.Run("InsertChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.InsertChatDebugRunParams{ChatID: chat.ID, Kind: "chat_turn", Status: "in_progress"}
|
||||
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
|
||||
}))
|
||||
s.Run("InsertChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.InsertChatDebugStepParams{RunID: uuid.New(), ChatID: chat.ID, StepNumber: 1, Operation: "stream", Status: "in_progress"}
|
||||
step := database.ChatDebugStep{ID: uuid.New(), RunID: arg.RunID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
|
||||
}))
|
||||
s.Run("UpdateChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatDebugRunParams{ID: uuid.New(), ChatID: chat.ID}
|
||||
run := database.ChatDebugRun{ID: arg.ID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
|
||||
}))
|
||||
s.Run("UpdateChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatDebugStepParams{ID: uuid.New(), ChatID: chat.ID}
|
||||
step := database.ChatDebugStep{ID: arg.ID, ChatID: chat.ID}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
|
||||
}))
|
||||
s.Run("UpsertChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDebugLoggingAllowUsers(gomock.Any(), true).Return(nil).AnyTimes()
|
||||
check.Args(true).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
@@ -561,24 +478,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -2344,9 +2243,9 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPullRequestsParams{}
|
||||
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -2577,19 +2476,6 @@ func (s *MethodTestSuite) TestUser() {
|
||||
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
|
||||
}))
|
||||
s.Run("GetUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), u.ID).Return(true, nil).AnyTimes()
|
||||
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns(true)
|
||||
}))
|
||||
s.Run("UpsertUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.UpsertUserChatDebugLoggingEnabledParams{UserID: u.ID, DebugLoggingEnabled: true}
|
||||
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
|
||||
dbm.EXPECT().UpsertUserChatDebugLoggingEnabled(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
|
||||
}))
|
||||
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
|
||||
@@ -3031,17 +2917,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: agt.ID,
|
||||
Directory: "/workspaces/project",
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
@@ -5538,10 +5413,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns(int64(1))
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -280,14 +280,6 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -416,22 +408,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteChatDebugDataAfterMessageID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatDebugDataAfterMessageID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataAfterMessageID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteChatDebugDataByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatDebugDataByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
@@ -752,12 +728,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0, r1
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -888,14 +864,6 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.FinalizeStaleChatDebugRows(ctx, updatedBefore)
|
||||
m.queryLatencies.WithLabelValues("FinalizeStaleChatDebugRows").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "FinalizeStaleChatDebugRows").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.FindMatchingPresetID(ctx, arg)
|
||||
@@ -1000,14 +968,6 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1152,38 +1112,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugLoggingAllowUsers(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugLoggingAllowUsers").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugRunByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugRunByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugRunsByChatID(ctx context.Context, chatID database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugRunsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugRunsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunsByChatID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDebugStepsByRunID(ctx, runID)
|
||||
m.queryLatencies.WithLabelValues("GetChatDebugStepsByRunID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugStepsByRunID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
@@ -2048,11 +1976,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -2672,14 +2600,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatDebugLoggingEnabled(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatDebugLoggingEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
|
||||
@@ -3376,22 +3296,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatDebugRun(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatDebugRun").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugRun").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatDebugStep(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatDebugStep").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugStep").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatFile(ctx, arg)
|
||||
@@ -4200,14 +4104,6 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4288,22 +4184,6 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatDebugRun(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatDebugRun").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugRun").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatDebugStep(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatDebugStep").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugStep").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
@@ -4936,14 +4816,6 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
|
||||
@@ -5144,14 +5016,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDebugLoggingAllowUsers").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
@@ -5384,14 +5248,6 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertUserChatDebugLoggingEnabled(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatDebugLoggingEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
|
||||
|
||||
@@ -363,20 +363,6 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -671,36 +657,6 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataAfterMessageID mocks base method.
|
||||
func (m *MockStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatDebugDataAfterMessageID", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataAfterMessageID indicates an expected call of DeleteChatDebugDataAfterMessageID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatDebugDataAfterMessageID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataAfterMessageID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataAfterMessageID), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataByChatID mocks base method.
|
||||
func (m *MockStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatDebugDataByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteChatDebugDataByChatID indicates an expected call of DeleteChatDebugDataByChatID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatDebugDataByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1274,12 +1230,11 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
@@ -1517,21 +1472,6 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt)
|
||||
}
|
||||
|
||||
// FinalizeStaleChatDebugRows mocks base method.
|
||||
func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, updatedBefore)
|
||||
ret0, _ := ret[0].(database.FinalizeStaleChatDebugRowsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FinalizeStaleChatDebugRows indicates an expected call of FinalizeStaleChatDebugRows.
|
||||
func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, updatedBefore any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, updatedBefore)
|
||||
}
|
||||
|
||||
// FindMatchingPresetID mocks base method.
|
||||
func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1727,21 +1667,6 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2117,66 +2042,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDebugLoggingAllowUsers mocks base method.
|
||||
func (m *MockStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugLoggingAllowUsers", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugLoggingAllowUsers indicates an expected call of GetChatDebugLoggingAllowUsers.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugLoggingAllowUsers(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).GetChatDebugLoggingAllowUsers), ctx)
|
||||
}
|
||||
|
||||
// GetChatDebugRunByID mocks base method.
|
||||
func (m *MockStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugRunByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugRunByID indicates an expected call of GetChatDebugRunByID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugRunByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunByID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatDebugRunsByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugRunsByChatID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugRunsByChatID indicates an expected call of GetChatDebugRunsByChatID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugRunsByChatID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunsByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunsByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDebugStepsByRunID mocks base method.
|
||||
func (m *MockStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDebugStepsByRunID", ctx, runID)
|
||||
ret0, _ := ret[0].([]database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDebugStepsByRunID indicates an expected call of GetChatDebugStepsByRunID.
|
||||
func (mr *MockStoreMockRecorder) GetChatDebugStepsByRunID(ctx, runID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugStepsByRunID", reflect.TypeOf((*MockStore)(nil).GetChatDebugStepsByRunID), ctx, runID)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3797,19 +3662,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsPullRequests mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
@@ -4997,21 +4862,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatDebugLoggingEnabled mocks base method.
|
||||
func (m *MockStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatDebugLoggingEnabled", ctx, userID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatDebugLoggingEnabled indicates an expected call of GetUserChatDebugLoggingEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatDebugLoggingEnabled(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).GetUserChatDebugLoggingEnabled), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatProviderKeys mocks base method.
|
||||
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6331,36 +6181,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugRun mocks base method.
|
||||
func (m *MockStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatDebugRun", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatDebugRun indicates an expected call of InsertChatDebugRun.
|
||||
func (mr *MockStoreMockRecorder) InsertChatDebugRun(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugRun", reflect.TypeOf((*MockStore)(nil).InsertChatDebugRun), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatDebugStep mocks base method.
|
||||
func (m *MockStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatDebugStep", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatDebugStep indicates an expected call of InsertChatDebugStep.
|
||||
func (mr *MockStoreMockRecorder) InsertChatDebugStep(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugStep", reflect.TypeOf((*MockStore)(nil).InsertChatDebugStep), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatFile mocks base method.
|
||||
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7960,20 +7780,6 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8119,36 +7925,6 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatDebugRun mocks base method.
|
||||
func (m *MockStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatDebugRun", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugRun)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatDebugRun indicates an expected call of UpdateChatDebugRun.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatDebugRun(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugRun", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugRun), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatDebugStep mocks base method.
|
||||
func (m *MockStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatDebugStep", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDebugStep)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatDebugStep indicates an expected call of UpdateChatDebugStep.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatDebugStep(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugStep", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugStep), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9300,20 +9076,6 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9669,20 +9431,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDebugLoggingAllowUsers mocks base method.
|
||||
func (m *MockStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDebugLoggingAllowUsers", ctx, allowUsers)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDebugLoggingAllowUsers indicates an expected call of UpsertChatDebugLoggingAllowUsers.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDebugLoggingAllowUsers(ctx, allowUsers any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).UpsertChatDebugLoggingAllowUsers), ctx, allowUsers)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -10100,20 +9848,6 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
|
||||
}
|
||||
|
||||
// UpsertUserChatDebugLoggingEnabled mocks base method.
|
||||
func (m *MockStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertUserChatDebugLoggingEnabled", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertUserChatDebugLoggingEnabled indicates an expected call of UpsertUserChatDebugLoggingEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertUserChatDebugLoggingEnabled(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).UpsertUserChatDebugLoggingEnabled), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertUserChatProviderKey mocks base method.
|
||||
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+2
-69
@@ -1255,44 +1255,6 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
|
||||
|
||||
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
|
||||
|
||||
CREATE TABLE chat_debug_runs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
root_chat_id uuid,
|
||||
parent_chat_id uuid,
|
||||
model_config_id uuid,
|
||||
trigger_message_id bigint,
|
||||
history_tip_message_id bigint,
|
||||
kind text NOT NULL,
|
||||
status text NOT NULL,
|
||||
provider text,
|
||||
model text,
|
||||
summary jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
started_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
finished_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE TABLE chat_debug_steps (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
run_id uuid NOT NULL,
|
||||
chat_id uuid NOT NULL,
|
||||
step_number integer NOT NULL,
|
||||
operation text NOT NULL,
|
||||
status text NOT NULL,
|
||||
history_tip_message_id bigint,
|
||||
assistant_message_id bigint,
|
||||
normalized_request jsonb NOT NULL,
|
||||
normalized_response jsonb,
|
||||
usage jsonb,
|
||||
attempts jsonb DEFAULT '[]'::jsonb NOT NULL,
|
||||
error jsonb,
|
||||
metadata jsonb DEFAULT '{}'::jsonb NOT NULL,
|
||||
started_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
finished_at timestamp with time zone
|
||||
);
|
||||
|
||||
CREATE TABLE chat_diff_statuses (
|
||||
chat_id uuid NOT NULL,
|
||||
url text,
|
||||
@@ -3397,12 +3359,6 @@ ALTER TABLE ONLY audit_logs
|
||||
ALTER TABLE ONLY boundary_usage_stats
|
||||
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_runs
|
||||
ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
|
||||
@@ -3797,20 +3753,6 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
|
||||
|
||||
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs USING btree (chat_id, started_at DESC);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs USING btree (updated_at) WHERE (finished_at IS NULL);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps USING btree (chat_id, assistant_message_id) WHERE (assistant_message_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps USING btree (chat_id, history_tip_message_id);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
|
||||
|
||||
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps USING btree (updated_at) WHERE (finished_at IS NULL);
|
||||
|
||||
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
|
||||
|
||||
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
|
||||
@@ -3841,14 +3783,14 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
@@ -4114,12 +4056,6 @@ ALTER TABLE ONLY aibridge_interceptions
|
||||
ALTER TABLE ONLY api_keys
|
||||
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_debug_runs
|
||||
ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_diff_statuses
|
||||
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4192,9 +4128,6 @@ ALTER TABLE ONLY connection_logs
|
||||
ALTER TABLE ONLY crypto_keys
|
||||
ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY chat_debug_steps
|
||||
ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY oauth2_provider_app_tokens
|
||||
ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ const (
|
||||
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDebugRunsChatID ForeignKeyConstraint = "chat_debug_runs_chat_id_fkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDebugStepsChatID ForeignKeyConstraint = "chat_debug_steps_chat_id_fkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
@@ -35,7 +33,6 @@ const (
|
||||
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyFkChatDebugStepsRunChat ForeignKeyConstraint = "fk_chat_debug_steps_run_chat" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
|
||||
ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -1 +0,0 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -1 +0,0 @@
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
|
||||
@@ -1,5 +0,0 @@
|
||||
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
|
||||
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
|
||||
-- This index was purpose-built for the old sort and no longer provides
|
||||
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
|
||||
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
|
||||
@@ -1,2 +0,0 @@
|
||||
DROP TABLE IF EXISTS chat_debug_steps;
|
||||
DROP TABLE IF EXISTS chat_debug_runs;
|
||||
@@ -1,59 +0,0 @@
|
||||
CREATE TABLE chat_debug_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
-- root_chat_id and parent_chat_id are intentionally NOT
|
||||
-- foreign-keyed to chats(id). They are snapshot values that
|
||||
-- record the subchat hierarchy at run time. The referenced
|
||||
-- chat may be archived or deleted independently, and we want
|
||||
-- to preserve the historical lineage in debug rows rather
|
||||
-- than cascade-delete them.
|
||||
root_chat_id UUID,
|
||||
parent_chat_id UUID,
|
||||
model_config_id UUID,
|
||||
trigger_message_id BIGINT,
|
||||
history_tip_message_id BIGINT,
|
||||
kind TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
summary JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
finished_at TIMESTAMPTZ
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs(id, chat_id);
|
||||
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs(chat_id, started_at DESC);
|
||||
|
||||
CREATE TABLE chat_debug_steps (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
run_id UUID NOT NULL,
|
||||
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
|
||||
step_number INT NOT NULL,
|
||||
operation TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
history_tip_message_id BIGINT,
|
||||
assistant_message_id BIGINT,
|
||||
normalized_request JSONB NOT NULL,
|
||||
normalized_response JSONB,
|
||||
usage JSONB,
|
||||
attempts JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
error JSONB,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
finished_at TIMESTAMPTZ,
|
||||
CONSTRAINT fk_chat_debug_steps_run_chat
|
||||
FOREIGN KEY (run_id, chat_id)
|
||||
REFERENCES chat_debug_runs(id, chat_id)
|
||||
ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps(run_id, step_number);
|
||||
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps(chat_id, history_tip_message_id);
|
||||
-- Supports DeleteChatDebugDataAfterMessageID assistant_message_id branch.
|
||||
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps(chat_id, assistant_message_id) WHERE assistant_message_id IS NOT NULL;
|
||||
|
||||
-- Supports FinalizeStaleChatDebugRows worker query.
|
||||
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs(updated_at) WHERE finished_at IS NULL;
|
||||
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps(updated_at) WHERE finished_at IS NULL;
|
||||
-65
@@ -1,65 +0,0 @@
|
||||
INSERT INTO chat_debug_runs (
|
||||
id,
|
||||
chat_id,
|
||||
model_config_id,
|
||||
history_tip_message_id,
|
||||
kind,
|
||||
status,
|
||||
provider,
|
||||
model,
|
||||
summary,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
) VALUES (
|
||||
'c98518f8-9fb3-458b-a642-57552af1db63',
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
'chat_turn',
|
||||
'completed',
|
||||
'openai',
|
||||
'gpt-5.2',
|
||||
'{"step_count":1,"has_error":false}'::jsonb,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:01+00',
|
||||
'2024-01-01 00:00:01+00'
|
||||
);
|
||||
|
||||
INSERT INTO chat_debug_steps (
|
||||
id,
|
||||
run_id,
|
||||
chat_id,
|
||||
step_number,
|
||||
operation,
|
||||
status,
|
||||
history_tip_message_id,
|
||||
assistant_message_id,
|
||||
normalized_request,
|
||||
normalized_response,
|
||||
usage,
|
||||
attempts,
|
||||
error,
|
||||
metadata,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
) VALUES (
|
||||
'59471c60-7851-4fa6-bf05-e21dd939721f',
|
||||
'c98518f8-9fb3-458b-a642-57552af1db63',
|
||||
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
|
||||
1,
|
||||
'stream',
|
||||
'completed',
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
|
||||
'{"messages":[]}'::jsonb,
|
||||
'{"finish_reason":"stop"}'::jsonb,
|
||||
'{"input_tokens":1,"output_tokens":1}'::jsonb,
|
||||
'[]'::jsonb,
|
||||
NULL,
|
||||
'{"provider":"openai"}'::jsonb,
|
||||
'2024-01-01 00:00:00+00',
|
||||
'2024-01-01 00:00:01+00',
|
||||
'2024-01-01 00:00:01+00'
|
||||
);
|
||||
@@ -4248,44 +4248,6 @@ type Chat struct {
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDebugRun struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
Kind string `db:"kind" json:"kind"`
|
||||
Status string `db:"status" json:"status"`
|
||||
Provider sql.NullString `db:"provider" json:"provider"`
|
||||
Model sql.NullString `db:"model" json:"model"`
|
||||
Summary json.RawMessage `db:"summary" json:"summary"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
}
|
||||
|
||||
type ChatDebugStep struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
RunID uuid.UUID `db:"run_id" json:"run_id"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
StepNumber int32 `db:"step_number" json:"step_number"`
|
||||
Operation string `db:"operation" json:"operation"`
|
||||
Status string `db:"status" json:"status"`
|
||||
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
|
||||
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
|
||||
NormalizedRequest json.RawMessage `db:"normalized_request" json:"normalized_request"`
|
||||
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
|
||||
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
|
||||
Attempts json.RawMessage `db:"attempts" json:"attempts"`
|
||||
Error pqtype.NullRawMessage `db:"error" json:"error"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Url sql.NullString `db:"url" json:"url"`
|
||||
|
||||
@@ -0,0 +1,749 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
|
||||
// publishes wait before a scheduled flush when nearby publishes do not
|
||||
// naturally coalesce sooner.
|
||||
DefaultBatchingFlushInterval = 50 * time.Millisecond
|
||||
// DefaultBatchingQueueSize is the default number of buffered chatd publish
|
||||
// requests waiting to be flushed.
|
||||
DefaultBatchingQueueSize = 8192
|
||||
|
||||
defaultBatchingPressureWait = 10 * time.Millisecond
|
||||
defaultBatchingFinalFlushLimit = 15 * time.Second
|
||||
batchingWarnInterval = 10 * time.Second
|
||||
|
||||
batchFlushScheduled = "scheduled"
|
||||
batchFlushShutdown = "shutdown"
|
||||
|
||||
batchFlushStageNone = "none"
|
||||
batchFlushStageBegin = "begin"
|
||||
batchFlushStageExec = "exec"
|
||||
batchFlushStageCommit = "commit"
|
||||
|
||||
batchDelegateFallbackReasonQueueFull = "queue_full"
|
||||
batchDelegateFallbackReasonFlushError = "flush_error"
|
||||
|
||||
batchChannelClassStreamNotify = "stream_notify"
|
||||
batchChannelClassOwnerEvent = "owner_event"
|
||||
batchChannelClassConfigChange = "config_change"
|
||||
batchChannelClassOther = "other"
|
||||
)
|
||||
|
||||
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
|
||||
// attempted after shutdown has started.
|
||||
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
|
||||
|
||||
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
|
||||
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
|
||||
// backpressure) and drains everything currently queued into a single
|
||||
// transaction. There is no fixed batch-size knob — the batch size is simply
|
||||
// whatever accumulated since the last flush, which naturally adapts to load.
|
||||
type BatchingConfig struct {
|
||||
FlushInterval time.Duration
|
||||
QueueSize int
|
||||
PressureWait time.Duration
|
||||
FinalFlushTimeout time.Duration
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
type queuedPublish struct {
|
||||
event string
|
||||
channelClass string
|
||||
message []byte
|
||||
}
|
||||
|
||||
type batchSender interface {
|
||||
Flush(ctx context.Context, batch []queuedPublish) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchFlushError struct {
|
||||
stage string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
|
||||
// sender connection while delegating subscribe behavior to the shared listener
|
||||
// pubsub instance.
|
||||
type BatchingPubsub struct {
|
||||
logger slog.Logger
|
||||
delegate *PGPubsub
|
||||
// sender is only accessed from the run() goroutine (including
|
||||
// flushBatch and resetSender which it calls). Do not read or
|
||||
// write this field from Publish or any other goroutine.
|
||||
sender batchSender
|
||||
newSender func(context.Context) (batchSender, error)
|
||||
clock quartz.Clock
|
||||
|
||||
publishCh chan queuedPublish
|
||||
flushCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
spaceMu sync.Mutex
|
||||
spaceSignal chan struct{}
|
||||
|
||||
warnTicker *quartz.Ticker
|
||||
|
||||
flushInterval time.Duration
|
||||
pressureWait time.Duration
|
||||
finalFlushTimeout time.Duration
|
||||
|
||||
queuedCount atomic.Int64
|
||||
closed atomic.Bool
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
runErr error
|
||||
|
||||
runCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
metrics batchingMetrics
|
||||
}
|
||||
|
||||
type batchingMetrics struct {
|
||||
QueueDepth prometheus.Gauge
|
||||
BatchSize prometheus.Histogram
|
||||
FlushDuration *prometheus.HistogramVec
|
||||
DelegateFallbacksTotal *prometheus.CounterVec
|
||||
SenderResetsTotal prometheus.Counter
|
||||
SenderResetFailuresTotal prometheus.Counter
|
||||
}
|
||||
|
||||
func newBatchingMetrics() batchingMetrics {
|
||||
return batchingMetrics{
|
||||
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_queue_depth",
|
||||
Help: "The number of chatd notifications waiting in the batching queue.",
|
||||
}),
|
||||
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_size",
|
||||
Help: "The number of logical notifications sent in each chatd batch flush.",
|
||||
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
|
||||
}),
|
||||
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_flush_duration_seconds",
|
||||
Help: "The time spent flushing one chatd batch to PostgreSQL.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
|
||||
}, []string{"reason"}),
|
||||
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_delegate_fallbacks_total",
|
||||
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
|
||||
}, []string{"channel_class", "reason", "stage"}),
|
||||
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_resets_total",
|
||||
Help: "The number of successful batched pubsub sender resets after flush failures.",
|
||||
}),
|
||||
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_reset_failures_total",
|
||||
Help: "The number of batched pubsub sender reset attempts that failed.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
|
||||
m.QueueDepth.Describe(descs)
|
||||
m.BatchSize.Describe(descs)
|
||||
m.FlushDuration.Describe(descs)
|
||||
m.DelegateFallbacksTotal.Describe(descs)
|
||||
m.SenderResetsTotal.Describe(descs)
|
||||
m.SenderResetFailuresTotal.Describe(descs)
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
|
||||
m.QueueDepth.Collect(metrics)
|
||||
m.BatchSize.Collect(metrics)
|
||||
m.FlushDuration.Collect(metrics)
|
||||
m.DelegateFallbacksTotal.Collect(metrics)
|
||||
m.SenderResetsTotal.Collect(metrics)
|
||||
m.SenderResetFailuresTotal.Collect(metrics)
|
||||
}
|
||||
|
||||
// NewBatching creates a chatd-specific batched pubsub wrapper around the
|
||||
// shared PostgreSQL listener implementation.
|
||||
func NewBatching(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if prototype == nil {
|
||||
return nil, xerrors.New("prototype database is nil")
|
||||
}
|
||||
if connectURL == "" {
|
||||
return nil, xerrors.New("connect URL is empty")
|
||||
}
|
||||
|
||||
newSender := func(ctx context.Context) (batchSender, error) {
|
||||
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
|
||||
}
|
||||
|
||||
sender, err := newSender(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
|
||||
if err != nil {
|
||||
_ = sender.Close()
|
||||
return nil, err
|
||||
}
|
||||
ps.newSender = newSender
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func newBatchingPubsub(
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
sender batchSender,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if sender == nil {
|
||||
return nil, xerrors.New("batch sender is nil")
|
||||
}
|
||||
|
||||
flushInterval := cfg.FlushInterval
|
||||
if flushInterval == 0 {
|
||||
flushInterval = DefaultBatchingFlushInterval
|
||||
}
|
||||
if flushInterval < 0 {
|
||||
return nil, xerrors.New("flush interval must be positive")
|
||||
}
|
||||
|
||||
queueSize := cfg.QueueSize
|
||||
if queueSize == 0 {
|
||||
queueSize = DefaultBatchingQueueSize
|
||||
}
|
||||
if queueSize < 0 {
|
||||
return nil, xerrors.New("queue size must be positive")
|
||||
}
|
||||
|
||||
pressureWait := cfg.PressureWait
|
||||
if pressureWait == 0 {
|
||||
pressureWait = defaultBatchingPressureWait
|
||||
}
|
||||
if pressureWait < 0 {
|
||||
return nil, xerrors.New("pressure wait must be positive")
|
||||
}
|
||||
|
||||
finalFlushTimeout := cfg.FinalFlushTimeout
|
||||
if finalFlushTimeout == 0 {
|
||||
finalFlushTimeout = defaultBatchingFinalFlushLimit
|
||||
}
|
||||
if finalFlushTimeout < 0 {
|
||||
return nil, xerrors.New("final flush timeout must be positive")
|
||||
}
|
||||
|
||||
clock := cfg.Clock
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
ps := &BatchingPubsub{
|
||||
logger: logger,
|
||||
delegate: delegate,
|
||||
sender: sender,
|
||||
clock: clock,
|
||||
publishCh: make(chan queuedPublish, queueSize),
|
||||
flushCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
spaceSignal: make(chan struct{}),
|
||||
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
|
||||
flushInterval: flushInterval,
|
||||
pressureWait: pressureWait,
|
||||
finalFlushTimeout: finalFlushTimeout,
|
||||
runCtx: runCtx,
|
||||
cancel: cancel,
|
||||
metrics: newBatchingMetrics(),
|
||||
}
|
||||
ps.metrics.QueueDepth.Set(0)
|
||||
|
||||
go ps.run()
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// Describe implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
|
||||
p.metrics.Describe(descs)
|
||||
}
|
||||
|
||||
// Collect implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
|
||||
p.metrics.Collect(metrics)
|
||||
}
|
||||
|
||||
// Subscribe delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
|
||||
return p.delegate.Subscribe(event, listener)
|
||||
}
|
||||
|
||||
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
|
||||
return p.delegate.SubscribeWithErr(event, listener)
|
||||
}
|
||||
|
||||
// Publish enqueues a logical notification for asynchronous batched delivery.
|
||||
func (p *BatchingPubsub) Publish(event string, message []byte) error {
|
||||
channelClass := batchChannelClass(event)
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
|
||||
req := queuedPublish{
|
||||
event: event,
|
||||
channelClass: channelClass,
|
||||
message: bytes.Clone(message),
|
||||
}
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
|
||||
defer timer.Stop("pubsubBatcher", "pressureWait")
|
||||
|
||||
for {
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
p.signalPressureFlush()
|
||||
spaceSignal := p.currentSpaceSignal()
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-spaceSignal:
|
||||
continue
|
||||
case <-timer.C:
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
// The batching queue is still full after a pressure
|
||||
// flush and brief wait. Fall back to the shared
|
||||
// pubsub pool so the notification is still delivered
|
||||
// rather than dropped.
|
||||
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
|
||||
p.logPublishRejection(event)
|
||||
return p.delegate.Publish(event, message)
|
||||
case <-p.doneCh:
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops accepting new publishes, performs a bounded best-effort drain,
|
||||
// and then closes the dedicated sender connection.
|
||||
func (p *BatchingPubsub) Close() error {
|
||||
p.closeOnce.Do(func() {
|
||||
p.closed.Store(true)
|
||||
p.cancel()
|
||||
p.notifySpaceAvailable()
|
||||
close(p.closeCh)
|
||||
<-p.doneCh
|
||||
p.closeErr = p.runErr
|
||||
})
|
||||
return p.closeErr
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
|
||||
if p.closed.Load() {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case p.publishCh <- req:
|
||||
queuedDepth := p.queuedCount.Add(1)
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
|
||||
p.metrics.QueueDepth.Set(float64(depth))
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) signalPressureFlush() {
|
||||
select {
|
||||
case p.flushCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
return p.spaceSignal
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) notifySpaceAvailable() {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
close(p.spaceSignal)
|
||||
p.spaceSignal = make(chan struct{})
|
||||
}
|
||||
|
||||
func batchChannelClass(event string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(event, "chat:stream:"):
|
||||
return batchChannelClassStreamNotify
|
||||
case strings.HasPrefix(event, "chat:owner:"):
|
||||
return batchChannelClassOwnerEvent
|
||||
case event == "chat:config_change":
|
||||
return batchChannelClassConfigChange
|
||||
default:
|
||||
return batchChannelClassOther
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
counts := make(map[string]int)
|
||||
for _, item := range batch {
|
||||
counts[item.channelClass]++
|
||||
}
|
||||
for channelClass, count := range counts {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
|
||||
}
|
||||
}
|
||||
|
||||
func batchFlushStage(err error) string {
|
||||
var flushErr *batchFlushError
|
||||
if errors.As(err, &flushErr) {
|
||||
return flushErr.stage
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) run() {
|
||||
defer close(p.doneCh)
|
||||
defer p.warnTicker.Stop("pubsubBatcher", "warn")
|
||||
|
||||
batch := make([]queuedPublish, 0, 64)
|
||||
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
|
||||
defer timer.Stop("pubsubBatcher", "scheduledFlush")
|
||||
|
||||
flush := func(reason string) {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
batch, _ = p.flushBatch(p.runCtx, batch, reason)
|
||||
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
// An item arrived before the timer fired. Append it and
|
||||
// let the timer or pressure signal trigger the actual
|
||||
// flush so that nearby publishes coalesce naturally.
|
||||
batch = append(batch, item)
|
||||
p.notifySpaceAvailable()
|
||||
case <-timer.C:
|
||||
flush(batchFlushScheduled)
|
||||
case <-p.flushCh:
|
||||
flush("pressure")
|
||||
case <-p.closeCh:
|
||||
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
|
||||
drained := false
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
batch = append(batch, item)
|
||||
drained = true
|
||||
default:
|
||||
if drained {
|
||||
p.notifySpaceAvailable()
|
||||
}
|
||||
return batch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) flushBatch(
|
||||
ctx context.Context,
|
||||
batch []queuedPublish,
|
||||
reason string,
|
||||
) ([]queuedPublish, error) {
|
||||
if len(batch) == 0 {
|
||||
return batch[:0], nil
|
||||
}
|
||||
|
||||
count := len(batch)
|
||||
totalBytes := 0
|
||||
for _, item := range batch {
|
||||
totalBytes += len(item.message)
|
||||
}
|
||||
|
||||
p.metrics.BatchSize.Observe(float64(count))
|
||||
start := p.clock.Now()
|
||||
senderErr := p.sender.Flush(ctx, batch)
|
||||
elapsed := p.clock.Since(start)
|
||||
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
|
||||
var err error
|
||||
if senderErr != nil {
|
||||
stage := batchFlushStage(senderErr)
|
||||
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
|
||||
var resetErr error
|
||||
if reason != batchFlushShutdown {
|
||||
resetErr = p.resetSender()
|
||||
}
|
||||
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
|
||||
if fallbackErr != nil || resetErr != nil {
|
||||
err = errors.Join(senderErr, fallbackErr, resetErr)
|
||||
}
|
||||
} else if p.delegate != nil {
|
||||
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
|
||||
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
|
||||
}
|
||||
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
clear(batch)
|
||||
return batch[:0], err
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
|
||||
if len(batch) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
p.observeDelegateFallbackBatch(batch, reason, stage)
|
||||
if p.delegate == nil {
|
||||
return 0, len(batch), xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, item := range batch {
|
||||
if err := p.delegate.Publish(item.event, item.message); err != nil {
|
||||
failed++
|
||||
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
|
||||
continue
|
||||
}
|
||||
delivered++
|
||||
}
|
||||
return delivered, failed, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) resetSender() error {
|
||||
if p.newSender == nil {
|
||||
return nil
|
||||
}
|
||||
newSender, err := p.newSender(context.Background())
|
||||
if err != nil {
|
||||
p.metrics.SenderResetFailuresTotal.Inc()
|
||||
return err
|
||||
}
|
||||
oldSender := p.sender
|
||||
p.sender = newSender
|
||||
p.metrics.SenderResetsTotal.Inc()
|
||||
if oldSender == nil {
|
||||
return nil
|
||||
}
|
||||
if err := oldSender.Close(); err != nil {
|
||||
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
|
||||
fields := []slog.Field{
|
||||
slog.F("reason", reason),
|
||||
slog.F("stage", stage),
|
||||
slog.F("count", count),
|
||||
slog.F("total_bytes", totalBytes),
|
||||
slog.F("delegate_delivered", delivered),
|
||||
slog.F("delegate_failed", failed),
|
||||
slog.Error(senderErr),
|
||||
}
|
||||
if fallbackErr != nil {
|
||||
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
|
||||
}
|
||||
if resetErr != nil {
|
||||
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
|
||||
}
|
||||
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
|
||||
defer cancel()
|
||||
|
||||
var errs []error
|
||||
for {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
if len(batch) == 0 {
|
||||
break
|
||||
}
|
||||
var err error
|
||||
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dropped := p.dropPendingPublishes()
|
||||
if dropped > 0 {
|
||||
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) dropPendingPublishes() int {
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-p.publishCh:
|
||||
count++
|
||||
default:
|
||||
if count > 0 {
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
}
|
||||
return count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logPublishRejection(event string) {
|
||||
fields := []slog.Field{
|
||||
slog.F("event", event),
|
||||
slog.F("queue_size", cap(p.publishCh)),
|
||||
slog.F("queued", p.queuedCount.Load()),
|
||||
}
|
||||
select {
|
||||
case <-p.warnTicker.C:
|
||||
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
|
||||
default:
|
||||
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
type pgBatchSender struct {
|
||||
logger slog.Logger
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newPGBatchSender(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
) (*pgBatchSender, error) {
|
||||
connector, err := newConnector(ctx, logger, prototype, connectURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := sql.OpenDB(connector)
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxIdleTime(0)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
|
||||
}
|
||||
|
||||
return &pgBatchSender{logger: logger, db: db}, nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, item := range batch {
|
||||
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
|
||||
// not support the first parameter being a prepared statement.
|
||||
//nolint:gosec
|
||||
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
@@ -0,0 +1,520 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubScheduledFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 2)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("two"), batch[1].message)
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 2, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
|
||||
|
||||
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
|
||||
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
|
||||
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
|
||||
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
|
||||
}
|
||||
|
||||
func TestBatchChannelClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
event string
|
||||
want string
|
||||
}{
|
||||
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
|
||||
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
|
||||
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
|
||||
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, batchChannelClass(tt.event))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 64,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Enqueue many messages before the timer fires — all should be
|
||||
// drained and flushed in a single batch.
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 5)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("five"), batch[4].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
|
||||
defer pressureTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.blockCh = make(chan struct{})
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 1,
|
||||
PressureWait: 10 * time.Millisecond,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Fill the queue (capacity 1).
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
// Fire the timer so the run loop starts flushing "one" — the
|
||||
// sender blocks on blockCh so the flush stays in-flight.
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
<-sender.started
|
||||
|
||||
// The run loop is blocked in flushBatch. Fill the queue again.
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
|
||||
// A third publish should fall back to the delegate (which has a
|
||||
// closed db, so the delegate Publish itself will error — but we
|
||||
// verify the fallback metric was incremented).
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- ps.Publish("chat:stream:a", []byte("three"))
|
||||
}()
|
||||
|
||||
pressureCall, err := pressureTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
pressureCall.MustRelease(ctx)
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
|
||||
err = testutil.TryReceive(ctx, t, errCh)
|
||||
// The delegate has a closed db so it returns an error from the
|
||||
// shared pool, not a batching-specific sentinel.
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
|
||||
|
||||
close(sender.blockCh)
|
||||
// Let the run loop finish the blocked flush and process "two".
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
require.NoError(t, ps.Close())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 3)
|
||||
require.Equal(t, []byte("one"), batches[0][0].message)
|
||||
require.Equal(t, []byte("two"), batches[0][1].message)
|
||||
require.Equal(t, []byte("three"), batches[0][2].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, 1, sender.CloseCalls())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubPreservesOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.NotEmpty(t, batches)
|
||||
|
||||
messages := make([]string, 0, 5)
|
||||
for _, batch := range batches {
|
||||
for _, item := range batch {
|
||||
messages = append(messages, string(item.message))
|
||||
}
|
||||
}
|
||||
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 1, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
|
||||
for _, stage := range stages {
|
||||
stage := stage
|
||||
t.Run(stage, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = stage
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:test",
|
||||
channelClass: batchChannelClass("chat:stream:test"),
|
||||
message: []byte("fallback-" + stage),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
firstSender := newFakeBatchSender()
|
||||
firstSender.err = context.DeadlineExceeded
|
||||
firstSender.errStage = batchFlushStageExec
|
||||
secondSender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
|
||||
ps.newSender = func(context.Context) (batchSender, error) {
|
||||
return secondSender, nil
|
||||
}
|
||||
|
||||
firstBatch := []queuedPublish{{
|
||||
event: "chat:stream:first",
|
||||
channelClass: batchChannelClass("chat:stream:first"),
|
||||
message: []byte("first"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(firstBatch)))
|
||||
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
|
||||
require.Equal(t, 1, firstSender.CloseCalls())
|
||||
|
||||
secondBatch := []queuedPublish{{
|
||||
event: "chat:stream:second",
|
||||
channelClass: batchChannelClass("chat:stream:second"),
|
||||
message: []byte("second"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(secondBatch)))
|
||||
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
|
||||
require.NoError(t, err)
|
||||
batches := secondSender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 1)
|
||||
require.Equal(t, []byte("second"), batches[0][0].message)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:error",
|
||||
channelClass: batchChannelClass("chat:stream:error"),
|
||||
message: []byte("error"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
|
||||
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
// Use a closed *sql.DB so that delegate.Publish returns a real
|
||||
// error instead of panicking on a nil pointer when the batching
|
||||
// queue falls back to the shared pool under pressure.
|
||||
closedDB := newClosedDB(t)
|
||||
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
|
||||
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
})
|
||||
return ps, delegate
|
||||
}
|
||||
|
||||
// newClosedDB returns an *sql.DB whose connections have been closed,
|
||||
// so any ExecContext call returns an error rather than panicking.
|
||||
func newClosedDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
return db
|
||||
}
|
||||
|
||||
type fakeBatchSender struct {
|
||||
mu sync.Mutex
|
||||
batches [][]queuedPublish
|
||||
flushes chan []queuedPublish
|
||||
started chan struct{}
|
||||
blockCh chan struct{}
|
||||
err error
|
||||
errStage string
|
||||
closeErr error
|
||||
closeCall int
|
||||
}
|
||||
|
||||
func newFakeBatchSender() *fakeBatchSender {
|
||||
return &fakeBatchSender{
|
||||
flushes: make(chan []queuedPublish, 16),
|
||||
started: make(chan struct{}, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
select {
|
||||
case s.started <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if s.blockCh != nil {
|
||||
select {
|
||||
case <-s.blockCh:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
clone := make([]queuedPublish, len(batch))
|
||||
for i, item := range batch {
|
||||
clone[i] = queuedPublish{
|
||||
event: item.event,
|
||||
message: bytes.Clone(item.message),
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.batches = append(s.batches, clone)
|
||||
s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case s.flushes <- clone:
|
||||
default:
|
||||
}
|
||||
if s.err == nil {
|
||||
return nil
|
||||
}
|
||||
if s.errStage != "" {
|
||||
return &batchFlushError{stage: s.errStage, err: s.err}
|
||||
}
|
||||
return s.err
|
||||
}
|
||||
|
||||
type metricWriter interface {
|
||||
Write(*dto.Metric) error
|
||||
}
|
||||
|
||||
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
|
||||
t.Helper()
|
||||
writer, ok := observer.(metricWriter)
|
||||
require.True(t, ok)
|
||||
|
||||
metric := &dto.Metric{}
|
||||
require.NoError(t, writer.Write(metric))
|
||||
histogram := metric.GetHistogram()
|
||||
require.NotNil(t, histogram)
|
||||
return histogram.GetSampleCount(), histogram.GetSampleSum()
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.closeCall++
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Batches() [][]queuedPublish {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
clone := make([][]queuedPublish, len(s.batches))
|
||||
for i, batch := range s.batches {
|
||||
clone[i] = make([]queuedPublish, len(batch))
|
||||
copy(clone[i], batch)
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) CloseCalls() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.closeCall
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
|
||||
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 1)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("hello")))
|
||||
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 4)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
|
||||
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
|
||||
require.NoError(t, senderConn.Close())
|
||||
|
||||
reconnected := false
|
||||
delivered := false
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
if !reconnected {
|
||||
select {
|
||||
case conn := <-trackedDriver.Connections:
|
||||
reconnected = conn != nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-messageCh:
|
||||
default:
|
||||
}
|
||||
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case msg := <-messageCh:
|
||||
delivered = string(msg) == "after-disconnect"
|
||||
case <-time.After(testutil.IntervalFast):
|
||||
delivered = false
|
||||
}
|
||||
return reconnected && delivered
|
||||
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
|
||||
}
|
||||
@@ -487,12 +487,14 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
// Creates a new listener using pq.
|
||||
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
|
||||
if db == nil {
|
||||
return nil, xerrors.New("database is nil")
|
||||
}
|
||||
|
||||
var (
|
||||
dialer = logDialer{
|
||||
logger: p.logger,
|
||||
logger: logger,
|
||||
// pq.defaultDialer uses a zero net.Dialer as well.
|
||||
d: net.Dialer{},
|
||||
}
|
||||
@@ -501,28 +503,38 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
)
|
||||
|
||||
// Create a custom connector if the database driver supports it.
|
||||
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
|
||||
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
|
||||
if ok {
|
||||
connector, err = connectorCreator.Connector(connectURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create custom connector: %w", err)
|
||||
return nil, xerrors.Errorf("create custom connector: %w", err)
|
||||
}
|
||||
} else {
|
||||
// use the default pq connector otherwise
|
||||
// Use the default pq connector otherwise.
|
||||
connector, err = pq.NewConnector(connectURL)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create pq connector: %w", err)
|
||||
return nil, xerrors.Errorf("create pq connector: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the dialer if the connector supports it.
|
||||
dc, ok := connector.(database.DialerConnector)
|
||||
if !ok {
|
||||
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
} else {
|
||||
dc.Dialer(dialer)
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
errCh = make(chan error, 1)
|
||||
sentErrCh = false
|
||||
|
||||
@@ -76,7 +76,6 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -102,8 +101,6 @@ type sqlcQuerier interface {
|
||||
// be recreated.
|
||||
DeleteAllWebpushSubscriptions(ctx context.Context) error
|
||||
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error)
|
||||
DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error)
|
||||
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
|
||||
@@ -171,7 +168,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -196,16 +193,6 @@ type sqlcQuerier interface {
|
||||
FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error)
|
||||
FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error)
|
||||
FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error)
|
||||
// Marks orphaned in-progress rows as interrupted so they do not stay
|
||||
// in a non-terminal state forever. The NOT IN list must match the
|
||||
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
|
||||
//
|
||||
// The steps CTE also catches steps whose parent run was just finalized
|
||||
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
|
||||
// same snapshot and cannot see each other's row updates. Without this,
|
||||
// a step with a recent updated_at would survive its run's finalization
|
||||
// and remain in 'in_progress' state permanently.
|
||||
FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error)
|
||||
// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters.
|
||||
// It returns the preset ID if a match is found, or NULL if no match is found.
|
||||
// The query finds presets where all preset parameters are present in the provided parameters,
|
||||
@@ -228,7 +215,6 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -270,15 +256,6 @@ type sqlcQuerier interface {
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
|
||||
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
|
||||
// allows users to opt into chat debug logging when the deployment does
|
||||
// not already force debug logging on globally.
|
||||
GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error)
|
||||
GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error)
|
||||
// Returns the most recent debug runs for a chat, ordered newest-first.
|
||||
// Callers must supply an explicit limit to avoid unbounded result sets.
|
||||
GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error)
|
||||
GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error)
|
||||
GetChatDesktopEnabled(ctx context.Context) (bool, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
@@ -439,12 +416,11 @@ type sqlcQuerier interface {
|
||||
// per PR for state/additions/deletions/model (model comes from the
|
||||
// most recent chat).
|
||||
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
|
||||
// Returns all individual PR rows with cost for the selected time range.
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
// direct children (that lack their own PR), and deduped picks one row
|
||||
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
// large result sets from direct API callers.
|
||||
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
|
||||
// per PR for metadata.
|
||||
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
|
||||
// PR Insights queries for the /agents analytics dashboard.
|
||||
// These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
// with chats and chat_messages (cost) to power the PR Insights view.
|
||||
@@ -640,7 +616,6 @@ type sqlcQuerier interface {
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
|
||||
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
|
||||
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
|
||||
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
|
||||
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
|
||||
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
@@ -760,8 +735,6 @@ type sqlcQuerier interface {
|
||||
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
|
||||
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
|
||||
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
|
||||
InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error)
|
||||
InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error)
|
||||
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
|
||||
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
|
||||
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
|
||||
@@ -920,7 +893,6 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
@@ -940,16 +912,6 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error)
|
||||
// Uses COALESCE so that passing NULL from Go means "keep the
|
||||
// existing value." This is intentional: debug rows follow a
|
||||
// write-once-finalize pattern where fields are set at creation
|
||||
// or finalization and never cleared back to NULL.
|
||||
UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
@@ -1046,7 +1008,6 @@ type sqlcQuerier interface {
|
||||
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
|
||||
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
|
||||
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
|
||||
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
|
||||
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
|
||||
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
|
||||
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
|
||||
@@ -1078,9 +1039,6 @@ type sqlcQuerier interface {
|
||||
// cumulative values for unique counts (accurate period totals). Request counts
|
||||
// are always deltas, accumulated in DB. Returns true if insert, false if update.
|
||||
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
|
||||
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
|
||||
// allows users to opt into chat debug logging.
|
||||
UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error
|
||||
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
@@ -1118,7 +1076,6 @@ type sqlcQuerier interface {
|
||||
// used to store the data, and the minutes are summed for each user and template
|
||||
// combination. The result is stored in the template_usage_stats table.
|
||||
UpsertTemplateUsageStats(ctx context.Context) error
|
||||
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
|
||||
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
|
||||
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
|
||||
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
|
||||
|
||||
+21
-980
File diff suppressed because it is too large
Load Diff
+53
-824
File diff suppressed because it is too large
Load Diff
@@ -1,205 +0,0 @@
|
||||
-- name: InsertChatDebugRun :one
|
||||
INSERT INTO chat_debug_runs (
|
||||
chat_id,
|
||||
root_chat_id,
|
||||
parent_chat_id,
|
||||
model_config_id,
|
||||
trigger_message_id,
|
||||
history_tip_message_id,
|
||||
kind,
|
||||
status,
|
||||
provider,
|
||||
model,
|
||||
summary,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('root_chat_id')::uuid,
|
||||
sqlc.narg('parent_chat_id')::uuid,
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
sqlc.narg('trigger_message_id')::bigint,
|
||||
sqlc.narg('history_tip_message_id')::bigint,
|
||||
@kind::text,
|
||||
@status::text,
|
||||
sqlc.narg('provider')::text,
|
||||
sqlc.narg('model')::text,
|
||||
COALESCE(sqlc.narg('summary')::jsonb, '{}'::jsonb),
|
||||
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
|
||||
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
|
||||
sqlc.narg('finished_at')::timestamptz
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatDebugRun :one
|
||||
-- Uses COALESCE so that passing NULL from Go means "keep the
|
||||
-- existing value." This is intentional: debug rows follow a
|
||||
-- write-once-finalize pattern where fields are set at creation
|
||||
-- or finalization and never cleared back to NULL.
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
root_chat_id = COALESCE(sqlc.narg('root_chat_id')::uuid, root_chat_id),
|
||||
parent_chat_id = COALESCE(sqlc.narg('parent_chat_id')::uuid, parent_chat_id),
|
||||
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
|
||||
trigger_message_id = COALESCE(sqlc.narg('trigger_message_id')::bigint, trigger_message_id),
|
||||
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
|
||||
status = COALESCE(sqlc.narg('status')::text, status),
|
||||
provider = COALESCE(sqlc.narg('provider')::text, provider),
|
||||
model = COALESCE(sqlc.narg('model')::text, model),
|
||||
summary = COALESCE(sqlc.narg('summary')::jsonb, summary),
|
||||
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
AND chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: InsertChatDebugStep :one
|
||||
INSERT INTO chat_debug_steps (
|
||||
run_id,
|
||||
chat_id,
|
||||
step_number,
|
||||
operation,
|
||||
status,
|
||||
history_tip_message_id,
|
||||
assistant_message_id,
|
||||
normalized_request,
|
||||
normalized_response,
|
||||
usage,
|
||||
attempts,
|
||||
error,
|
||||
metadata,
|
||||
started_at,
|
||||
updated_at,
|
||||
finished_at
|
||||
)
|
||||
SELECT
|
||||
@run_id::uuid,
|
||||
run.chat_id,
|
||||
@step_number::int,
|
||||
@operation::text,
|
||||
@status::text,
|
||||
sqlc.narg('history_tip_message_id')::bigint,
|
||||
sqlc.narg('assistant_message_id')::bigint,
|
||||
COALESCE(sqlc.narg('normalized_request')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('normalized_response')::jsonb,
|
||||
sqlc.narg('usage')::jsonb,
|
||||
COALESCE(sqlc.narg('attempts')::jsonb, '[]'::jsonb),
|
||||
sqlc.narg('error')::jsonb,
|
||||
COALESCE(sqlc.narg('metadata')::jsonb, '{}'::jsonb),
|
||||
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
|
||||
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
|
||||
sqlc.narg('finished_at')::timestamptz
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.id = @run_id::uuid
|
||||
AND run.chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: UpdateChatDebugStep :one
|
||||
-- Uses COALESCE so that passing NULL from Go means "keep the
|
||||
-- existing value." This is intentional: debug rows follow a
|
||||
-- write-once-finalize pattern where fields are set at creation
|
||||
-- or finalization and never cleared back to NULL.
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = COALESCE(sqlc.narg('status')::text, status),
|
||||
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
|
||||
assistant_message_id = COALESCE(sqlc.narg('assistant_message_id')::bigint, assistant_message_id),
|
||||
normalized_request = COALESCE(sqlc.narg('normalized_request')::jsonb, normalized_request),
|
||||
normalized_response = COALESCE(sqlc.narg('normalized_response')::jsonb, normalized_response),
|
||||
usage = COALESCE(sqlc.narg('usage')::jsonb, usage),
|
||||
attempts = COALESCE(sqlc.narg('attempts')::jsonb, attempts),
|
||||
error = COALESCE(sqlc.narg('error')::jsonb, error),
|
||||
metadata = COALESCE(sqlc.narg('metadata')::jsonb, metadata),
|
||||
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid
|
||||
AND chat_id = @chat_id::uuid
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetChatDebugRunsByChatID :many
|
||||
-- Returns the most recent debug runs for a chat, ordered newest-first.
|
||||
-- Callers must supply an explicit limit to avoid unbounded result sets.
|
||||
SELECT *
|
||||
FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
ORDER BY started_at DESC, id DESC
|
||||
LIMIT @limit_val::int;
|
||||
|
||||
-- name: GetChatDebugRunByID :one
|
||||
SELECT *
|
||||
FROM chat_debug_runs
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: GetChatDebugStepsByRunID :many
|
||||
SELECT *
|
||||
FROM chat_debug_steps
|
||||
WHERE run_id = @run_id::uuid
|
||||
ORDER BY step_number ASC, started_at ASC;
|
||||
|
||||
-- name: DeleteChatDebugDataByChatID :execrows
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: DeleteChatDebugDataAfterMessageID :execrows
|
||||
WITH affected_runs AS (
|
||||
SELECT DISTINCT run.id
|
||||
FROM chat_debug_runs run
|
||||
WHERE run.chat_id = @chat_id::uuid
|
||||
AND (
|
||||
run.history_tip_message_id > @message_id::bigint
|
||||
OR run.trigger_message_id > @message_id::bigint
|
||||
)
|
||||
|
||||
UNION
|
||||
|
||||
SELECT DISTINCT step.run_id AS id
|
||||
FROM chat_debug_steps step
|
||||
WHERE step.chat_id = @chat_id::uuid
|
||||
AND (
|
||||
step.assistant_message_id > @message_id::bigint
|
||||
OR step.history_tip_message_id > @message_id::bigint
|
||||
)
|
||||
)
|
||||
DELETE FROM chat_debug_runs
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND id IN (SELECT id FROM affected_runs);
|
||||
|
||||
-- name: FinalizeStaleChatDebugRows :one
|
||||
-- Marks orphaned in-progress rows as interrupted so they do not stay
|
||||
-- in a non-terminal state forever. The NOT IN list must match the
|
||||
-- terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
|
||||
--
|
||||
-- The steps CTE also catches steps whose parent run was just finalized
|
||||
-- (via run_id IN), because PostgreSQL data-modifying CTEs share the
|
||||
-- same snapshot and cannot see each other's row updates. Without this,
|
||||
-- a step with a recent updated_at would survive its run's finalization
|
||||
-- and remain in 'in_progress' state permanently.
|
||||
WITH finalized_runs AS (
|
||||
UPDATE chat_debug_runs
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE updated_at < @updated_before::timestamptz
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING id
|
||||
), finalized_steps AS (
|
||||
UPDATE chat_debug_steps
|
||||
SET
|
||||
status = 'interrupted',
|
||||
updated_at = NOW(),
|
||||
finished_at = NOW()
|
||||
WHERE (
|
||||
updated_at < @updated_before::timestamptz
|
||||
OR run_id IN (SELECT id FROM finalized_runs)
|
||||
)
|
||||
AND finished_at IS NULL
|
||||
AND status NOT IN ('completed', 'error', 'interrupted')
|
||||
RETURNING 1
|
||||
)
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
|
||||
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized;
|
||||
@@ -173,12 +173,11 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsPullRequests :many
|
||||
-- Returns all individual PR rows with cost for the selected time range.
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
-- Returns individual PR rows with cost for the recent PRs table.
|
||||
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
-- direct children (that lack their own PR), and deduped picks one row
|
||||
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
-- large result sets from direct API callers.
|
||||
-- per PR for metadata.
|
||||
WITH pr_costs AS (
|
||||
SELECT
|
||||
prc.pr_key,
|
||||
@@ -265,4 +264,4 @@ SELECT * FROM (
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
) sub
|
||||
ORDER BY sub.created_at DESC
|
||||
LIMIT 500;
|
||||
LIMIT @limit_val::int;
|
||||
|
||||
@@ -353,18 +353,20 @@ WHERE
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
END
|
||||
AND CASE
|
||||
-- Cursor pagination: the last element on a page acts as the cursor.
|
||||
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
|
||||
-- (pin_order is negated so lower values sort first in DESC order),
|
||||
-- which lets us use a single tuple < comparison.
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the updated_at field, so select all
|
||||
-- rows before the cursor.
|
||||
(updated_at, id) < (
|
||||
SELECT
|
||||
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
|
||||
updated_at, id
|
||||
FROM
|
||||
chats c2
|
||||
chats
|
||||
WHERE
|
||||
c2.id = @after_id
|
||||
id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -376,15 +378,9 @@ WHERE
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
|
||||
-- pinned chats, lower pin_order values come first. The negation
|
||||
-- trick (-pin_order) keeps all sort columns DESC so the cursor
|
||||
-- tuple < comparison works with uniform direction.
|
||||
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
|
||||
-pin_order DESC,
|
||||
updated_at DESC,
|
||||
id DESC
|
||||
OFFSET @offset_opt
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
@@ -1297,26 +1293,3 @@ GROUP BY cm.chat_id;
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -195,8 +195,7 @@ SELECT
|
||||
w.id AS workspace_id,
|
||||
COALESCE(w.name, '') AS workspace_name,
|
||||
-- Include the name of the provisioner_daemon associated to the job
|
||||
COALESCE(pd.name, '') AS worker_name,
|
||||
wb.transition as workspace_build_transition
|
||||
COALESCE(pd.name, '') AS worker_name
|
||||
FROM
|
||||
provisioner_jobs pj
|
||||
LEFT JOIN
|
||||
@@ -241,8 +240,7 @@ GROUP BY
|
||||
t.icon,
|
||||
w.id,
|
||||
w.name,
|
||||
pd.name,
|
||||
wb.transition
|
||||
pd.name
|
||||
ORDER BY
|
||||
pj.created_at DESC
|
||||
LIMIT
|
||||
|
||||
@@ -179,31 +179,6 @@ SET value = CASE
|
||||
END
|
||||
WHERE site_configs.key = 'agents_desktop_enabled';
|
||||
|
||||
-- GetChatDebugLoggingAllowUsers returns the runtime admin setting that
|
||||
-- allows users to opt into chat debug logging when the deployment does
|
||||
-- not already force debug logging on globally.
|
||||
-- name: GetChatDebugLoggingAllowUsers :one
|
||||
SELECT
|
||||
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users;
|
||||
|
||||
-- UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
|
||||
-- allows users to opt into chat debug logging.
|
||||
-- name: UpsertChatDebugLoggingAllowUsers :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES (
|
||||
'agents_chat_debug_logging_allow_users',
|
||||
CASE
|
||||
WHEN sqlc.arg(allow_users)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = CASE
|
||||
WHEN sqlc.arg(allow_users)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users';
|
||||
|
||||
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
|
||||
-- Returns an empty string when no allowlist has been configured (all templates allowed).
|
||||
-- name: GetChatTemplateAllowlist :one
|
||||
|
||||
@@ -56,6 +56,6 @@ SET
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -213,31 +213,6 @@ RETURNING *;
|
||||
-- name: DeleteUserChatCompactionThreshold :exec
|
||||
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
|
||||
|
||||
-- name: GetUserChatDebugLoggingEnabled :one
|
||||
SELECT
|
||||
value = 'true' AS debug_logging_enabled
|
||||
FROM user_configs
|
||||
WHERE user_id = @user_id
|
||||
AND key = 'chat_debug_logging_enabled';
|
||||
|
||||
-- name: UpsertUserChatDebugLoggingEnabled :exec
|
||||
INSERT INTO user_configs (user_id, key, value)
|
||||
VALUES (
|
||||
@user_id,
|
||||
'chat_debug_logging_enabled',
|
||||
CASE
|
||||
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
)
|
||||
ON CONFLICT ON CONSTRAINT user_configs_pkey
|
||||
DO UPDATE SET value = CASE
|
||||
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
|
||||
ELSE 'false'
|
||||
END
|
||||
WHERE user_configs.user_id = @user_id
|
||||
AND user_configs.key = 'chat_debug_logging_enabled';
|
||||
|
||||
-- name: GetUserTaskNotificationAlertDismissed :one
|
||||
SELECT
|
||||
value::boolean as task_notification_alert_dismissed
|
||||
|
||||
@@ -190,14 +190,6 @@ SET
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceAgentDirectoryByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
directory = $2, updated_at = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentLogsAfter :many
|
||||
SELECT
|
||||
*
|
||||
|
||||
@@ -15,8 +15,6 @@ const (
|
||||
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
|
||||
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
|
||||
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
|
||||
UniqueChatDebugRunsPkey UniqueConstraint = "chat_debug_runs_pkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
|
||||
UniqueChatDebugStepsPkey UniqueConstraint = "chat_debug_steps_pkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
|
||||
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
|
||||
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
|
||||
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
|
||||
@@ -130,8 +128,6 @@ const (
|
||||
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
|
||||
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
|
||||
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
|
||||
UniqueIndexChatDebugRunsIDChat UniqueConstraint = "idx_chat_debug_runs_id_chat" // CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
|
||||
UniqueIndexChatDebugStepsRunStep UniqueConstraint = "idx_chat_debug_steps_run_step" // CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
|
||||
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
|
||||
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
|
||||
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
|
||||
|
||||
+77
-70
@@ -137,9 +137,8 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -147,44 +146,54 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
if err != nil {
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
<-ctx.Done()
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -1810,9 +1819,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// - pinOrder > 0 && already pinned: reorder (shift
|
||||
// neighbors, clamp to [1, count]).
|
||||
// - pinOrder > 0 && not pinned: append to end. The
|
||||
// requested value is intentionally ignored; the
|
||||
// SQL ORDER BY sorts pinned chats first so they
|
||||
// appear on page 1 of the paginated sidebar.
|
||||
// requested value is intentionally ignored because
|
||||
// PinChatByID also bumps updated_at to keep the
|
||||
// chat visible in the paginated sidebar.
|
||||
var err error
|
||||
errMsg := "Failed to pin chat."
|
||||
switch {
|
||||
@@ -2167,7 +2176,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2190,22 +2198,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2213,30 +2206,41 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return encoder.Encode(batch)
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2269,7 +2273,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2278,6 +2282,8 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2287,7 +2293,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2302,7 +2308,6 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2316,7 +2321,8 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
@@ -5626,7 +5632,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
previousSummary database.GetPRInsightsSummaryRow
|
||||
timeSeries []database.GetPRInsightsTimeSeriesRow
|
||||
byModel []database.GetPRInsightsPerModelRow
|
||||
recentPRs []database.GetPRInsightsPullRequestsRow
|
||||
recentPRs []database.GetPRInsightsRecentPRsRow
|
||||
)
|
||||
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
@@ -5674,10 +5680,11 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
|
||||
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
LimitVal: 20,
|
||||
})
|
||||
return err
|
||||
})
|
||||
@@ -5787,10 +5794,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
PullRequests: prEntries,
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
RecentPRs: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
+98
-199
@@ -876,186 +876,6 @@ func TestListChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, allChats, totalChats)
|
||||
})
|
||||
|
||||
// Test that a pinned chat with an old updated_at appears on page 1.
|
||||
t.Run("PinnedOnFirstPage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create the chat that will later be pinned. It gets the
|
||||
// earliest updated_at because it is inserted first.
|
||||
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "pinned-chat",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill page 1 with newer chats so the pinned chat would
|
||||
// normally be pushed off the first page (default limit 50).
|
||||
const fillerCount = 51
|
||||
fillerChats := make([]codersdk.Chat, 0, fillerCount)
|
||||
for i := range fillerCount {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("filler-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
fillerChats = append(fillerChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach a terminal status so
|
||||
// updated_at is stable before paginating. A single
|
||||
// polling loop checks every chat per tick to avoid
|
||||
// O(N) separate Eventually loops.
|
||||
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
|
||||
pending := make(map[uuid.UUID]struct{}, len(allCreated))
|
||||
for _, c := range allCreated {
|
||||
pending[c.ID] = struct{}{}
|
||||
}
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
|
||||
})
|
||||
if listErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, ch := range all {
|
||||
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
|
||||
delete(pending, ch.ID)
|
||||
}
|
||||
}
|
||||
return len(pending) == 0
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the earliest chat.
|
||||
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch page 1 with default limit (50).
|
||||
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: 50},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The pinned chat must appear on page 1.
|
||||
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
|
||||
for _, c := range page1 {
|
||||
page1IDs[c.ID] = struct{}{}
|
||||
}
|
||||
_, found := page1IDs[pinnedChat.ID]
|
||||
require.True(t, found, "pinned chat should appear on page 1")
|
||||
|
||||
// The pinned chat should be the first item in the list.
|
||||
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
})
|
||||
|
||||
// Test cursor pagination with a mix of pinned and unpinned chats.
|
||||
t.Run("CursorWithPins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create 5 chats: 2 will be pinned, 3 unpinned.
|
||||
const totalChats = 5
|
||||
createdChats := make([]codersdk.Chat, 0, totalChats)
|
||||
for i := range totalChats {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
createdChats = append(createdChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach terminal status.
|
||||
// Check each chat by ID rather than fetching the full list.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
for _, c := range createdChats {
|
||||
ch, err := client.GetChat(ctx, c.ID)
|
||||
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
|
||||
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the first two chats (oldest updated_at).
|
||||
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Paginate with limit=2 using cursor (after_id).
|
||||
const pageSize = 2
|
||||
maxPages := totalChats/pageSize + 2
|
||||
var allPaginated []codersdk.Chat
|
||||
var afterID uuid.UUID
|
||||
for range maxPages {
|
||||
opts := &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: pageSize},
|
||||
}
|
||||
if afterID != uuid.Nil {
|
||||
opts.Pagination.AfterID = afterID
|
||||
}
|
||||
page, listErr := client.ListChats(ctx, opts)
|
||||
require.NoError(t, listErr)
|
||||
if len(page) == 0 {
|
||||
break
|
||||
}
|
||||
allPaginated = append(allPaginated, page...)
|
||||
afterID = page[len(page)-1].ID
|
||||
}
|
||||
|
||||
// All chats should appear exactly once.
|
||||
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
|
||||
for _, c := range allPaginated {
|
||||
_, dup := seenIDs[c.ID]
|
||||
require.False(t, dup, "chat %s appeared more than once", c.ID)
|
||||
seenIDs[c.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
|
||||
|
||||
// Pinned chats should come before unpinned ones, and
|
||||
// within the pinned group, lower pin_order sorts first.
|
||||
pinnedSeen := false
|
||||
unpinnedSeen := false
|
||||
for _, c := range allPaginated {
|
||||
if c.PinOrder > 0 {
|
||||
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
|
||||
pinnedSeen = true
|
||||
} else {
|
||||
unpinnedSeen = true
|
||||
}
|
||||
}
|
||||
require.True(t, pinnedSeen, "at least one pinned chat should exist")
|
||||
|
||||
// Verify within-pinned ordering: pin_order=1 before
|
||||
// pin_order=2 (the -pin_order DESC column).
|
||||
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
|
||||
"pin_order=1 chat should be first")
|
||||
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
|
||||
"pin_order=2 chat should be second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatModels(t *testing.T) {
|
||||
@@ -1294,6 +1114,17 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1305,16 +1136,25 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1334,6 +1174,18 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1346,11 +1198,18 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1423,14 +1282,25 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1443,15 +1313,25 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1470,6 +1350,7 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1512,13 +1393,31 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1528,7 +1427,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1541,12 +1440,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -147,35 +146,12 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
|
||||
if params.state != "" {
|
||||
cancelQuery.Add("state", params.state)
|
||||
}
|
||||
cancel.RawQuery = cancelQuery.Encode()
|
||||
|
||||
cancelURI := cancel.String()
|
||||
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Invalid Callback URL",
|
||||
Description: "The application's registered callback URL has an invalid scheme.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
// #nosec G203 -- The scheme is validated by
|
||||
// codersdk.ValidateRedirectURIScheme above.
|
||||
CancelURI: htmltemplate.URL(cancelURI),
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -21,7 +20,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
|
||||
@@ -435,9 +435,6 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
|
||||
if pj.WorkspaceID.Valid {
|
||||
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
|
||||
}
|
||||
if pj.WorkspaceBuildTransition.Valid {
|
||||
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
|
||||
}
|
||||
return job
|
||||
}
|
||||
|
||||
|
||||
@@ -97,14 +97,13 @@ func TestProvisionerJobs(t *testing.T) {
|
||||
|
||||
// Verify that job metadata is correct.
|
||||
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
ChatEventKindActionRequired ChatEventKind = "action_required"
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Create a new user secret
|
||||
// @ID create-a-new-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
|
||||
// @Success 201 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [post]
|
||||
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
var req codersdk.CreateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Name is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Value is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Value: req.Value,
|
||||
ValueKeyID: sql.NullString{},
|
||||
EnvName: req.EnvName,
|
||||
FilePath: req.FilePath,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "A secret with that name, environment variable, or file path already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary List user secrets
|
||||
// @ID list-user-secrets
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Success 200 {array} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [get]
|
||||
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error listing secrets.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
|
||||
}
|
||||
|
||||
// @Summary Get a user secret by name
|
||||
// @ID get-a-user-secret-by-name
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [get]
|
||||
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Update a user secret
|
||||
// @ID update-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [patch]
|
||||
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var req codersdk.UpdateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one field must be provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
UpdateValue: req.Value != nil,
|
||||
Value: "",
|
||||
ValueKeyID: sql.NullString{},
|
||||
UpdateDescription: req.Description != nil,
|
||||
Description: "",
|
||||
UpdateEnvName: req.EnvName != nil,
|
||||
EnvName: "",
|
||||
UpdateFilePath: req.FilePath != nil,
|
||||
FilePath: "",
|
||||
}
|
||||
if req.Value != nil {
|
||||
params.Value = *req.Value
|
||||
}
|
||||
if req.Description != nil {
|
||||
params.Description = *req.Description
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
params.EnvName = *req.EnvName
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
params.FilePath = *req.FilePath
|
||||
}
|
||||
|
||||
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Update would conflict with an existing secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Delete a user secret
|
||||
// @ID delete-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 204
|
||||
// @Router /users/{user}/secrets/{name} [delete]
|
||||
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -1,413 +0,0 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPostUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub PAT",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
FilePath: "~/.github-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github-token", secret.Name)
|
||||
assert.Equal(t, "Personal GitHub PAT", secret.Description)
|
||||
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
|
||||
assert.Equal(t, "~/.github-token", secret.FilePath)
|
||||
assert.NotZero(t, secret.ID)
|
||||
assert.NotZero(t, secret.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Value: "some-value",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Name is required")
|
||||
})
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "missing-value-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Value is required")
|
||||
})
|
||||
|
||||
t.Run("DuplicateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-1",
|
||||
Value: "value1",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-2",
|
||||
Value: "value2",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-2",
|
||||
Value: "value2",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "invalid-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "1INVALID",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ReservedEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "reserved-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "PATH",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("CoderPrefixEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "coder-prefix-secret",
|
||||
Value: "value",
|
||||
EnvName: "CODER_AGENT_TOKEN",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "bad-path-secret",
|
||||
Value: "value",
|
||||
FilePath: "relative/path",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Verify no secrets exist on a fresh user.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, secrets)
|
||||
|
||||
t.Run("WithSecrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-a",
|
||||
Value: "value-a",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-b",
|
||||
Value: "value-b",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
// Sorted by name.
|
||||
assert.Equal(t, "list-secret-a", secrets[0].Name)
|
||||
assert.Equal(t, "list-secret-b", secrets[1].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "get-found-secret",
|
||||
Value: "my-value",
|
||||
EnvName: "GET_FOUND_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "get-found-secret", got.Name)
|
||||
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("UpdateDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-desc-secret",
|
||||
Value: "my-value",
|
||||
Description: "original",
|
||||
EnvName: "PATCH_DESC_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newDesc := "updated"
|
||||
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
|
||||
Description: &newDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", updated.Description)
|
||||
// Other fields unchanged.
|
||||
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NoFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-nofields-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
newVal := "new-value"
|
||||
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
|
||||
Value: &newVal,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-1",
|
||||
Value: "value1",
|
||||
EnvName: "CONFLICT_TAKEN_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "CONFLICT_TAKEN_ENV"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
|
||||
EnvName: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/conflict-taken",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "/tmp/conflict-taken"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
|
||||
FilePath: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "delete-me-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
@@ -42,8 +42,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -2395,598 +2393,3 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
|
||||
}
|
||||
return sdk
|
||||
}
|
||||
|
||||
// maxChatContextParts caps the number of parts per request to
|
||||
// prevent unbounded message payloads.
|
||||
const maxChatContextParts = 100
|
||||
|
||||
// maxChatContextFileBytes caps each context-file part to the same
|
||||
// 64KiB budget used when the agent reads instruction files from disk.
|
||||
const maxChatContextFileBytes = 64 * 1024
|
||||
|
||||
// maxChatContextRequestBodyBytes caps the JSON request body size for
|
||||
// agent-added context to roughly the same per-part budget used when
|
||||
// reading instruction files from disk.
|
||||
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
|
||||
|
||||
// sanitizeWorkspaceAgentContextFileContent applies prompt
|
||||
// sanitization, then enforces the 64KiB per-file budget. The
|
||||
// truncated flag is preserved when the caller already capped the
|
||||
// file before sending it.
|
||||
func sanitizeWorkspaceAgentContextFileContent(
|
||||
content string,
|
||||
truncated bool,
|
||||
) (string, bool) {
|
||||
content = chatd.SanitizePromptText(content)
|
||||
if len(content) > maxChatContextFileBytes {
|
||||
content = content[:maxChatContextFileBytes]
|
||||
truncated = true
|
||||
}
|
||||
return content, truncated
|
||||
}
|
||||
|
||||
// readChatContextBody reads and validates the request body for chat
|
||||
// context endpoints. It handles MaxBytesReader wrapping, error
|
||||
// responses, and body rewind. If the body is empty or whitespace-only
|
||||
// and allowEmpty is true, it returns false without writing an error.
|
||||
//
|
||||
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
|
||||
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Request body too large.",
|
||||
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
|
||||
})
|
||||
return false
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return false
|
||||
}
|
||||
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
|
||||
r.Body = http.NoBody
|
||||
return false
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return httpapi.Read(ctx, rw, r, dst)
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.AddChatContextRequest
|
||||
if !readChatContextBody(ctx, rw, r, &req, false) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) > maxChatContextParts {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to only non-empty context-file and skill parts.
|
||||
filtered := chatd.FilterContextParts(req.Parts, false)
|
||||
if len(filtered) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Parts = filtered
|
||||
responsePartCount := 0
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
// We verify agent-to-chat ownership explicitly below.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stamp each persisted part with the agent identity. Context-file
|
||||
// parts also get server-authoritative workspace metadata.
|
||||
directory := workspaceAgent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = workspaceAgent.Directory
|
||||
}
|
||||
for i := range req.Parts {
|
||||
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
|
||||
UUID: workspaceAgent.ID,
|
||||
Valid: true,
|
||||
}
|
||||
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
|
||||
req.Parts[i].ContextFileContent,
|
||||
req.Parts[i].ContextFileTruncated,
|
||||
)
|
||||
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
|
||||
req.Parts[i].ContextFileDirectory = directory
|
||||
}
|
||||
req.Parts = chatd.FilterContextParts(req.Parts, false)
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
responsePartCount = len(req.Parts)
|
||||
|
||||
// Skill-only messages need a sentinel context-file part so the turn
|
||||
// pipeline trusts the associated skill metadata.
|
||||
req.Parts = prependAgentChatContextSentinelIfNeeded(
|
||||
req.Parts,
|
||||
workspaceAgent.ID,
|
||||
workspaceAgent.OperatingSystem,
|
||||
directory,
|
||||
)
|
||||
|
||||
content, err := chatprompt.MarshalParts(req.Parts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal context parts.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspace.OwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
|
||||
chat.ID,
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
locked.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
uuid.Nil,
|
||||
)); err != nil {
|
||||
return xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("rebuild injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to persist context message.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
Count: responsePartCount,
|
||||
})
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.ClearChatContextRequest
|
||||
populated := readChatContextBody(ctx, rw, r, &req, true)
|
||||
if !populated && r.Body != http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
// Zero active chats is not an error for clear.
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
|
||||
return
|
||||
}
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to clear context from chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
errNoActiveChats = xerrors.New("no active chats found")
|
||||
errChatNotFound = xerrors.New("chat not found")
|
||||
errChatNotActive = xerrors.New("chat is not active")
|
||||
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
|
||||
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
|
||||
)
|
||||
|
||||
type multipleActiveChatsError struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (e *multipleActiveChatsError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"multiple active chats (%d) found for this agent, specify a chat ID",
|
||||
e.count,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
|
||||
switch len(chats) {
|
||||
case 0:
|
||||
return database.Chat{}, errNoActiveChats
|
||||
case 1:
|
||||
return chats[0], nil
|
||||
}
|
||||
|
||||
var rootChat *database.Chat
|
||||
for i := range chats {
|
||||
chat := &chats[i]
|
||||
if chat.ParentChatID.Valid {
|
||||
continue
|
||||
}
|
||||
if rootChat != nil {
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
rootChat = chat
|
||||
}
|
||||
if rootChat != nil {
|
||||
return *rootChat, nil
|
||||
}
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
|
||||
// resolveAgentChat finds the target chat from either an explicit ID
|
||||
// or auto-detection via the agent's active chats.
|
||||
func resolveAgentChat(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
explicitChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
if explicitChatID == uuid.Nil {
|
||||
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
|
||||
}
|
||||
ownerChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
continue
|
||||
}
|
||||
ownerChats = append(ownerChats, chat)
|
||||
}
|
||||
return resolveDefaultAgentChat(ownerChats)
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, explicitChatID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, errChatNotFound
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
|
||||
}
|
||||
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
|
||||
return database.Chat{}, errChatDoesNotBelongToAgent
|
||||
}
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if !isActiveAgentChat(chat) {
|
||||
return database.Chat{}, errChatNotActive
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func isActiveAgentChat(chat database.Chat) bool {
|
||||
if chat.Archived {
|
||||
return false
|
||||
}
|
||||
|
||||
switch chat.Status {
|
||||
case database.ChatStatusWaiting,
|
||||
database.ChatStatusPending,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusPaused,
|
||||
database.ChatStatusRequiresAction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clearAgentChatContext(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspaceOwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
hadInjectedContext := locked.LastInjectedContext.Valid
|
||||
var skillOnlyMessageIDs []int64
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
|
||||
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
|
||||
if hasContextFile || hasSkill {
|
||||
hadInjectedContext = true
|
||||
}
|
||||
if hasSkill && !hasContextFile {
|
||||
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
if !hadInjectedContext {
|
||||
return nil
|
||||
}
|
||||
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("soft delete context-file messages: %w", err)
|
||||
}
|
||||
for _, messageID := range skillOnlyMessageIDs {
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
|
||||
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
|
||||
}
|
||||
}
|
||||
// Reset provider-side Responses chaining so the next turn replays
|
||||
// the post-clear history instead of inheriting cleared context.
|
||||
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("clear provider response chain: %w", err)
|
||||
}
|
||||
// Clear the injected-context cache inside the transaction so it is
|
||||
// atomic with the soft-deletes.
|
||||
param, err := chatd.BuildLastInjectedContext(nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
|
||||
// part when the request only carries skills. The turn pipeline uses
|
||||
// the sentinel's agent metadata to trust the skill parts.
|
||||
func prependAgentChatContextSentinelIfNeeded(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
agentID uuid.UUID,
|
||||
operatingSystem string,
|
||||
directory string,
|
||||
) []codersdk.ChatMessagePart {
|
||||
hasContextFile := false
|
||||
hasSkill := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFile = true
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasSkill = true
|
||||
}
|
||||
if hasContextFile && hasSkill {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
if !hasSkill || hasContextFile {
|
||||
return parts
|
||||
}
|
||||
return append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: chatd.AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
ContextFileOS: operatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
}}, parts...)
|
||||
}
|
||||
|
||||
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
|
||||
sort.SliceStable(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// updateAgentChatLastInjectedContextFromMessages rebuilds the
|
||||
// injected-context cache from all persisted context-file and skill parts.
|
||||
func updateAgentChatLastInjectedContextFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) error {
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("load context messages for injected context: %w", err)
|
||||
}
|
||||
|
||||
sortChatMessagesByCreatedAtAndID(messages)
|
||||
|
||||
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect injected context parts: %w", err)
|
||||
}
|
||||
parts = chatd.FilterContextPartsToLatestAgent(parts)
|
||||
|
||||
param, err := chatd.BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, typ := range types {
|
||||
if part.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAgentChatError translates resolveAgentChat errors to HTTP
|
||||
// responses.
|
||||
func writeAgentChatError(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
err error,
|
||||
) {
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "No active chats found for this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotFound) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToAgent) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this workspace owner.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotActive) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Cannot modify context: this chat is no longer active.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var multipleErr *multipleActiveChatsError
|
||||
if errors.As(err, &multipleErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: owner.ID,
|
||||
}).WithAgent().Do()
|
||||
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
|
||||
|
||||
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, status := range database.AllChatStatusValues() {
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: status,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("%s-archived-%t", status, archived),
|
||||
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
insertedChats = append(insertedChats, chat)
|
||||
}
|
||||
}
|
||||
|
||||
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeByID := make(map[uuid.UUID]bool, len(activeChats))
|
||||
for _, chat := range activeChats {
|
||||
activeByID[chat.ID] = true
|
||||
}
|
||||
|
||||
for _, chat := range insertedChats {
|
||||
require.Equalf(
|
||||
t,
|
||||
isActiveAgentChat(chat),
|
||||
activeByID[chat.ID],
|
||||
"status=%s archived=%t",
|
||||
chat.Status,
|
||||
chat.Archived,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{
|
||||
{
|
||||
ID: 2,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: newContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: oldContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.True(t, arg.LastInjectedContext.Valid)
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 1)
|
||||
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
|
||||
return database.Chat{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
err = updateAgentChatLastInjectedContextFromMessages(
|
||||
context.Background(),
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db,
|
||||
chatID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func insertAgentChatTestModelConfig(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: createdBy,
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return model
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
|
||||
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
})
|
||||
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
+12
-69
@@ -213,39 +213,6 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Connecting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -280,10 +247,10 @@ func TestWorkspace(t *testing.T) {
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
|
||||
assert.False(t, agent.Health.Healthy)
|
||||
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Unhealthy", func(t *testing.T) {
|
||||
@@ -335,7 +302,6 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
a1AuthToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -347,9 +313,7 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "a1",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: a1AuthToken,
|
||||
},
|
||||
Auth: &proto.Agent_Token{},
|
||||
}, {
|
||||
Id: uuid.NewString(),
|
||||
Name: "a2",
|
||||
@@ -366,21 +330,13 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, a1AuthToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait for the mixed state: a1 connected (healthy)
|
||||
// and workspace unhealthy (because a2 timed out).
|
||||
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return agent1.Health.Healthy && !workspace.Health.Healthy
|
||||
return assert.NoError(t, err) && !workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
@@ -404,7 +360,6 @@ func TestWorkspace(t *testing.T) {
|
||||
// disconnected, but this should not make the workspace unhealthy.
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -416,9 +371,7 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "parent",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
Auth: &proto.Agent_Token{},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
@@ -430,23 +383,14 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Wait for the parent agent to connect and be healthy.
|
||||
var parentAgent codersdk.WorkspaceAgent
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return parentAgent.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
|
||||
// Get the workspace and parent agent.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
|
||||
|
||||
// Create a sub-agent with a short connection timeout so it becomes
|
||||
// unhealthy quickly (simulating a devcontainer rebuild scenario).
|
||||
@@ -460,7 +404,6 @@ func TestWorkspace(t *testing.T) {
|
||||
// Wait for the sub-agent to become unhealthy due to timeout.
|
||||
var subAgentUnhealthy bool
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
+53
-129
@@ -73,9 +73,9 @@ const (
|
||||
|
||||
// maxConcurrentRecordingUploads caps the number of recording
|
||||
// stop-and-store operations that can run concurrently. Each
|
||||
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
|
||||
// (110 MB) in memory, so this value implicitly bounds memory
|
||||
// to roughly maxConcurrentRecordingUploads * 110 MB.
|
||||
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
|
||||
// this value implicitly bounds memory to roughly
|
||||
// maxConcurrentRecordingUploads * 100 MB.
|
||||
maxConcurrentRecordingUploads = 25
|
||||
|
||||
// staleRecoveryIntervalDivisor determines how often the stale
|
||||
@@ -996,7 +996,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return database.Chat{}, txErr
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return chat, nil
|
||||
}
|
||||
@@ -1158,7 +1158,7 @@ func (p *Server) SendMessage(
|
||||
|
||||
p.publishMessage(opts.ChatID, result.Message)
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
return result, nil
|
||||
}
|
||||
@@ -1301,7 +1301,7 @@ func (p *Server) EditMessage(
|
||||
QueueUpdate: true,
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1355,10 +1355,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1373,7 +1373,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
codersdk.ChatWatchEventKindCreated,
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
@@ -1382,7 +1382,7 @@ func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind codersdk.ChatWatchEventKind,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
@@ -1545,7 +1545,7 @@ func (p *Server) PromoteQueued(
|
||||
})
|
||||
p.publishMessage(opts.ChatID, promoted)
|
||||
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -2092,7 +2092,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2347,7 +2347,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
|
||||
return database.Chat{}, err
|
||||
}
|
||||
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2461,33 +2461,6 @@ type chainModeInfo struct {
|
||||
// trailingUserCount is the number of contiguous user messages
|
||||
// at the end of the conversation that form the current turn.
|
||||
trailingUserCount int
|
||||
// contributingTrailingUserCount counts the trailing user
|
||||
// messages that materially change the provider input.
|
||||
contributingTrailingUserCount int
|
||||
}
|
||||
|
||||
func userMessageContributesToChainMode(msg database.ChatMessage) bool {
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeText,
|
||||
codersdk.ChatMessagePartTypeReasoning:
|
||||
if strings.TrimSpace(part.Text) != "" {
|
||||
return true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeFile,
|
||||
codersdk.ChatMessagePartTypeFileReference:
|
||||
return true
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if part.ContextFileContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveChainMode scans DB messages from the end to count trailing user
|
||||
@@ -2497,13 +2470,11 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
var info chainModeInfo
|
||||
i := len(messages) - 1
|
||||
for ; i >= 0; i-- {
|
||||
if messages[i].Role != database.ChatMessageRoleUser {
|
||||
break
|
||||
}
|
||||
info.trailingUserCount++
|
||||
if userMessageContributesToChainMode(messages[i]) {
|
||||
info.contributingTrailingUserCount++
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
info.trailingUserCount++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
for ; i >= 0; i-- {
|
||||
switch messages[i].Role {
|
||||
@@ -2526,15 +2497,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// filterPromptForChainMode keeps only system messages and the trailing
|
||||
// user messages that still contribute model-visible content to the
|
||||
// current turn. Assistant and tool messages are dropped because the
|
||||
// provider already has them via the previous_response_id chain.
|
||||
// filterPromptForChainMode keeps only system messages and the last
|
||||
// trailingUserCount user messages from the prompt. Assistant and tool
|
||||
// messages are dropped because the provider already has them via the
|
||||
// previous_response_id chain.
|
||||
func filterPromptForChainMode(
|
||||
prompt []fantasy.Message,
|
||||
info chainModeInfo,
|
||||
trailingUserCount int,
|
||||
) []fantasy.Message {
|
||||
if info.contributingTrailingUserCount <= 0 {
|
||||
if trailingUserCount <= 0 {
|
||||
return prompt
|
||||
}
|
||||
|
||||
@@ -2545,12 +2516,7 @@ func filterPromptForChainMode(
|
||||
}
|
||||
}
|
||||
|
||||
// Prompt construction already drops user turns with no model-visible
|
||||
// content, such as skill-only sentinel messages. That means the user
|
||||
// count here stays aligned with contributingTrailingUserCount even
|
||||
// when non-contributing DB turns are interleaved in the trailing
|
||||
// block.
|
||||
usersToSkip := totalUsers - info.contributingTrailingUserCount
|
||||
usersToSkip := totalUsers - trailingUserCount
|
||||
if usersToSkip < 0 {
|
||||
usersToSkip = 0
|
||||
}
|
||||
@@ -2596,28 +2562,6 @@ func appendChatMessage(
|
||||
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
|
||||
}
|
||||
|
||||
// BuildSingleChatMessageInsertParams creates batch insert params for one
|
||||
// message using the shared chat message builder.
|
||||
func BuildSingleChatMessageInsertParams(
|
||||
chatID uuid.UUID,
|
||||
role database.ChatMessageRole,
|
||||
content pqtype.NullRawMessage,
|
||||
visibility database.ChatMessageVisibility,
|
||||
modelConfigID uuid.UUID,
|
||||
contentVersion int16,
|
||||
createdBy uuid.UUID,
|
||||
) database.InsertChatMessagesParams {
|
||||
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: chatID,
|
||||
}
|
||||
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
|
||||
if createdBy != uuid.Nil {
|
||||
msg = msg.withCreatedBy(createdBy)
|
||||
}
|
||||
appendChatMessage(¶ms, msg)
|
||||
return params
|
||||
}
|
||||
|
||||
func insertUserMessageAndSetPending(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -3627,7 +3571,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
@@ -3635,7 +3579,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.Ch
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
@@ -3647,7 +3591,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
|
||||
if diffStatus != nil {
|
||||
sdkChat.DiffStatus = diffStatus
|
||||
}
|
||||
event := codersdk.ChatWatchEvent{
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: kind,
|
||||
Chat: sdkChat,
|
||||
}
|
||||
@@ -3659,7 +3603,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("kind", kind),
|
||||
@@ -3692,8 +3636,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
toolCalls := pendingToStreamToolCalls(pending)
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindActionRequired,
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindActionRequired,
|
||||
Chat: sdkChat,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
@@ -3705,7 +3649,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
@@ -3733,7 +3677,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
|
||||
}
|
||||
|
||||
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4215,7 +4159,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
if title, ok := generatedTitle.Load(); ok {
|
||||
updatedChat.Title = title
|
||||
}
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
|
||||
// When the chat is parked in requires_action,
|
||||
// publish the stream event and global pubsub event
|
||||
@@ -4486,21 +4430,13 @@ func (p *Server) runChat(
|
||||
// the workspace agent has changed (e.g. workspace rebuilt).
|
||||
needsInstructionPersist := false
|
||||
hasContextFiles := false
|
||||
persistedSkills := skillsFromParts(messages)
|
||||
latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages)
|
||||
currentWorkspaceAgentID := uuid.Nil
|
||||
hasCurrentWorkspaceAgent := false
|
||||
if chat.WorkspaceID.Valid {
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
||||
currentWorkspaceAgentID = agent.ID
|
||||
hasCurrentWorkspaceAgent = true
|
||||
}
|
||||
persistedAgentID, found := contextFileAgentID(messages)
|
||||
hasContextFiles = found
|
||||
if !hasPersistedInstructionFiles(messages) {
|
||||
if !hasContextFiles {
|
||||
needsInstructionPersist = true
|
||||
} else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID {
|
||||
// Agent changed. Persist fresh instruction files.
|
||||
} else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID {
|
||||
// Agent changed — persist fresh instruction files.
|
||||
// Old context-file messages remain in the conversation
|
||||
// to preserve the prompt cache prefix.
|
||||
needsInstructionPersist = true
|
||||
@@ -4523,8 +4459,7 @@ func (p *Server) runChat(
|
||||
if needsInstructionPersist {
|
||||
g2.Go(func() error {
|
||||
var persistErr error
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
instruction, discoveredSkills, persistErr = p.persistInstructionFiles(
|
||||
instruction, skills, persistErr = p.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig.ID,
|
||||
@@ -4536,12 +4471,6 @@ func (p *Server) runChat(
|
||||
return workspaceCtx.getWorkspaceConn(instructionCtx)
|
||||
},
|
||||
)
|
||||
skills = selectSkillMetasForInstructionRefresh(
|
||||
persistedSkills,
|
||||
discoveredSkills,
|
||||
uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent},
|
||||
uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent},
|
||||
)
|
||||
if persistErr != nil {
|
||||
p.logger.Warn(ctx, "failed to persist instruction files",
|
||||
slog.F("chat_id", chat.ID),
|
||||
@@ -4556,7 +4485,7 @@ func (p *Server) runChat(
|
||||
// re-injected via InsertSystem after compaction drops
|
||||
// those messages. No workspace dial needed.
|
||||
instruction = instructionFromContextFiles(messages)
|
||||
skills = persistedSkills
|
||||
skills = skillsFromParts(messages)
|
||||
}
|
||||
g2.Go(func() error {
|
||||
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
|
||||
@@ -5174,14 +5103,14 @@ func (p *Server) runChat(
|
||||
// assistant and tool messages that the provider already has.
|
||||
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
|
||||
chainInfo.previousResponseID != "" &&
|
||||
chainInfo.contributingTrailingUserCount > 0 &&
|
||||
chainInfo.trailingUserCount > 0 &&
|
||||
chainInfo.modelConfigID == modelConfig.ID
|
||||
if chainModeActive {
|
||||
providerOptions = chatprovider.CloneWithPreviousResponseID(
|
||||
providerOptions,
|
||||
chainInfo.previousResponseID,
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
@@ -5235,7 +5164,7 @@ func (p *Server) runChat(
|
||||
if chainModeActive {
|
||||
reloadedPrompt = filterPromptForChainMode(
|
||||
reloadedPrompt,
|
||||
chainInfo,
|
||||
chainInfo.trailingUserCount,
|
||||
)
|
||||
}
|
||||
return reloadedPrompt, nil
|
||||
@@ -5608,9 +5537,8 @@ func refreshChatWorkspaceSnapshot(
|
||||
}
|
||||
|
||||
// contextFileAgentID extracts the workspace agent ID from the most
|
||||
// recent persisted instruction-file parts. The skill-only sentinel is
|
||||
// ignored because it does not represent persisted instruction content.
|
||||
// Returns uuid.Nil, false if no instruction-file parts exist.
|
||||
// recent persisted context-file parts. Returns uuid.Nil, false if no
|
||||
// context-file parts exist.
|
||||
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
@@ -5623,14 +5551,11 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
continue
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!p.ContextFileAgentID.Valid ||
|
||||
p.ContextFilePath == AgentChatContextSentinelPath {
|
||||
continue
|
||||
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid {
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return lastID, found
|
||||
@@ -5700,14 +5625,13 @@ func (p *Server) persistInstructionFiles(
|
||||
// agent cannot know its own UUID, OS metadata, or
|
||||
// directory — those are added here at the trust boundary.
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
var hasContent, hasContextFilePart bool
|
||||
var hasContent bool
|
||||
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
|
||||
|
||||
for i := range agentParts {
|
||||
agentParts[i].ContextFileAgentID = agentID
|
||||
switch agentParts[i].Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFilePart = true
|
||||
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
|
||||
agentParts[i].ContextFileOS = agent.OperatingSystem
|
||||
agentParts[i].ContextFileDirectory = directory
|
||||
@@ -5728,13 +5652,13 @@ func (p *Server) persistInstructionFiles(
|
||||
if !workspaceConnOK {
|
||||
return "", nil, nil
|
||||
}
|
||||
// Persist a blank context-file marker (plus any skill-only
|
||||
// parts) so subsequent turns skip the workspace agent dial.
|
||||
if !hasContextFilePart {
|
||||
agentParts = append([]codersdk.ChatMessagePart{{
|
||||
// Persist a sentinel (plus any skill-only parts) so
|
||||
// subsequent turns skip the workspace agent dial.
|
||||
if len(agentParts) == 0 {
|
||||
agentParts = []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFileAgentID: agentID,
|
||||
}}, agentParts...)
|
||||
}}
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(agentParts)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -71,14 +70,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
updatedChat.Title = wantTitle
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload codersdk.ChatWatchEvent
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload codersdk.ChatWatchEvent
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -184,7 +183,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -234,14 +233,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
unlockedChat.StartedAt = sql.NullTime{}
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload codersdk.ChatWatchEvent
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatWatchEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload codersdk.ChatWatchEvent
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
@@ -373,7 +372,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
@@ -704,33 +703,7 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
|
||||
gomock.Any(),
|
||||
agentID,
|
||||
).Return(workspaceAgent, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.InsertChatMessagesParams)
|
||||
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
|
||||
return false
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
foundMarker := false
|
||||
foundSkill := false
|
||||
for _, p := range parts {
|
||||
switch p.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
|
||||
foundMarker = true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
|
||||
foundSkill = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundMarker && foundSkill
|
||||
}),
|
||||
).Return(nil, nil).Times(1)
|
||||
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
|
||||
gomock.Cond(func(x any) bool {
|
||||
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
|
||||
@@ -2047,30 +2020,6 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instructionAgentID := uuid.New()
|
||||
sentinelAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: sentinelAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
id, ok := contextFileAgentID(msgs)
|
||||
require.Equal(t, instructionAgentID, id)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []database.ChatMessage{
|
||||
@@ -2087,492 +2036,6 @@ func TestContextFileAgentID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasPersistedInstructionFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
}}),
|
||||
}
|
||||
require.False(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
|
||||
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
agentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/workspace/AGENTS.md",
|
||||
ContextFileContent: "repo instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
require.True(t, hasPersistedInstructionFiles(msgs))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/legacy/AGENTS.md",
|
||||
ContextFileContent: "legacy instructions",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}}),
|
||||
}
|
||||
|
||||
got := instructionFromContextFiles(msgs)
|
||||
require.Contains(t, got, "legacy instructions")
|
||||
require.Contains(t, got, "new instructions")
|
||||
require.Contains(t, got, "Operating System: linux")
|
||||
require.Contains(t, got, "Working Directory: /new")
|
||||
require.NotContains(t, got, "old instructions")
|
||||
require.NotContains(t, got, "Operating System: darwin")
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-legacy",
|
||||
SkillDir: "/skills/repo-helper-legacy",
|
||||
}}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
|
||||
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
msgs := []database.ChatMessage{
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
SkillDir: "/skills/repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
SkillDir: "/skills/repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
got := skillsFromParts(msgs)
|
||||
require.Equal(t, []chattool.SkillMeta{{
|
||||
Name: "repo-helper-new",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
}}, got)
|
||||
}
|
||||
|
||||
func TestMergeSkillMetas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{
|
||||
Name: "repo-helper",
|
||||
Description: "Persisted skill",
|
||||
Dir: "/skills/repo-helper-old",
|
||||
}}
|
||||
discovered := []chattool.SkillMeta{
|
||||
{
|
||||
Name: "repo-helper",
|
||||
Description: "Discovered replacement",
|
||||
Dir: "/skills/repo-helper-new",
|
||||
MetaFile: "SKILL.md",
|
||||
},
|
||||
{
|
||||
Name: "deep-review",
|
||||
Description: "Discovered skill",
|
||||
Dir: "/skills/deep-review",
|
||||
},
|
||||
}
|
||||
|
||||
got := mergeSkillMetas(persisted, discovered)
|
||||
require.Equal(t, []chattool.SkillMeta{
|
||||
discovered[0],
|
||||
discovered[1],
|
||||
}, got)
|
||||
}
|
||||
|
||||
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
|
||||
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
|
||||
currentAgentID := uuid.New()
|
||||
otherAgentID := uuid.New()
|
||||
|
||||
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
|
||||
})
|
||||
|
||||
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
discovered,
|
||||
uuid.NullUUID{UUID: currentAgentID, Valid: true},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, discovered, got)
|
||||
})
|
||||
|
||||
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := selectSkillMetasForInstructionRefresh(
|
||||
persisted,
|
||||
nil,
|
||||
uuid.NullUUID{},
|
||||
uuid.NullUUID{UUID: otherAgentID, Valid: true},
|
||||
)
|
||||
require.Equal(t, persisted, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
user.Role = database.ChatMessageRoleUser
|
||||
|
||||
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
|
||||
require.Equal(t, "resp-123", got.previousResponseID)
|
||||
require.Equal(t, modelConfigID, got.modelConfigID)
|
||||
require.Equal(t, 2, got.trailingUserCount)
|
||||
require.Equal(t, 1, got.contributingTrailingUserCount)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "first trailing user",
|
||||
}})
|
||||
firstTrailingUser.Role = database.ChatMessageRoleUser
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "last trailing user",
|
||||
}})
|
||||
lastTrailingUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
firstTrailingUser,
|
||||
skillOnly,
|
||||
lastTrailingUser,
|
||||
})
|
||||
require.Equal(t, 3, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "first trailing user"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "last trailing user"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 3)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
|
||||
|
||||
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "first trailing user", firstPart.Text)
|
||||
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "last trailing user", lastPart.Text)
|
||||
}
|
||||
|
||||
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
modelConfigID := uuid.New()
|
||||
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "prior user message",
|
||||
}})
|
||||
priorUser.Role = database.ChatMessageRoleUser
|
||||
assistant := database.ChatMessage{
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
}
|
||||
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper",
|
||||
SkillDir: "/skills/repo-helper",
|
||||
},
|
||||
})
|
||||
skillOnly.Role = database.ChatMessageRoleUser
|
||||
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: "latest user message",
|
||||
}})
|
||||
latestUser.Role = database.ChatMessageRoleUser
|
||||
|
||||
chainInfo := resolveChainMode([]database.ChatMessage{
|
||||
priorUser,
|
||||
assistant,
|
||||
skillOnly,
|
||||
latestUser,
|
||||
})
|
||||
require.Equal(t, 2, chainInfo.trailingUserCount)
|
||||
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
|
||||
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "system instruction"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "prior user message"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "assistant reply"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "latest user message"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := filterPromptForChainMode(prompt, chainInfo)
|
||||
require.Len(t, got, 2)
|
||||
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
|
||||
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
|
||||
|
||||
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "latest user message", part.Text)
|
||||
}
|
||||
|
||||
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
|
||||
raw, _ := json.Marshal(parts)
|
||||
return database.ChatMessage{
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type (
|
||||
runContextKey struct{}
|
||||
stepContextKey struct{}
|
||||
reuseStepKey struct{}
|
||||
reuseHolder struct {
|
||||
mu sync.Mutex
|
||||
handle *stepHandle
|
||||
}
|
||||
)
|
||||
|
||||
// ContextWithRun stores rc in ctx.
|
||||
//
|
||||
// Step counter cleanup is reference-counted per RunID: each live
|
||||
// RunContext increments a counter and runtime.AddCleanup decrements
|
||||
// it when the struct is garbage collected. Shared state (step
|
||||
// counters) is only deleted when the last RunContext for a given
|
||||
// RunID becomes unreachable, preventing premature cleanup when
|
||||
// multiple RunContext instances share the same RunID.
|
||||
func ContextWithRun(ctx context.Context, rc *RunContext) context.Context {
|
||||
if rc == nil {
|
||||
panic("chatdebug: nil RunContext")
|
||||
}
|
||||
|
||||
enriched := context.WithValue(ctx, runContextKey{}, rc)
|
||||
if rc.RunID != uuid.Nil {
|
||||
trackRunRef(rc.RunID)
|
||||
runtime.AddCleanup(rc, func(id uuid.UUID) {
|
||||
releaseRunRef(id)
|
||||
}, rc.RunID)
|
||||
}
|
||||
return enriched
|
||||
}
|
||||
|
||||
// RunFromContext returns the debug run context stored in ctx.
|
||||
func RunFromContext(ctx context.Context) (*RunContext, bool) {
|
||||
rc, ok := ctx.Value(runContextKey{}).(*RunContext)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return rc, true
|
||||
}
|
||||
|
||||
// ContextWithStep stores sc in ctx.
|
||||
func ContextWithStep(ctx context.Context, sc *StepContext) context.Context {
|
||||
if sc == nil {
|
||||
panic("chatdebug: nil StepContext")
|
||||
}
|
||||
return context.WithValue(ctx, stepContextKey{}, sc)
|
||||
}
|
||||
|
||||
// StepFromContext returns the debug step context stored in ctx.
|
||||
func StepFromContext(ctx context.Context) (*StepContext, bool) {
|
||||
sc, ok := ctx.Value(stepContextKey{}).(*StepContext)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return sc, true
|
||||
}
|
||||
|
||||
// ReuseStep marks ctx so wrapped model calls under it share one debug step.
|
||||
func ReuseStep(ctx context.Context) context.Context {
|
||||
if holder, ok := reuseHolderFromContext(ctx); ok {
|
||||
return context.WithValue(ctx, reuseStepKey{}, holder)
|
||||
}
|
||||
return context.WithValue(ctx, reuseStepKey{}, &reuseHolder{})
|
||||
}
|
||||
|
||||
func reuseHolderFromContext(ctx context.Context) (*reuseHolder, bool) {
|
||||
holder, ok := ctx.Value(reuseStepKey{}).(*reuseHolder)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return holder, true
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestReuseStep_PreservesExistingHolder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := ReuseStep(context.Background())
|
||||
first, ok := reuseHolderFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
|
||||
reused := ReuseStep(ctx)
|
||||
second, ok := reuseHolderFromContext(reused)
|
||||
require.True(t, ok)
|
||||
require.Same(t, first, second)
|
||||
}
|
||||
|
||||
func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
func() {
|
||||
_ = ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
_, ok := stepCounters.Load(runID)
|
||||
return !ok
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
// rc2 is the surviving instance that should keep the step counter alive.
|
||||
rc2 := &RunContext{RunID: runID, ChatID: chatID}
|
||||
ctx2 := ContextWithRun(context.Background(), rc2)
|
||||
|
||||
// Create a second RunContext with the same RunID and let it become
|
||||
// unreachable. Its GC cleanup must NOT delete the step counter
|
||||
// because rc2 is still alive.
|
||||
func() {
|
||||
rc1 := &RunContext{RunID: runID, ChatID: chatID}
|
||||
ctx1 := ContextWithRun(context.Background(), rc1)
|
||||
h, _ := beginStep(ctx1, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, h)
|
||||
require.Equal(t, int32(1), h.stepCtx.StepNumber)
|
||||
}()
|
||||
|
||||
// Force GC to collect rc1.
|
||||
for range 5 {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
}
|
||||
|
||||
// The step counter must still be present because rc2 is alive.
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok, "step counter was prematurely cleaned up while another RunContext is still alive")
|
||||
|
||||
// Subsequent steps on the surviving context must continue numbering.
|
||||
h2, _ := beginStep(ctx2, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
|
||||
require.NotNil(t, h2)
|
||||
require.Equal(t, int32(2), h2.stepCtx.StepNumber)
|
||||
}
|
||||
|
||||
func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
// Run in a closure so the RunContext becomes unreachable after
|
||||
// context cancellation, allowing GC to trigger the cleanup.
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
|
||||
_, ok := stepCounters.Load(runID)
|
||||
require.True(t, ok)
|
||||
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// After the closure, the RunContext is unreachable.
|
||||
// runtime.AddCleanup fires during GC.
|
||||
require.Eventually(t, func() bool {
|
||||
runtime.GC()
|
||||
runtime.Gosched()
|
||||
_, ok := stepCounters.Load(runID)
|
||||
return !ok
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, int32(1), nextStepNumber(runID))
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestContextWithRunRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rc := &chatdebug.RunContext{
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
RootChatID: uuid.New(),
|
||||
ParentChatID: uuid.New(),
|
||||
ModelConfigID: uuid.New(),
|
||||
TriggerMessageID: 11,
|
||||
HistoryTipMessageID: 22,
|
||||
Kind: chatdebug.KindChatTurn,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-sonnet",
|
||||
}
|
||||
|
||||
ctx := chatdebug.ContextWithRun(context.Background(), rc)
|
||||
got, ok := chatdebug.RunFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, rc, got)
|
||||
require.Equal(t, *rc, *got)
|
||||
}
|
||||
|
||||
func TestRunFromContextAbsent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := chatdebug.RunFromContext(context.Background())
|
||||
require.False(t, ok)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestContextWithStepRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sc := &chatdebug.StepContext{
|
||||
StepID: uuid.New(),
|
||||
RunID: uuid.New(),
|
||||
ChatID: uuid.New(),
|
||||
StepNumber: 7,
|
||||
Operation: chatdebug.OperationStream,
|
||||
HistoryTipMessageID: 33,
|
||||
}
|
||||
|
||||
ctx := chatdebug.ContextWithStep(context.Background(), sc)
|
||||
got, ok := chatdebug.StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, sc, got)
|
||||
require.Equal(t, *sc, *got)
|
||||
}
|
||||
|
||||
func TestStepFromContextAbsent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := chatdebug.StepFromContext(context.Background())
|
||||
require.False(t, ok)
|
||||
require.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestContextWithRunAndStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rc := &chatdebug.RunContext{RunID: uuid.New(), ChatID: uuid.New()}
|
||||
sc := &chatdebug.StepContext{StepID: uuid.New(), RunID: rc.RunID, ChatID: rc.ChatID}
|
||||
|
||||
ctx := chatdebug.ContextWithStep(
|
||||
chatdebug.ContextWithRun(context.Background(), rc),
|
||||
sc,
|
||||
)
|
||||
|
||||
gotRun, ok := chatdebug.RunFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, rc, gotRun)
|
||||
|
||||
gotStep, ok := chatdebug.StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Same(t, sc, gotStep)
|
||||
}
|
||||
|
||||
func TestContextWithRunPanicsOnNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = chatdebug.ContextWithRun(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextWithStepPanicsOnNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
_ = chatdebug.ContextWithStep(context.Background(), nil)
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,331 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Checks unexported normalized structs against fantasy source types.
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// fieldDisposition documents whether a fantasy struct field is captured
|
||||
// by the corresponding normalized struct ("normalized") or
|
||||
// intentionally omitted ("skipped: <reason>"). The test fails when a
|
||||
// fantasy type gains a field that is not yet classified, forcing the
|
||||
// developer to decide whether to normalize or skip it.
|
||||
//
|
||||
// This mirrors the audit-table exhaustiveness check in
|
||||
// enterprise/audit/table.go — same idea, different domain.
|
||||
type fieldDisposition = map[string]string
|
||||
|
||||
// TestNormalizationFieldCoverage ensures every exported field on the
|
||||
// fantasy types that model.go normalizes is explicitly accounted for.
|
||||
// When the fantasy library adds a field the test fails, surfacing the
|
||||
// drift at `go test` time rather than silently dropping data.
|
||||
func TestNormalizationFieldCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ reflect.Type
|
||||
fields fieldDisposition
|
||||
}{
|
||||
// ── struct-to-struct mappings ──────────────────────────
|
||||
|
||||
{
|
||||
name: "fantasy.Usage → normalizedUsage",
|
||||
typ: reflect.TypeFor[fantasy.Usage](),
|
||||
fields: fieldDisposition{
|
||||
"InputTokens": "normalized",
|
||||
"OutputTokens": "normalized",
|
||||
"TotalTokens": "normalized",
|
||||
"ReasoningTokens": "normalized",
|
||||
"CacheCreationTokens": "normalized",
|
||||
"CacheReadTokens": "normalized",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.Call → normalizedCallPayload",
|
||||
typ: reflect.TypeFor[fantasy.Call](),
|
||||
fields: fieldDisposition{
|
||||
"Prompt": "normalized",
|
||||
"MaxOutputTokens": "normalized",
|
||||
"Temperature": "normalized",
|
||||
"TopP": "normalized",
|
||||
"TopK": "normalized",
|
||||
"PresencePenalty": "normalized",
|
||||
"FrequencyPenalty": "normalized",
|
||||
"Tools": "normalized",
|
||||
"ToolChoice": "normalized",
|
||||
"UserAgent": "skipped: internal transport header, not useful for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider data, only count preserved",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectCall → normalizedObjectCallPayload",
|
||||
typ: reflect.TypeFor[fantasy.ObjectCall](),
|
||||
fields: fieldDisposition{
|
||||
"Prompt": "normalized",
|
||||
"Schema": "skipped: full schema too large; SchemaName+SchemaDescription captured instead",
|
||||
"SchemaName": "normalized",
|
||||
"SchemaDescription": "normalized",
|
||||
"MaxOutputTokens": "normalized",
|
||||
"Temperature": "normalized",
|
||||
"TopP": "normalized",
|
||||
"TopK": "normalized",
|
||||
"PresencePenalty": "normalized",
|
||||
"FrequencyPenalty": "normalized",
|
||||
"UserAgent": "skipped: internal transport header, not useful for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider data, only count preserved",
|
||||
"RepairText": "skipped: function value, not serializable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.Response → normalizedResponsePayload",
|
||||
typ: reflect.TypeFor[fantasy.Response](),
|
||||
fields: fieldDisposition{
|
||||
"Content": "normalized",
|
||||
"FinishReason": "normalized",
|
||||
"Usage": "normalized",
|
||||
"Warnings": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectResponse → normalizedObjectResponsePayload",
|
||||
typ: reflect.TypeFor[fantasy.ObjectResponse](),
|
||||
fields: fieldDisposition{
|
||||
"Object": "skipped: arbitrary user type, not serializable generically",
|
||||
"RawText": "normalized: as RawTextLength (length only, content unbounded)",
|
||||
"Usage": "normalized",
|
||||
"FinishReason": "normalized",
|
||||
"Warnings": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.CallWarning → normalizedWarning",
|
||||
typ: reflect.TypeFor[fantasy.CallWarning](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized",
|
||||
"Setting": "normalized",
|
||||
"Tool": "skipped: interface value, warning message+type sufficient for debug panel",
|
||||
"Details": "normalized",
|
||||
"Message": "normalized",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.StreamPart → appendNormalizedStreamContent",
|
||||
typ: reflect.TypeFor[fantasy.StreamPart](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized",
|
||||
"ID": "normalized: as ToolCallID in content parts",
|
||||
"ToolCallName": "normalized: as ToolName in content parts",
|
||||
"ToolCallInput": "normalized: as Arguments or Result (bounded)",
|
||||
"Delta": "normalized: accumulated into text/reasoning content parts",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"Usage": "normalized: captured in stream finalize",
|
||||
"FinishReason": "normalized: captured in stream finalize",
|
||||
"Error": "normalized: captured in stream error handling",
|
||||
"Warnings": "normalized: captured in stream warning accumulation",
|
||||
"SourceType": "normalized",
|
||||
"URL": "normalized",
|
||||
"Title": "normalized",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ObjectStreamPart → wrapObjectStreamSeq",
|
||||
typ: reflect.TypeFor[fantasy.ObjectStreamPart](),
|
||||
fields: fieldDisposition{
|
||||
"Type": "normalized: drives switch in wrapObjectStreamSeq",
|
||||
"Object": "skipped: arbitrary user type, only ObjectPartCount tracked",
|
||||
"Delta": "normalized: accumulated into rawTextLength",
|
||||
"Error": "normalized: captured in stream error handling",
|
||||
"Usage": "normalized: captured in stream finalize",
|
||||
"FinishReason": "normalized: captured in stream finalize",
|
||||
"Warnings": "normalized: captured in stream warning accumulation",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
|
||||
// ── message part types (normalizeMessageParts) ────────
|
||||
|
||||
{
|
||||
name: "fantasy.TextPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.TextPart](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ReasoningPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ReasoningPart](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.FilePart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.FilePart](),
|
||||
fields: fieldDisposition{
|
||||
"Filename": "normalized",
|
||||
"Data": "skipped: binary data never stored in debug records",
|
||||
"MediaType": "normalized",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolCallPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ToolCallPart](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Input": "normalized: as Arguments (bounded)",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolResultPart → normalizedMessagePart",
|
||||
typ: reflect.TypeFor[fantasy.ToolResultPart](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"Output": "normalized: text extracted via normalizeToolResultOutput",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
|
||||
// ── response content types (normalizeContentParts) ────
|
||||
|
||||
{
|
||||
name: "fantasy.TextContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.TextContent](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ReasoningContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ReasoningContent](),
|
||||
fields: fieldDisposition{
|
||||
"Text": "normalized: bounded to MaxMessagePartTextLength",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.FileContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.FileContent](),
|
||||
fields: fieldDisposition{
|
||||
"MediaType": "normalized",
|
||||
"Data": "skipped: binary data never stored in debug records",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.SourceContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.SourceContent](),
|
||||
fields: fieldDisposition{
|
||||
"SourceType": "normalized",
|
||||
"ID": "skipped: provider-internal identifier, not actionable in debug panel",
|
||||
"URL": "normalized",
|
||||
"Title": "normalized",
|
||||
"MediaType": "skipped: only relevant for document sources, rarely useful for debugging",
|
||||
"Filename": "skipped: only relevant for document sources, rarely useful for debugging",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolCallContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ToolCallContent](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Input": "normalized: as Arguments (bounded), InputLength tracks original",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
"Invalid": "skipped: validation state not surfaced in debug panel",
|
||||
"ValidationError": "skipped: validation state not surfaced in debug panel",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ToolResultContent → normalizedContentPart",
|
||||
typ: reflect.TypeFor[fantasy.ToolResultContent](),
|
||||
fields: fieldDisposition{
|
||||
"ToolCallID": "normalized",
|
||||
"ToolName": "normalized",
|
||||
"Result": "normalized: text extracted via normalizeToolResultOutput",
|
||||
"ClientMetadata": "skipped: client execution metadata not needed for debug panel",
|
||||
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
|
||||
"ProviderMetadata": "skipped: opaque provider-specific metadata",
|
||||
},
|
||||
},
|
||||
|
||||
// ── tool types (normalizeTools) ───────────────────────
|
||||
|
||||
{
|
||||
name: "fantasy.FunctionTool → normalizedTool",
|
||||
typ: reflect.TypeFor[fantasy.FunctionTool](),
|
||||
fields: fieldDisposition{
|
||||
"Name": "normalized",
|
||||
"Description": "normalized",
|
||||
"InputSchema": "normalized: preserved as JSON for debug panel rendering",
|
||||
"ProviderOptions": "skipped: opaque provider-specific options",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fantasy.ProviderDefinedTool → normalizedTool",
|
||||
typ: reflect.TypeFor[fantasy.ProviderDefinedTool](),
|
||||
fields: fieldDisposition{
|
||||
"ID": "normalized",
|
||||
"Name": "normalized",
|
||||
"Args": "skipped: provider-specific configuration not needed for debug panel",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Every exported field on the fantasy type must be
|
||||
// registered as "normalized" or "skipped: <reason>".
|
||||
for i := range tt.typ.NumField() {
|
||||
field := tt.typ.Field(i)
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
disposition, ok := tt.fields[field.Name]
|
||||
if !ok {
|
||||
require.Failf(t, "unregistered field",
|
||||
"%s.%s is not in the coverage map — "+
|
||||
"add it as \"normalized\" or \"skipped: <reason>\"",
|
||||
tt.typ.Name(), field.Name)
|
||||
}
|
||||
require.NotEmptyf(t, disposition,
|
||||
"%s.%s has an empty disposition — "+
|
||||
"use \"normalized\" or \"skipped: <reason>\"",
|
||||
tt.typ.Name(), field.Name)
|
||||
}
|
||||
|
||||
// Catch stale entries that reference removed fields.
|
||||
for name := range tt.fields {
|
||||
found := false
|
||||
for i := range tt.typ.NumField() {
|
||||
if tt.typ.Field(i).Name == name {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Truef(t, found,
|
||||
"stale coverage entry %s.%s — "+
|
||||
"field no longer exists in fantasy, remove it",
|
||||
tt.typ.Name(), name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,987 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type testError struct{ message string }
|
||||
|
||||
func (e *testError) Error() string { return e.message }
|
||||
|
||||
func expectDebugLoggingEnabled(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
ownerID uuid.UUID,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().GetChatDebugLoggingEnabled(gomock.Any()).Return(true, nil)
|
||||
db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil)
|
||||
}
|
||||
|
||||
func expectCreateStepNumberWithRequestValidity(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
normalizedRequestValid bool,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
stepID := uuid.New()
|
||||
db.EXPECT().
|
||||
InsertChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, runID, params.RunID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.Equal(t, stepNumber, params.StepNumber)
|
||||
require.Equal(t, string(op), params.Operation)
|
||||
require.Equal(t, string(StatusInProgress), params.Status)
|
||||
require.Equal(t, normalizedRequestValid, params.NormalizedRequest.Valid)
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
RunID: runID,
|
||||
ChatID: chatID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: params.Operation,
|
||||
Status: params.Status,
|
||||
}, nil
|
||||
})
|
||||
|
||||
// CreateStep now touches the parent run's updated_at to prevent
|
||||
// premature stale finalization.
|
||||
db.EXPECT().
|
||||
UpdateChatDebugRun(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
|
||||
require.Equal(t, runID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
|
||||
})
|
||||
|
||||
return stepID
|
||||
}
|
||||
|
||||
func expectCreateStepNumber(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
stepNumber int32,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
stepNumber,
|
||||
op,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func expectCreateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
runID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
op Operation,
|
||||
) uuid.UUID {
|
||||
t.Helper()
|
||||
|
||||
return expectCreateStepNumber(t, db, runID, chatID, 1, op)
|
||||
}
|
||||
|
||||
func expectUpdateStep(
|
||||
t *testing.T,
|
||||
db *dbmock.MockStore,
|
||||
stepID uuid.UUID,
|
||||
chatID uuid.UUID,
|
||||
status Status,
|
||||
assertFn func(database.UpdateChatDebugStepParams),
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
db.EXPECT().
|
||||
UpdateChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugStepParams{})).
|
||||
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
|
||||
require.Equal(t, stepID, params.ID)
|
||||
require.Equal(t, chatID, params.ChatID)
|
||||
require.True(t, params.Status.Valid)
|
||||
require.Equal(t, string(status), params.Status.String)
|
||||
require.True(t, params.FinishedAt.Valid)
|
||||
|
||||
if assertFn != nil {
|
||||
assertFn(params)
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{
|
||||
ID: stepID,
|
||||
ChatID: chatID,
|
||||
Status: params.Status.String,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugModel_Provider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
||||
model := &debugModel{inner: inner}
|
||||
|
||||
require.Equal(t, inner.Provider(), model.Provider())
|
||||
}
|
||||
|
||||
func TestDebugModel_Model(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
||||
model := &debugModel{inner: inner}
|
||||
|
||||
require.Equal(t, inner.Model(), model.Model())
|
||||
}
|
||||
|
||||
func TestDebugModel_Disabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
_, ok := StepFromContext(ctx)
|
||||
require.False(t, ok)
|
||||
require.Nil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{
|
||||
ChatID: chatID,
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := model.Generate(context.Background(), fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_Generate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
call := fantasy.Call{
|
||||
Prompt: fantasy.Prompt{fantasy.NewUserMessage("hello")},
|
||||
MaxOutputTokens: int64Ptr(128),
|
||||
Temperature: float64Ptr(0.25),
|
||||
}
|
||||
respWant := &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
fantasy.ToolCallContent{ToolCallID: "tool-1", ToolName: "tool", Input: `{}`},
|
||||
fantasy.SourceContent{ID: "source-1", Title: "docs", URL: "https://example.com"},
|
||||
},
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{InputTokens: 10, OutputTokens: 4, TotalTokens: 14},
|
||||
Warnings: []fantasy.CallWarning{{Message: "warning"}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
|
||||
// Verify actual JSON content so a broken tag or field
|
||||
// rename is caught rather than only checking .Valid.
|
||||
var usage fantasy.Usage
|
||||
require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage))
|
||||
require.EqualValues(t, 10, usage.InputTokens)
|
||||
require.EqualValues(t, 4, usage.OutputTokens)
|
||||
require.EqualValues(t, 14, usage.TotalTokens)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp))
|
||||
require.Equal(t, "stop", resp["finish_reason"])
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
|
||||
require.Equal(t, call, got)
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"message":"hello","api_key":"super-secret"}`,
|
||||
string(body))
|
||||
require.Equal(t, "Bearer top-secret", req.Header.Get("Authorization"))
|
||||
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Header().Set("X-API-Key", "response-secret")
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
|
||||
var attempts []Attempt
|
||||
require.NoError(t, json.Unmarshal(params.Attempts.RawMessage, &attempts))
|
||||
require.Len(t, attempts, 1)
|
||||
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
|
||||
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
server.URL,
|
||||
strings.NewReader(`{"message":"hello","api_key":"super-secret"}`),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", "Bearer top-secret")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
||||
require.NoError(t, resp.Body.Close())
|
||||
return &fantasy.Response{FinishReason: fantasy.FinishReasonStop}, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
wantErr := &testError{message: "boom"}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateFn: func(context.Context, fantasy.Call) (*fantasy.Response, error) {
|
||||
return nil, wantErr
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.Generate(ctx, fantasy.Call{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
}
|
||||
|
||||
func TestStepStatusForError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, stepStatusForError(context.Canceled))
|
||||
})
|
||||
|
||||
t.Run("DeadlineExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, stepStatusForError(context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("OtherError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, stepStatusForError(xerrors.New("boom")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestDebugModel_Stream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
errPart := xerrors.New("chunk failed")
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hel"},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "tool-call-1", ToolCallName: "tool"},
|
||||
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
|
||||
{Type: fantasy.StreamPartTypeWarnings, Warnings: []fantasy.CallWarning{{Message: "w1"}, {Message: "w2"}}},
|
||||
{Type: fantasy.StreamPartTypeError, Error: errPart},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := make([]fantasy.StreamPart, 0, len(parts))
|
||||
for part := range seq {
|
||||
got = append(got, part)
|
||||
}
|
||||
|
||||
require.Equal(t, parts, got)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
|
||||
{Type: fantasy.ObjectStreamPartTypeObject, Object: map[string]any{"value": "object"}},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.True(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
// Clean successes (no prior error) leave the error column
|
||||
// as SQL NULL rather than sending jsonClear.
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return objectPartsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := make([]fantasy.ObjectStreamPart, 0, len(parts))
|
||||
for part := range seq {
|
||||
got = append(got, part)
|
||||
}
|
||||
|
||||
require.Equal(t, parts, got)
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamCompletedAfterFinish verifies that when a consumer
|
||||
// stops iteration after receiving a finish part, the step is marked as
|
||||
// completed rather than interrupted.
|
||||
func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusCompleted is the
|
||||
// assertion: if the wrapper chose StatusInterrupted instead, the
|
||||
// mock would reject the call.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer reads the finish part then breaks — this should still
|
||||
// be considered a completed stream, not interrupted.
|
||||
for part := range seq {
|
||||
if part.Type == fantasy.StreamPartTypeFinish {
|
||||
break
|
||||
}
|
||||
}
|
||||
// gomock verifies UpdateStep was called with StatusCompleted.
|
||||
}
|
||||
|
||||
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
|
||||
// stops iteration before receiving a finish part, the step is marked as
|
||||
// interrupted.
|
||||
func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: " world"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
|
||||
// The mock expectation for UpdateStep with StatusInterrupted is the
|
||||
// assertion: breaking before the finish part means interrupted.
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, nil)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer reads the first delta then breaks before finish.
|
||||
count := 0
|
||||
for range seq {
|
||||
count++
|
||||
if count == 1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, count)
|
||||
// gomock verifies UpdateStep was called with StatusInterrupted.
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.False(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
var nilStream fantasy.StreamResponse
|
||||
return nilStream, nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.Nil(t, seq)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
|
||||
require.False(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.True(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamObjectFn: func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
var nilStream fantasy.ObjectStreamResponse
|
||||
return nilStream, nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, seq)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestDebugModel_StreamEarlyStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "first"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
|
||||
}
|
||||
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
|
||||
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, func(params database.UpdateChatDebugStepParams) {
|
||||
require.True(t, params.NormalizedResponse.Valid)
|
||||
require.False(t, params.Usage.Valid)
|
||||
require.True(t, params.Attempts.Valid)
|
||||
require.False(t, params.Error.Valid)
|
||||
require.True(t, params.Metadata.Valid)
|
||||
})
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return partsToSeq(parts), nil
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
seq, err := model.Stream(ctx, fantasy.Call{})
|
||||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
for part := range seq {
|
||||
require.Equal(t, parts[0], part)
|
||||
count++
|
||||
break
|
||||
}
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestStreamErrorStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("CancellationBecomesInterrupted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.Canceled))
|
||||
})
|
||||
|
||||
t.Run("DeadlineExceededBecomesInterrupted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("NilErrorBecomesError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, streamErrorStatus(StatusCompleted, nil))
|
||||
})
|
||||
|
||||
t.Run("ExistingErrorWins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, StatusError, streamErrorStatus(StatusError, context.Canceled))
|
||||
})
|
||||
}
|
||||
|
||||
func objectPartsToSeq(parts []fantasy.ObjectStreamPart) fantasy.ObjectStreamResponse {
|
||||
return func(yield func(fantasy.ObjectStreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func partsToSeq(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
||||
return func(yield func(fantasy.StreamPart) bool) {
|
||||
for _, part := range parts {
|
||||
if !yield(part) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
call := fantasy.ObjectCall{
|
||||
Prompt: fantasy.Prompt{fantasy.NewUserMessage("summarize")},
|
||||
SchemaName: "Summary",
|
||||
MaxOutputTokens: int64Ptr(256),
|
||||
}
|
||||
respWant := &fantasy.ObjectResponse{
|
||||
RawText: `{"title":"test"}`,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
|
||||
}
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
inner := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
require.Equal(t, call, got)
|
||||
stepCtx, ok := StepFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runID, stepCtx.RunID)
|
||||
require.Equal(t, chatID, stepCtx.ChatID)
|
||||
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
||||
require.NotNil(t, attemptSinkFromContext(ctx))
|
||||
return respWant, nil
|
||||
},
|
||||
}
|
||||
|
||||
model := &debugModel{
|
||||
inner: inner,
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, call)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, respWant, resp)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObjectError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
wantErr := &testError{message: "object boom"}
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, wantErr
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
}
|
||||
|
||||
func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
runID := uuid.New()
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
model := &debugModel{
|
||||
inner: &chattest.FakeModel{
|
||||
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return nil, nil //nolint:nilnil // Intentionally testing nil response handling.
|
||||
},
|
||||
},
|
||||
svc: svc,
|
||||
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
||||
}
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
|
||||
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
||||
require.Nil(t, resp)
|
||||
require.ErrorIs(t, err, ErrNilModelResult)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
// Create a context that we cancel after the stream finishes.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
||||
}
|
||||
seq := wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
||||
|
||||
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
||||
for range seq {
|
||||
}
|
||||
|
||||
// Cancel the context after the stream has been fully consumed
|
||||
// and finalized. The status should remain completed.
|
||||
cancel()
|
||||
|
||||
handle.mu.Lock()
|
||||
status := handle.status
|
||||
handle.mu.Unlock()
|
||||
require.Equal(t, StatusCompleted, status)
|
||||
}
|
||||
|
||||
func TestWrapObjectStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
parts := []fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "obj"},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 1, TotalTokens: 4}},
|
||||
}
|
||||
seq := wrapObjectStreamSeq(ctx, handle, objectPartsToSeq(parts))
|
||||
|
||||
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
||||
for range seq {
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
handle.mu.Lock()
|
||||
status := handle.status
|
||||
handle.mu.Unlock()
|
||||
require.Equal(t, StatusCompleted, status)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
parts := []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}
|
||||
|
||||
// Create the wrapped stream but never iterate it.
|
||||
_ = wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
||||
|
||||
// Cancel the context — the AfterFunc safety net should finalize
|
||||
// the step as interrupted.
|
||||
cancel()
|
||||
|
||||
// AfterFunc fires asynchronously; give it a moment.
|
||||
require.Eventually(t, func() bool {
|
||||
handle.mu.Lock()
|
||||
defer handle.mu.Unlock()
|
||||
return handle.status == StatusInterrupted
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 { return &v }
|
||||
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
@@ -1,379 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported normalization helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
func TestNormalizeCall_PreservesToolSchemasAndMessageToolPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeCall(fantasy.Call{
|
||||
Prompt: fantasy.Prompt{
|
||||
{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolCallPart{
|
||||
ToolCallID: "call-search",
|
||||
ToolName: "search_docs",
|
||||
Input: `{"query":"debug panel"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.ToolResultPart{
|
||||
ToolCallID: "call-search",
|
||||
Output: fantasy.ToolResultOutputContentText{
|
||||
Text: `{"matches":["model.go","DebugStepCard.tsx"]}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Tools: []fantasy.Tool{
|
||||
fantasy.FunctionTool{
|
||||
Name: "search_docs",
|
||||
Description: "Searches documentation.",
|
||||
InputSchema: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Tools, 1)
|
||||
require.True(t, payload.Tools[0].HasInputSchema)
|
||||
require.JSONEq(t, `{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`,
|
||||
string(payload.Tools[0].InputSchema))
|
||||
|
||||
require.Len(t, payload.Messages, 2)
|
||||
require.Equal(t, "tool-call", payload.Messages[0].Parts[0].Type)
|
||||
require.Equal(t, `{"query":"debug panel"}`, payload.Messages[0].Parts[0].Arguments)
|
||||
require.Equal(t, "tool-result", payload.Messages[1].Parts[0].Type)
|
||||
require.Equal(t,
|
||||
`{"matches":["model.go","DebugStepCard.tsx"]}`,
|
||||
payload.Messages[1].Parts[0].Result,
|
||||
)
|
||||
}
|
||||
|
||||
func TestNormalizers_SkipTypedNilInterfaceValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("MessageParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilPart *fantasy.TextPart
|
||||
parts := normalizeMessageParts([]fantasy.MessagePart{
|
||||
nilPart,
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
})
|
||||
require.Len(t, parts, 1)
|
||||
require.Equal(t, "text", parts[0].Type)
|
||||
require.Equal(t, "hello", parts[0].Text)
|
||||
})
|
||||
|
||||
t.Run("Tools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilTool *fantasy.FunctionTool
|
||||
tools := normalizeTools([]fantasy.Tool{
|
||||
nilTool,
|
||||
fantasy.FunctionTool{Name: "search_docs"},
|
||||
})
|
||||
require.Len(t, tools, 1)
|
||||
require.Equal(t, "function", tools[0].Type)
|
||||
require.Equal(t, "search_docs", tools[0].Name)
|
||||
})
|
||||
|
||||
t.Run("ContentParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilContent *fantasy.TextContent
|
||||
content := normalizeContentParts(fantasy.ResponseContent{
|
||||
nilContent,
|
||||
fantasy.TextContent{Text: "hello"},
|
||||
})
|
||||
require.Len(t, content, 1)
|
||||
require.Equal(t, "text", content[0].Type)
|
||||
require.Equal(t, "hello", content[0].Text)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppendNormalizedStreamContent_PreservesOrderAndCanonicalTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var content []normalizedContentPart
|
||||
streamDebugBytes := 0
|
||||
for _, part := range []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "before "},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"query":"debug"}`},
|
||||
{Type: fantasy.StreamPartTypeToolResult, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"matches":1}`},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "after"},
|
||||
} {
|
||||
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
|
||||
}
|
||||
|
||||
require.Equal(t, []normalizedContentPart{
|
||||
{Type: "text", Text: "before "},
|
||||
{Type: "tool-call", ToolCallID: "call-1", ToolName: "search_docs", Arguments: `{"query":"debug"}`, InputLength: len(`{"query":"debug"}`)},
|
||||
{Type: "tool-result", ToolCallID: "call-1", ToolName: "search_docs", Result: `{"matches":1}`},
|
||||
{Type: "text", Text: "after"},
|
||||
}, content)
|
||||
}
|
||||
|
||||
func TestAppendNormalizedStreamContent_GlobalTextCap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
streamDebugBytes := 0
|
||||
long := strings.Repeat("a", maxStreamDebugTextBytes)
|
||||
var content []normalizedContentPart
|
||||
for _, part := range []fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: long},
|
||||
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{}`},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, Delta: "tail"},
|
||||
} {
|
||||
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
|
||||
}
|
||||
|
||||
require.Len(t, content, 2)
|
||||
require.Equal(t, strings.Repeat("a", maxStreamDebugTextBytes), content[0].Text)
|
||||
require.Equal(t, "tool-call", content[1].Type)
|
||||
require.Equal(t, maxStreamDebugTextBytes, streamDebugBytes)
|
||||
}
|
||||
|
||||
func TestWrapStreamSeq_SourceCountExcludesToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
seq := wrapStreamSeq(context.Background(), handle, partsToSeq([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolResult, ID: "tool-1", ToolCallName: "search_docs"},
|
||||
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}))
|
||||
|
||||
partCount := 0
|
||||
for range seq {
|
||||
partCount++
|
||||
}
|
||||
require.Equal(t, 3, partCount)
|
||||
|
||||
metadata, ok := handle.metadata.(map[string]any)
|
||||
require.True(t, ok)
|
||||
summary, ok := metadata["stream_summary"].(streamSummary)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 1, summary.SourceCount)
|
||||
}
|
||||
|
||||
func TestWrapObjectStreamSeq_UsesStructuredOutputPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handle := &stepHandle{
|
||||
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
||||
sink: &attemptSink{},
|
||||
}
|
||||
usage := fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5}
|
||||
seq := wrapObjectStreamSeq(context.Background(), handle, objectPartsToSeq([]fantasy.ObjectStreamPart{
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
|
||||
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
|
||||
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: usage},
|
||||
}))
|
||||
|
||||
partCount := 0
|
||||
for range seq {
|
||||
partCount++
|
||||
}
|
||||
require.Equal(t, 3, partCount)
|
||||
|
||||
resp, ok := handle.response.(normalizedObjectResponsePayload)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, normalizedObjectResponsePayload{
|
||||
RawTextLength: len("object"),
|
||||
FinishReason: string(fantasy.FinishReasonStop),
|
||||
Usage: normalizeUsage(usage),
|
||||
StructuredOutput: true,
|
||||
}, resp)
|
||||
}
|
||||
|
||||
func TestNormalizeResponse_UsesCanonicalToolTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeResponse(&fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Input: `{"operation":"add","operands":[2,2]}`,
|
||||
},
|
||||
fantasy.ToolResultContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Result: fantasy.ToolResultOutputContentText{Text: `{"sum":4}`},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Content, 2)
|
||||
require.Equal(t, "tool-call", payload.Content[0].Type)
|
||||
require.Equal(t, "tool-result", payload.Content[1].Type)
|
||||
}
|
||||
|
||||
func TestBoundText_RespectsDocumentedRuneLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runes := make([]rune, MaxMessagePartTextLength+5)
|
||||
for i := range runes {
|
||||
runes[i] = 'a'
|
||||
}
|
||||
input := string(runes)
|
||||
got := boundText(input)
|
||||
require.Equal(t, MaxMessagePartTextLength, len([]rune(got)))
|
||||
require.Equal(t, '…', []rune(got)[len([]rune(got))-1])
|
||||
}
|
||||
|
||||
func TestNormalizeToolResultOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TextValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{Text: "hello"})
|
||||
require.Equal(t, "hello", got)
|
||||
})
|
||||
|
||||
t.Run("TextPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentText{Text: "hello"})
|
||||
require.Equal(t, "hello", got)
|
||||
})
|
||||
|
||||
t.Run("TextPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentText)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("tool failed"),
|
||||
})
|
||||
require.Equal(t, "tool failed", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorValueNilError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{Error: nil})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("ptr fail"),
|
||||
})
|
||||
require.Equal(t, "ptr fail", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentError)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("ErrorPointerNilError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{Error: nil})
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
|
||||
Text: "caption",
|
||||
MediaType: "image/png",
|
||||
})
|
||||
require.Equal(t, "caption", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithoutText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
|
||||
MediaType: "image/png",
|
||||
})
|
||||
require.Equal(t, "[media output: image/png]", got)
|
||||
})
|
||||
|
||||
t.Run("MediaWithoutTextOrType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{})
|
||||
require.Equal(t, "[media output]", got)
|
||||
})
|
||||
|
||||
t.Run("MediaPointerNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentMedia)(nil))
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("MediaPointerWithText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentMedia{
|
||||
Text: "ptr caption",
|
||||
MediaType: "image/jpeg",
|
||||
})
|
||||
require.Equal(t, "ptr caption", got)
|
||||
})
|
||||
|
||||
t.Run("NilOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeToolResultOutput(nil)
|
||||
require.Equal(t, "", got)
|
||||
})
|
||||
|
||||
t.Run("DefaultJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// An unexpected type falls through to the default JSON
|
||||
// marshal branch.
|
||||
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{
|
||||
Text: "fallback",
|
||||
})
|
||||
require.Equal(t, "fallback", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeResponse_PreservesToolCallArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := normalizeResponse(&fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "call-calc",
|
||||
ToolName: "calculator",
|
||||
Input: `{"operation":"add","operands":[2,2]}`,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.Len(t, payload.Content, 1)
|
||||
require.Equal(t, "call-calc", payload.Content[0].ToolCallID)
|
||||
require.Equal(t, "calculator", payload.Content[0].ToolName)
|
||||
require.JSONEq(t,
|
||||
`{"operation":"add","operands":[2,2]}`,
|
||||
payload.Content[0].Arguments,
|
||||
)
|
||||
require.Equal(t, len(`{"operation":"add","operands":[2,2]}`), payload.Content[0].InputLength)
|
||||
}
|
||||
@@ -1,319 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// RecorderOptions identifies the chat/model context for debug recording.
|
||||
type RecorderOptions struct {
|
||||
ChatID uuid.UUID
|
||||
OwnerID uuid.UUID
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// WrapModel returns model unchanged when debug recording is disabled, or a
|
||||
// debug wrapper when a service is available.
|
||||
func WrapModel(
|
||||
model fantasy.LanguageModel,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
) fantasy.LanguageModel {
|
||||
if model == nil {
|
||||
panic("chatdebug: nil LanguageModel")
|
||||
}
|
||||
if svc == nil {
|
||||
return model
|
||||
}
|
||||
return &debugModel{inner: model, svc: svc, opts: opts}
|
||||
}
|
||||
|
||||
type attemptSink struct {
|
||||
mu sync.Mutex
|
||||
attempts []Attempt
|
||||
attemptCounter atomic.Int32
|
||||
}
|
||||
|
||||
func (s *attemptSink) nextAttemptNumber() int {
|
||||
if s == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return int(s.attemptCounter.Add(1))
|
||||
}
|
||||
|
||||
func (s *attemptSink) record(a Attempt) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.attempts = append(s.attempts, a)
|
||||
}
|
||||
|
||||
func (s *attemptSink) snapshot() []Attempt {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
attempts := make([]Attempt, len(s.attempts))
|
||||
copy(attempts, s.attempts)
|
||||
return attempts
|
||||
}
|
||||
|
||||
type attemptSinkKey struct{}
|
||||
|
||||
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
|
||||
if sink == nil {
|
||||
panic("chatdebug: nil attemptSink")
|
||||
}
|
||||
return context.WithValue(ctx, attemptSinkKey{}, sink)
|
||||
}
|
||||
|
||||
func attemptSinkFromContext(ctx context.Context) *attemptSink {
|
||||
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
|
||||
return sink
|
||||
}
|
||||
|
||||
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
// runRefCounts tracks how many live RunContext instances reference each
|
||||
// RunID. Cleanup of shared state (step counters) is deferred until the
|
||||
// last RunContext for a given RunID is garbage collected.
|
||||
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
|
||||
|
||||
func trackRunRef(runID uuid.UUID) {
|
||||
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter := val.(*atomic.Int32)
|
||||
counter.Add(1)
|
||||
}
|
||||
|
||||
// releaseRunRef decrements the reference count for runID and cleans up
|
||||
// shared state when the last reference is released.
|
||||
func releaseRunRef(runID uuid.UUID) {
|
||||
val, ok := runRefCounts.Load(runID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter := val.(*atomic.Int32)
|
||||
if counter.Add(-1) <= 0 {
|
||||
runRefCounts.Delete(runID)
|
||||
stepCounters.Delete(runID)
|
||||
}
|
||||
}
|
||||
|
||||
func nextStepNumber(runID uuid.UUID) int32 {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// CleanupStepCounter removes per-run step counter and reference count
|
||||
// state. This is used by tests and later stacked branches that have a
|
||||
// real run lifecycle.
|
||||
func CleanupStepCounter(runID uuid.UUID) {
|
||||
stepCounters.Delete(runID)
|
||||
runRefCounts.Delete(runID)
|
||||
}
|
||||
|
||||
const stepFinalizeTimeout = 5 * time.Second
|
||||
|
||||
func stepFinalizeContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
panic("chatdebug: nil context")
|
||||
}
|
||||
return context.WithTimeout(context.WithoutCancel(ctx), stepFinalizeTimeout)
|
||||
}
|
||||
|
||||
func syncStepCounter(runID uuid.UUID, stepNumber int32) {
|
||||
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
|
||||
counter, ok := val.(*atomic.Int32)
|
||||
if !ok {
|
||||
panic("chatdebug: invalid step counter type")
|
||||
}
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= stepNumber {
|
||||
return
|
||||
}
|
||||
if counter.CompareAndSwap(current, stepNumber) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stepHandle struct {
|
||||
stepCtx *StepContext
|
||||
sink *attemptSink
|
||||
svc *Service
|
||||
opts RecorderOptions
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
response any
|
||||
usage any
|
||||
err any
|
||||
metadata any
|
||||
// hadError tracks whether a prior finalization wrote an error
|
||||
// payload. Used to decide whether a successful retry needs to
|
||||
// explicitly clear the error field via jsonClear.
|
||||
hadError bool
|
||||
}
|
||||
|
||||
// beginStep validates preconditions, creates a debug step, and returns a
|
||||
// handle plus an enriched context carrying StepContext and attemptSink.
|
||||
// Returns (nil, original ctx) when debug recording should be skipped.
|
||||
func beginStep(
|
||||
ctx context.Context,
|
||||
svc *Service,
|
||||
opts RecorderOptions,
|
||||
op Operation,
|
||||
normalizedReq any,
|
||||
) (*stepHandle, context.Context) {
|
||||
if svc == nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
rc, ok := RunFromContext(ctx)
|
||||
if !ok || rc.RunID == uuid.Nil {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
chatID := opts.ChatID
|
||||
if chatID == uuid.Nil {
|
||||
chatID = rc.ChatID
|
||||
}
|
||||
if !svc.IsEnabled(ctx, chatID, opts.OwnerID) {
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
holder, reuseStep := reuseHolderFromContext(ctx)
|
||||
if reuseStep {
|
||||
holder.mu.Lock()
|
||||
defer holder.mu.Unlock()
|
||||
// Only reuse the cached handle if it belongs to the same run.
|
||||
// A different RunContext means a new logical run, so we must
|
||||
// create a fresh step to avoid cross-run attribution.
|
||||
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
|
||||
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, holder.handle.sink)
|
||||
return holder.handle, enriched
|
||||
}
|
||||
}
|
||||
|
||||
stepNum := nextStepNumber(rc.RunID)
|
||||
step, err := svc.CreateStep(ctx, CreateStepParams{
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: stepNum,
|
||||
Operation: op,
|
||||
Status: StatusInProgress,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
NormalizedRequest: normalizedReq,
|
||||
})
|
||||
if err != nil {
|
||||
svc.log.Warn(ctx, "failed to create chat debug step",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("run_id", rc.RunID),
|
||||
slog.F("operation", op),
|
||||
)
|
||||
return nil, ctx
|
||||
}
|
||||
|
||||
syncStepCounter(rc.RunID, step.StepNumber)
|
||||
actualStepNumber := step.StepNumber
|
||||
if actualStepNumber == 0 {
|
||||
actualStepNumber = stepNum
|
||||
}
|
||||
|
||||
sc := &StepContext{
|
||||
StepID: step.ID,
|
||||
RunID: rc.RunID,
|
||||
ChatID: chatID,
|
||||
StepNumber: actualStepNumber,
|
||||
Operation: op,
|
||||
HistoryTipMessageID: rc.HistoryTipMessageID,
|
||||
}
|
||||
handle := &stepHandle{stepCtx: sc, sink: &attemptSink{}, svc: svc, opts: opts}
|
||||
enriched := ContextWithStep(ctx, handle.stepCtx)
|
||||
enriched = withAttemptSink(enriched, handle.sink)
|
||||
if reuseStep {
|
||||
holder.handle = handle
|
||||
}
|
||||
|
||||
return handle, enriched
|
||||
}
|
||||
|
||||
// finish updates the debug step with final status and data. A mutex
|
||||
// guards the write so concurrent callers (e.g. retried stream wrappers
|
||||
// sharing a reuse handle) don't race. Unlike sync.Once, later retries
|
||||
// are allowed to overwrite earlier failure results so the step reflects
|
||||
// the final outcome.
|
||||
func (h *stepHandle) finish(
|
||||
ctx context.Context,
|
||||
status Status,
|
||||
response any,
|
||||
usage any,
|
||||
errPayload any,
|
||||
metadata any,
|
||||
) {
|
||||
if h == nil || h.stepCtx == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
h.status = status
|
||||
h.response = response
|
||||
h.usage = usage
|
||||
h.err = errPayload
|
||||
h.metadata = metadata
|
||||
if errPayload != nil {
|
||||
h.hadError = true
|
||||
}
|
||||
if h.svc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateCtx, cancel := stepFinalizeContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
// When the step completes successfully after a prior failed
|
||||
// attempt, the error field must be explicitly cleared. A plain
|
||||
// nil would leave the COALESCE-based SQL untouched, so we send
|
||||
// jsonClear{} which serializes as a valid JSONB null. Only do
|
||||
// this when a prior error was actually recorded — otherwise
|
||||
// clean successes would get a spurious JSONB null that downstream
|
||||
// aggregation could misread as an error.
|
||||
errValue := errPayload
|
||||
if errValue == nil && status == StatusCompleted && h.hadError {
|
||||
errValue = jsonClear{}
|
||||
}
|
||||
|
||||
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
|
||||
ID: h.stepCtx.StepID,
|
||||
ChatID: h.stepCtx.ChatID,
|
||||
Status: status,
|
||||
NormalizedResponse: response,
|
||||
Usage: usage,
|
||||
Attempts: h.sink.snapshot(),
|
||||
Error: errValue,
|
||||
Metadata: metadata,
|
||||
FinishedAt: time.Now(),
|
||||
}); updateErr != nil {
|
||||
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
|
||||
slog.Error(updateErr),
|
||||
slog.F("step_id", h.stepCtx.StepID),
|
||||
slog.F("chat_id", h.stepCtx.ChatID),
|
||||
slog.F("status", status),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestAttemptSink_ThreadSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
sink := &attemptSink{}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
attempts := sink.snapshot()
|
||||
require.Len(t, attempts, n)
|
||||
|
||||
numbers := make([]int, 0, n)
|
||||
statuses := make([]int, 0, n)
|
||||
for _, attempt := range attempts {
|
||||
numbers = append(numbers, attempt.Number)
|
||||
statuses = append(statuses, attempt.ResponseStatus)
|
||||
}
|
||||
sort.Ints(numbers)
|
||||
sort.Ints(statuses)
|
||||
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, numbers[i])
|
||||
require.Equal(t, 200+i, statuses[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttemptSinkContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.Nil(t, attemptSinkFromContext(ctx))
|
||||
|
||||
sink := &attemptSink{}
|
||||
ctx = withAttemptSink(ctx, sink)
|
||||
require.Same(t, sink, attemptSinkFromContext(ctx))
|
||||
}
|
||||
|
||||
func TestWrapModel_NilModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Panics(t, func() {
|
||||
WrapModel(nil, &Service{}, RecorderOptions{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrapModel_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, nil, RecorderOptions{})
|
||||
require.Same(t, model, wrapped)
|
||||
}
|
||||
|
||||
func TestNextStepNumber_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const n = 256
|
||||
|
||||
runID := uuid.New()
|
||||
results := make([]int, n)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
|
||||
for i := range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results[i] = int(nextStepNumber(runID))
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
sort.Ints(results)
|
||||
for i := range n {
|
||||
require.Equal(t, i+1, results[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseCtx, cancelBase := context.WithCancel(context.Background())
|
||||
cancelBase()
|
||||
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
|
||||
|
||||
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
|
||||
defer cancelFinalize()
|
||||
|
||||
require.NoError(t, finalizeCtx.Err())
|
||||
_, hasDeadline := finalizeCtx.Deadline()
|
||||
require.True(t, hasDeadline)
|
||||
}
|
||||
|
||||
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
syncStepCounter(runID, 7)
|
||||
require.Equal(t, int32(8), nextStepNumber(runID))
|
||||
}
|
||||
|
||||
func TestStepHandleFinish_NilHandle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var handle *stepHandle
|
||||
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestBeginStep_NilService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
|
||||
require.Nil(t, handle)
|
||||
require.Nil(t, attemptSinkFromContext(enriched))
|
||||
_, ok := StepFromContext(enriched)
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
runID := uuid.New()
|
||||
runChatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
|
||||
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
|
||||
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
|
||||
require.NotNil(t, handle)
|
||||
require.Equal(t, runChatID, handle.stepCtx.ChatID)
|
||||
|
||||
stepCtx, ok := StepFromContext(enriched)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, runChatID, stepCtx.ChatID)
|
||||
}
|
||||
|
||||
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
||||
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
|
||||
|
||||
require.NotSame(t, model, wrapped)
|
||||
require.IsType(t, &debugModel{}, wrapped)
|
||||
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
|
||||
require.Equal(t, model.Provider(), wrapped.Provider())
|
||||
require.Equal(t, model.Model(), wrapped.Model())
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// RedactedValue replaces sensitive values in debug payloads.
|
||||
const RedactedValue = "[REDACTED]"
|
||||
|
||||
var sensitiveHeaderNames = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"x-api-key": {},
|
||||
"api-key": {},
|
||||
"proxy-authorization": {},
|
||||
"cookie": {},
|
||||
"set-cookie": {},
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyFragments triggers redaction for JSON keys containing
|
||||
// these substrings. Notably, "token" is intentionally absent because it
|
||||
// false-positively redacts LLM token-usage fields (input_tokens,
|
||||
// output_tokens, prompt_tokens, completion_tokens, reasoning_tokens,
|
||||
// cache_creation_input_tokens, cache_read_input_tokens, etc.). Auth-
|
||||
// related token fields are caught by the exact-match set below.
|
||||
var sensitiveJSONKeyFragments = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"authorization",
|
||||
"credential",
|
||||
}
|
||||
|
||||
// sensitiveJSONKeyExact matches auth-related token/key field names
|
||||
// without false-positiving on LLM usage counters. Includes both
|
||||
// snake_case originals and their camelCase-lowered equivalents
|
||||
// (e.g. "accessToken" → "accesstoken") so that providers using
|
||||
// either convention are caught.
|
||||
var sensitiveJSONKeyExact = map[string]struct{}{
|
||||
"token": {},
|
||||
"access_token": {},
|
||||
"accesstoken": {},
|
||||
"refresh_token": {},
|
||||
"refreshtoken": {},
|
||||
"id_token": {},
|
||||
"idtoken": {},
|
||||
"api_token": {},
|
||||
"apitoken": {},
|
||||
"api_key": {},
|
||||
"apikey": {},
|
||||
"api-key": {},
|
||||
"x-api-key": {},
|
||||
"auth_token": {},
|
||||
"authtoken": {},
|
||||
"bearer_token": {},
|
||||
"bearertoken": {},
|
||||
"session_token": {},
|
||||
"sessiontoken": {},
|
||||
"security_token": {},
|
||||
"securitytoken": {},
|
||||
"private_key": {},
|
||||
"privatekey": {},
|
||||
"signing_key": {},
|
||||
"signingkey": {},
|
||||
"secret_key": {},
|
||||
"secretkey": {},
|
||||
}
|
||||
|
||||
// RedactHeaders returns a flattened copy of h with sensitive values redacted.
|
||||
func RedactHeaders(h http.Header) map[string]string {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
redacted := make(map[string]string, len(h))
|
||||
for name, values := range h {
|
||||
if isSensitiveName(name) {
|
||||
redacted[name] = RedactedValue
|
||||
continue
|
||||
}
|
||||
redacted[name] = strings.Join(values, ", ")
|
||||
}
|
||||
return redacted
|
||||
}
|
||||
|
||||
// RedactJSONSecrets redacts sensitive JSON values by key name. When
|
||||
// the input is not valid JSON (truncated body, HTML error page, etc.)
|
||||
// the raw bytes are replaced entirely with a diagnostic placeholder
|
||||
// to avoid leaking credentials from malformed payloads.
|
||||
func RedactJSONSecrets(data []byte) []byte {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(data))
|
||||
decoder.UseNumber()
|
||||
|
||||
var value any
|
||||
if err := decoder.Decode(&value); err != nil {
|
||||
// Cannot parse: replace entirely to prevent credential leaks
|
||||
// from non-JSON error responses (HTML pages, partial bodies).
|
||||
return []byte(`{"error":"chatdebug: body is not valid JSON, redacted for safety"}`)
|
||||
}
|
||||
if err := consumeJSONEOF(decoder); err != nil {
|
||||
return []byte(`{"error":"chatdebug: body contains extra JSON values, redacted for safety"}`)
|
||||
}
|
||||
|
||||
redacted, changed := redactJSONValue(value)
|
||||
if !changed {
|
||||
return data
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(redacted)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
func consumeJSONEOF(decoder *json.Decoder) error {
|
||||
var extra any
|
||||
err := decoder.Decode(&extra)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
return xerrors.New("chatdebug: extra JSON values")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var safeRateLimitHeaderNames = map[string]struct{}{
|
||||
"anthropic-ratelimit-requests-limit": {},
|
||||
"anthropic-ratelimit-requests-remaining": {},
|
||||
"anthropic-ratelimit-requests-reset": {},
|
||||
"anthropic-ratelimit-tokens-limit": {},
|
||||
"anthropic-ratelimit-tokens-remaining": {},
|
||||
"anthropic-ratelimit-tokens-reset": {},
|
||||
"x-ratelimit-limit-requests": {},
|
||||
"x-ratelimit-limit-tokens": {},
|
||||
"x-ratelimit-remaining-requests": {},
|
||||
"x-ratelimit-remaining-tokens": {},
|
||||
"x-ratelimit-reset-requests": {},
|
||||
"x-ratelimit-reset-tokens": {},
|
||||
}
|
||||
|
||||
// isSensitiveName reports whether a name (header or query parameter)
|
||||
// looks like a credential-carrying key. Exact-match headers are
|
||||
// checked first, then the rate-limit allowlist, then substring
|
||||
// patterns for API keys and auth tokens.
|
||||
func isSensitiveName(name string) bool {
|
||||
lowerName := strings.ToLower(name)
|
||||
if _, ok := sensitiveHeaderNames[lowerName]; ok {
|
||||
return true
|
||||
}
|
||||
if _, ok := safeRateLimitHeaderNames[lowerName]; ok {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lowerName, "api-key") ||
|
||||
strings.Contains(lowerName, "api_key") ||
|
||||
strings.Contains(lowerName, "apikey") {
|
||||
return true
|
||||
}
|
||||
// Catch any header containing "token" (e.g. Token, X-Token,
|
||||
// X-Auth-Token). Safe rate-limit headers like
|
||||
// x-ratelimit-remaining-tokens are already allowlisted above
|
||||
// and will not reach this point.
|
||||
if strings.Contains(lowerName, "token") {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(lowerName, "secret") ||
|
||||
strings.Contains(lowerName, "bearer")
|
||||
}
|
||||
|
||||
func isSensitiveJSONKey(key string) bool {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if _, ok := sensitiveJSONKeyExact[lowerKey]; ok {
|
||||
return true
|
||||
}
|
||||
for _, fragment := range sensitiveJSONKeyFragments {
|
||||
if strings.Contains(lowerKey, fragment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func redactJSONValue(value any) (any, bool) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
for key, child := range typed {
|
||||
if isSensitiveJSONKey(key) {
|
||||
if current, ok := child.(string); ok && current == RedactedValue {
|
||||
continue
|
||||
}
|
||||
typed[key] = RedactedValue
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[key] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
case []any:
|
||||
changed := false
|
||||
for i, child := range typed {
|
||||
redactedChild, childChanged := redactJSONValue(child)
|
||||
if childChanged {
|
||||
typed[i] = redactedChild
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return typed, changed
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
@@ -1,277 +0,0 @@
|
||||
package chatdebug_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
)
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Nil(t, chatdebug.RedactHeaders(nil))
|
||||
})
|
||||
|
||||
t.Run("empty header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
redacted := chatdebug.RedactHeaders(http.Header{})
|
||||
require.NotNil(t, redacted)
|
||||
require.Empty(t, redacted)
|
||||
})
|
||||
|
||||
t.Run("authorization redacted and others preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer secret-token"},
|
||||
"Accept": {"application/json"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
require.Equal(t, "application/json", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("multi-value headers are flattened", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Accept": {"application/json", "text/plain"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "application/json, text/plain", redacted["Accept"])
|
||||
})
|
||||
|
||||
t.Run("header name matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lowerAuthorization := "authorization"
|
||||
upperAuthorization := "AUTHORIZATION"
|
||||
headers := http.Header{
|
||||
lowerAuthorization: {"lower"},
|
||||
upperAuthorization: {"upper"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[lowerAuthorization])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted[upperAuthorization])
|
||||
})
|
||||
|
||||
t.Run("token and secret substrings are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
traceHeader := "X-Trace-ID"
|
||||
headers := http.Header{
|
||||
"X-Auth-Token": {"abc"},
|
||||
"X-Custom-Secret": {"def"},
|
||||
"X-Bearer": {"ghi"},
|
||||
traceHeader: {"trace"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Auth-Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Bearer"])
|
||||
require.Equal(t, "trace", redacted[traceHeader])
|
||||
})
|
||||
|
||||
t.Run("known safe rate limit headers containing token are not redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Anthropic-Ratelimit-Tokens-Limit": {"1000000"},
|
||||
"Anthropic-Ratelimit-Tokens-Remaining": {"999000"},
|
||||
"Anthropic-Ratelimit-Tokens-Reset": {"2026-03-31T08:55:26Z"},
|
||||
"X-RateLimit-Limit-Tokens": {"120000"},
|
||||
"X-RateLimit-Remaining-Tokens": {"119500"},
|
||||
"X-RateLimit-Reset-Tokens": {"12ms"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["Anthropic-Ratelimit-Tokens-Limit"])
|
||||
require.Equal(t, "999000", redacted["Anthropic-Ratelimit-Tokens-Remaining"])
|
||||
require.Equal(t, "2026-03-31T08:55:26Z", redacted["Anthropic-Ratelimit-Tokens-Reset"])
|
||||
require.Equal(t, "120000", redacted["X-RateLimit-Limit-Tokens"])
|
||||
require.Equal(t, "119500", redacted["X-RateLimit-Remaining-Tokens"])
|
||||
require.Equal(t, "12ms", redacted["X-RateLimit-Reset-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("non-standard headers with api-key pattern are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Custom-Api-Key": {"secret-key"},
|
||||
"X-Custom-Secret": {"secret-val"},
|
||||
"X-Custom-Session-Token": {"session-id"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Session-Token"])
|
||||
})
|
||||
|
||||
t.Run("rate limit headers with token in name are preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Rate-limit headers containing "token" should NOT be redacted
|
||||
// because they carry usage/limit counts, not credentials.
|
||||
headers := http.Header{
|
||||
"X-Ratelimit-Limit-Tokens": {"1000000"},
|
||||
"X-Ratelimit-Remaining-Tokens": {"999000"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, "1000000", redacted["X-Ratelimit-Limit-Tokens"])
|
||||
require.Equal(t, "999000", redacted["X-Ratelimit-Remaining-Tokens"])
|
||||
})
|
||||
|
||||
t.Run("original header is not modified", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"Authorization": {"Bearer keep-me"},
|
||||
"X-Test": {"value"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
redacted["X-Test"] = "changed"
|
||||
|
||||
require.Equal(t, []string{"Bearer keep-me"}, headers["Authorization"])
|
||||
require.Equal(t, []string{"value"}, headers["X-Test"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
|
||||
})
|
||||
t.Run("api-key header variants are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{
|
||||
"X-Goog-Api-Key": {"secret"},
|
||||
"X-Api_Key": {"other-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Goog-Api-Key"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Api_Key"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
|
||||
t.Run("plain token headers are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Headers like "Token" or "X-Token" should be redacted
|
||||
// even without auth/session/access qualifiers.
|
||||
headers := http.Header{
|
||||
"Token": {"my-secret-token"},
|
||||
"X-Token": {"another-secret"},
|
||||
"X-Safe": {"ok"},
|
||||
}
|
||||
|
||||
redacted := chatdebug.RedactHeaders(headers)
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["Token"])
|
||||
require.Equal(t, chatdebug.RedactedValue, redacted["X-Token"])
|
||||
require.Equal(t, "ok", redacted["X-Safe"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedactJSONSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("redacts top level secret fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"api_key":"abc","token":"def","password":"ghi","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"api_key":"[REDACTED]","token":"[REDACTED]","password":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts security_token exact key", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"security_token":"s3cret","securityToken":"tok","safe":"ok"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"security_token":"[REDACTED]","securityToken":"[REDACTED]","safe":"ok"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("preserves LLM token usage fields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"input_tokens":100,"output_tokens":50,"prompt_tokens":80,"completion_tokens":20,"reasoning_tokens":10,"cache_creation_input_tokens":5,"cache_read_input_tokens":3,"total_tokens":150,"max_tokens":4096,"max_output_tokens":2048}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
// All usage/limit fields should be preserved, not redacted.
|
||||
require.Equal(t, input, redacted)
|
||||
})
|
||||
|
||||
t.Run("redacts nested objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"outer":{"nested_secret":"abc","safe":1},"keep":true}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"outer":{"nested_secret":"[REDACTED]","safe":1},"keep":true}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("redacts arrays of objects", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`[{"token":"abc"},{"value":1,"credentials":{"access_key":"def"}}]`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `[{"token":"[REDACTED]"},{"value":1,"credentials":"[REDACTED]"}]`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("concatenated JSON is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"token":"abc"}{"safe":"ok"}`)
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "extra JSON values")
|
||||
})
|
||||
|
||||
t.Run("non JSON input is replaced with diagnostic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte("not json")
|
||||
result := chatdebug.RedactJSONSecrets(input)
|
||||
require.Contains(t, string(result), "not valid JSON")
|
||||
})
|
||||
|
||||
t.Run("empty input is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte{}
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("JSON without sensitive keys is unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"safe":"ok","nested":{"value":1}}`)
|
||||
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
|
||||
})
|
||||
|
||||
t.Run("key matching is case insensitive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := []byte(`{"API_KEY":"abc","Token":"def","PASSWORD":"ghi"}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"API_KEY":"[REDACTED]","Token":"[REDACTED]","PASSWORD":"[REDACTED]"}`, string(redacted))
|
||||
})
|
||||
|
||||
t.Run("camelCase token field names are redacted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Providers may use camelCase (e.g. accessToken, refreshToken).
|
||||
// These should be redacted even though they don't match the
|
||||
// snake_case originals exactly.
|
||||
input := []byte(`{"accessToken":"abc","refreshToken":"def","authToken":"ghi","input_tokens":100,"output_tokens":50}`)
|
||||
redacted := chatdebug.RedactJSONSecrets(input)
|
||||
require.JSONEq(t, `{"accessToken":"[REDACTED]","refreshToken":"[REDACTED]","authToken":"[REDACTED]","input_tokens":100,"output_tokens":50}`, string(redacted))
|
||||
})
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBeginStepReuseStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("reuses handle under ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
ctx = ReuseStep(ctx)
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, firstEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, secondEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.Same(t, firstHandle, secondHandle)
|
||||
require.Same(t, firstHandle.stepCtx, secondHandle.stepCtx)
|
||||
require.Same(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, runID, firstHandle.stepCtx.RunID)
|
||||
require.Equal(t, chatID, firstHandle.stepCtx.ChatID)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, OperationStream, firstHandle.stepCtx.Operation)
|
||||
require.NotEqual(t, uuid.Nil, firstHandle.stepCtx.StepID)
|
||||
|
||||
firstStepCtx, ok := StepFromContext(firstEnriched)
|
||||
require.True(t, ok)
|
||||
secondStepCtx, ok := StepFromContext(secondEnriched)
|
||||
require.True(t, ok)
|
||||
require.Same(t, firstStepCtx, secondStepCtx)
|
||||
require.Same(t, firstHandle.stepCtx, firstStepCtx)
|
||||
require.Same(t, attemptSinkFromContext(firstEnriched), attemptSinkFromContext(secondEnriched))
|
||||
})
|
||||
|
||||
t.Run("creates new handles without ReuseStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
runID := uuid.New()
|
||||
t.Cleanup(func() { CleanupStepCounter(runID) })
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
1,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
expectDebugLoggingEnabled(t, db, ownerID)
|
||||
expectCreateStepNumberWithRequestValidity(
|
||||
t,
|
||||
db,
|
||||
runID,
|
||||
chatID,
|
||||
2,
|
||||
OperationStream,
|
||||
false,
|
||||
)
|
||||
|
||||
svc := NewService(db, testutil.Logger(t), nil)
|
||||
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
||||
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
|
||||
|
||||
firstHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
secondHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
|
||||
|
||||
require.NotNil(t, firstHandle)
|
||||
require.NotNil(t, secondHandle)
|
||||
require.NotSame(t, firstHandle, secondHandle)
|
||||
require.NotSame(t, firstHandle.sink, secondHandle.sink)
|
||||
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
|
||||
require.Equal(t, int32(2), secondHandle.stepCtx.StepNumber)
|
||||
require.NotEqual(t, firstHandle.stepCtx.StepID, secondHandle.stepCtx.StepID)
|
||||
})
|
||||
}
|
||||
@@ -1,539 +0,0 @@
|
||||
package chatdebug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
)
|
||||
|
||||
// DefaultStaleThreshold is the fallback stale timeout for debug rows
|
||||
// when no caller-provided value is supplied.
|
||||
const DefaultStaleThreshold = 5 * time.Minute
|
||||
|
||||
// Service persists chat debug rows and fans out lightweight change events.
|
||||
type Service struct {
|
||||
db database.Store
|
||||
log slog.Logger
|
||||
pubsub pubsub.Pubsub
|
||||
alwaysEnable bool
|
||||
// staleAfterNanos stores the stale threshold as nanoseconds in an
|
||||
// atomic.Int64 so SetStaleAfter and FinalizeStale can be called
|
||||
// from concurrent goroutines without a data race.
|
||||
staleAfterNanos atomic.Int64
|
||||
}
|
||||
|
||||
// ServiceOption configures optional Service behavior.
|
||||
type ServiceOption func(*Service)
|
||||
|
||||
// WithStaleThreshold overrides the default stale-row finalization
|
||||
// threshold. Callers that already have a configurable in-flight chat
|
||||
// timeout (e.g. chatd's InFlightChatStaleAfter) should pass it here
|
||||
// so the two sweeps stay in sync.
|
||||
func WithStaleThreshold(d time.Duration) ServiceOption {
|
||||
return func(s *Service) {
|
||||
if d > 0 {
|
||||
s.staleAfterNanos.Store(d.Nanoseconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithAlwaysEnable forces debug logging on for every chat regardless
|
||||
// of the runtime admin and user opt-in settings. This is used for the
|
||||
// deployment-level serpent flag.
|
||||
func WithAlwaysEnable(always bool) ServiceOption {
|
||||
return func(s *Service) {
|
||||
s.alwaysEnable = always
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRunParams contains friendly inputs for creating a debug run.
|
||||
type CreateRunParams struct {
|
||||
ChatID uuid.UUID
|
||||
RootChatID uuid.UUID
|
||||
ParentChatID uuid.UUID
|
||||
ModelConfigID uuid.UUID
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
Kind RunKind
|
||||
Status Status
|
||||
Provider string
|
||||
Model string
|
||||
Summary any
|
||||
}
|
||||
|
||||
// UpdateRunParams contains inputs for updating a debug run.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query. Once a field is set it cannot be cleared
|
||||
// back to NULL — this is intentional for the write-once-finalize
|
||||
// lifecycle of debug rows.
|
||||
type UpdateRunParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
Summary any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// CreateStepParams contains friendly inputs for creating a debug step.
|
||||
type CreateStepParams struct {
|
||||
RunID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
StepNumber int32
|
||||
Operation Operation
|
||||
Status Status
|
||||
HistoryTipMessageID int64
|
||||
NormalizedRequest any
|
||||
}
|
||||
|
||||
// UpdateStepParams contains optional inputs for updating a debug step.
|
||||
// Most payload fields are typed as any and serialized through nullJSON
|
||||
// because their shape varies by provider. The Attempts field uses a
|
||||
// concrete slice for compile-time safety where the schema is stable.
|
||||
// Zero-valued fields are treated as "keep the existing value" by the
|
||||
// COALESCE-based SQL query — once set, fields cannot be cleared back
|
||||
// to NULL. This is intentional for the write-once-finalize lifecycle
|
||||
// of debug rows.
|
||||
type UpdateStepParams struct {
|
||||
ID uuid.UUID
|
||||
ChatID uuid.UUID
|
||||
Status Status
|
||||
AssistantMessageID int64
|
||||
NormalizedResponse any
|
||||
Usage any
|
||||
Attempts []Attempt
|
||||
Error any
|
||||
Metadata any
|
||||
FinishedAt time.Time
|
||||
}
|
||||
|
||||
// NewService constructs a chat debug persistence service.
|
||||
func NewService(db database.Store, log slog.Logger, ps pubsub.Pubsub, opts ...ServiceOption) *Service {
|
||||
if db == nil {
|
||||
panic("chatdebug: nil database.Store")
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
db: db,
|
||||
log: log,
|
||||
pubsub: ps,
|
||||
}
|
||||
s.staleAfterNanos.Store(DefaultStaleThreshold.Nanoseconds())
|
||||
for _, opt := range opts {
|
||||
opt(s)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SetStaleAfter overrides the in-flight stale threshold used when
|
||||
// finalizing abandoned debug rows. Zero or negative durations keep the
|
||||
// default threshold.
|
||||
func (s *Service) SetStaleAfter(staleAfter time.Duration) {
|
||||
if s == nil || staleAfter <= 0 {
|
||||
return
|
||||
}
|
||||
s.staleAfterNanos.Store(staleAfter.Nanoseconds())
|
||||
}
|
||||
|
||||
func chatdContext(ctx context.Context) context.Context {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// chat debug persistence reads and writes.
|
||||
return dbauthz.AsChatd(ctx)
|
||||
}
|
||||
|
||||
// IsEnabled returns whether debug logging is enabled for the given chat.
|
||||
func (s *Service) IsEnabled(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
ownerID uuid.UUID,
|
||||
) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.alwaysEnable {
|
||||
return true
|
||||
}
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
authCtx := chatdContext(ctx)
|
||||
|
||||
allowUsers, err := s.db.GetChatDebugLoggingEnabled(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
s.log.Warn(ctx, "failed to load runtime admin chat debug logging setting",
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
if !allowUsers {
|
||||
return false
|
||||
}
|
||||
|
||||
if ownerID == uuid.Nil {
|
||||
s.log.Warn(ctx, "missing chat owner for debug logging enablement check",
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
enabled, err := s.db.GetUserChatDebugLoggingEnabled(authCtx, ownerID)
|
||||
if err == nil {
|
||||
return enabled
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
|
||||
s.log.Warn(ctx, "failed to load user chat debug logging setting",
|
||||
slog.Error(err),
|
||||
slog.F("chat_id", chatID),
|
||||
slog.F("owner_id", ownerID),
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateRun inserts a new debug run and emits a run update event.
|
||||
func (s *Service) CreateRun(
|
||||
ctx context.Context,
|
||||
params CreateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.InsertChatDebugRun(chatdContext(ctx),
|
||||
database.InsertChatDebugRunParams{
|
||||
ChatID: params.ChatID,
|
||||
RootChatID: nullUUID(params.RootChatID),
|
||||
ParentChatID: nullUUID(params.ParentChatID),
|
||||
ModelConfigID: nullUUID(params.ModelConfigID),
|
||||
TriggerMessageID: nullInt64(params.TriggerMessageID),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
Kind: string(params.Kind),
|
||||
Status: string(params.Status),
|
||||
Provider: nullString(params.Provider),
|
||||
Model: nullString(params.Model),
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// UpdateRun updates an existing debug run and emits a run update event.
|
||||
func (s *Service) UpdateRun(
|
||||
ctx context.Context,
|
||||
params UpdateRunParams,
|
||||
) (database.ChatDebugRun, error) {
|
||||
run, err := s.db.UpdateChatDebugRun(chatdContext(ctx),
|
||||
database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: nullString(string(params.Status)),
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: s.nullJSON(ctx, params.Summary),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugRun{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
|
||||
return run, nil
|
||||
}
|
||||
|
||||
// CreateStep inserts a new debug step and emits a step update event.
|
||||
func (s *Service) CreateStep(
|
||||
ctx context.Context,
|
||||
params CreateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
insert := database.InsertChatDebugStepParams{
|
||||
RunID: params.RunID,
|
||||
StepNumber: params.StepNumber,
|
||||
Operation: string(params.Operation),
|
||||
Status: string(params.Status),
|
||||
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
|
||||
AssistantMessageID: sql.NullInt64{},
|
||||
NormalizedRequest: s.nullJSON(ctx, params.NormalizedRequest),
|
||||
NormalizedResponse: pqtype.NullRawMessage{},
|
||||
Usage: pqtype.NullRawMessage{},
|
||||
Attempts: pqtype.NullRawMessage{},
|
||||
Error: pqtype.NullRawMessage{},
|
||||
Metadata: pqtype.NullRawMessage{},
|
||||
StartedAt: sql.NullTime{},
|
||||
UpdatedAt: sql.NullTime{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ChatID: params.ChatID,
|
||||
}
|
||||
|
||||
// Cap retry attempts to prevent infinite loops under
|
||||
// pathological concurrency. Each iteration performs two DB
|
||||
// round-trips (insert + list), so 10 retries is generous.
|
||||
const maxCreateStepRetries = 10
|
||||
|
||||
for attempt := 0; attempt < maxCreateStepRetries; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
step, err := s.db.InsertChatDebugStep(chatdContext(ctx), insert)
|
||||
if err == nil {
|
||||
// Touch the parent run's updated_at so the stale-
|
||||
// finalization sweep does not prematurely interrupt
|
||||
// long-running runs that are still producing steps.
|
||||
if _, touchErr := s.db.UpdateChatDebugRun(chatdContext(ctx), database.UpdateChatDebugRunParams{
|
||||
RootChatID: uuid.NullUUID{},
|
||||
ParentChatID: uuid.NullUUID{},
|
||||
ModelConfigID: uuid.NullUUID{},
|
||||
TriggerMessageID: sql.NullInt64{},
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
Status: sql.NullString{},
|
||||
Provider: sql.NullString{},
|
||||
Model: sql.NullString{},
|
||||
Summary: pqtype.NullRawMessage{},
|
||||
FinishedAt: sql.NullTime{},
|
||||
ID: params.RunID,
|
||||
ChatID: params.ChatID,
|
||||
}); touchErr != nil {
|
||||
s.log.Warn(ctx, "failed to touch parent run updated_at",
|
||||
slog.F("run_id", params.RunID),
|
||||
slog.Error(touchErr),
|
||||
)
|
||||
}
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
if !database.IsUniqueViolation(err, database.UniqueIndexChatDebugStepsRunStep) {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
steps, listErr := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), params.RunID)
|
||||
if listErr != nil {
|
||||
return database.ChatDebugStep{}, listErr
|
||||
}
|
||||
nextStepNumber := insert.StepNumber + 1
|
||||
for _, existing := range steps {
|
||||
if existing.StepNumber >= nextStepNumber {
|
||||
nextStepNumber = existing.StepNumber + 1
|
||||
}
|
||||
}
|
||||
insert.StepNumber = nextStepNumber
|
||||
}
|
||||
|
||||
return database.ChatDebugStep{}, xerrors.Errorf(
|
||||
"failed to create debug step after %d attempts (run_id=%s)",
|
||||
maxCreateStepRetries, params.RunID,
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateStep updates an existing debug step and emits a step update event.
|
||||
func (s *Service) UpdateStep(
|
||||
ctx context.Context,
|
||||
params UpdateStepParams,
|
||||
) (database.ChatDebugStep, error) {
|
||||
step, err := s.db.UpdateChatDebugStep(chatdContext(ctx),
|
||||
database.UpdateChatDebugStepParams{
|
||||
Status: nullString(string(params.Status)),
|
||||
HistoryTipMessageID: sql.NullInt64{},
|
||||
AssistantMessageID: nullInt64(params.AssistantMessageID),
|
||||
NormalizedRequest: pqtype.NullRawMessage{},
|
||||
NormalizedResponse: s.nullJSON(ctx, params.NormalizedResponse),
|
||||
Usage: s.nullJSON(ctx, params.Usage),
|
||||
Attempts: s.nullJSON(ctx, params.Attempts),
|
||||
Error: s.nullJSON(ctx, params.Error),
|
||||
Metadata: s.nullJSON(ctx, params.Metadata),
|
||||
FinishedAt: nullTime(params.FinishedAt),
|
||||
ID: params.ID,
|
||||
ChatID: params.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return database.ChatDebugStep{}, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
|
||||
return step, nil
|
||||
}
|
||||
|
||||
// DeleteByChatID deletes all debug data for a chat and emits a delete event.
|
||||
func (s *Service) DeleteByChatID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataByChatID(chatdContext(ctx), chatID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// DeleteAfterMessageID deletes debug data newer than the given message.
|
||||
func (s *Service) DeleteAfterMessageID(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
messageID int64,
|
||||
) (int64, error) {
|
||||
deleted, err := s.db.DeleteChatDebugDataAfterMessageID(
|
||||
chatdContext(ctx),
|
||||
database.DeleteChatDebugDataAfterMessageIDParams{
|
||||
ChatID: chatID,
|
||||
MessageID: messageID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// FinalizeStale finalizes stale in-flight debug rows and emits a broadcast.
|
||||
func (s *Service) FinalizeStale(
|
||||
ctx context.Context,
|
||||
) (database.FinalizeStaleChatDebugRowsRow, error) {
|
||||
ns := s.staleAfterNanos.Load()
|
||||
staleAfter := time.Duration(ns)
|
||||
if staleAfter <= 0 {
|
||||
staleAfter = DefaultStaleThreshold
|
||||
}
|
||||
|
||||
result, err := s.db.FinalizeStaleChatDebugRows(
|
||||
chatdContext(ctx),
|
||||
time.Now().Add(-staleAfter),
|
||||
)
|
||||
if err != nil {
|
||||
return database.FinalizeStaleChatDebugRowsRow{}, err
|
||||
}
|
||||
|
||||
if result.RunsFinalized > 0 || result.StepsFinalized > 0 {
|
||||
s.publishEvent(ctx, uuid.Nil, EventKindFinalize, uuid.Nil, uuid.Nil)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func nullUUID(id uuid.UUID) uuid.NullUUID {
|
||||
return uuid.NullUUID{UUID: id, Valid: id != uuid.Nil}
|
||||
}
|
||||
|
||||
func nullInt64(v int64) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: v, Valid: v != 0}
|
||||
}
|
||||
|
||||
func nullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
func nullTime(value time.Time) sql.NullTime {
|
||||
return sql.NullTime{Time: value, Valid: !value.IsZero()}
|
||||
}
|
||||
|
||||
// nullJSON marshals value to a NullRawMessage. When value is nil or
|
||||
// marshals to JSON "null", the result is {Valid: false}. Combined with
|
||||
// the COALESCE-based UPDATE queries, this means a caller cannot clear a
|
||||
// previously-set JSON column back to NULL — passing nil preserves the
|
||||
// existing value. This is acceptable for debug logs because fields
|
||||
// accumulate monotonically (request → response → usage → error) and
|
||||
// never need to be cleared during normal operation.
|
||||
// jsonClear is a sentinel value that tells nullJSON to emit a valid
|
||||
// JSON null (JSONB 'null') instead of SQL NULL. COALESCE treats SQL
|
||||
// NULL as "keep existing" but replaces with a non-NULL JSONB value,
|
||||
// so passing jsonClear explicitly overwrites a previously set field.
|
||||
type jsonClear struct{}
|
||||
|
||||
func (s *Service) nullJSON(ctx context.Context, value any) pqtype.NullRawMessage {
|
||||
if value == nil {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
// Sentinel: emit a valid JSONB null so COALESCE replaces
|
||||
// any previously stored value.
|
||||
if _, ok := value.(jsonClear); ok {
|
||||
return pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage("null"),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug JSON",
|
||||
slog.Error(err),
|
||||
slog.F("value_type", fmt.Sprintf("%T", value)),
|
||||
)
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
return pqtype.NullRawMessage{}
|
||||
}
|
||||
|
||||
return pqtype.NullRawMessage{RawMessage: data, Valid: true}
|
||||
}
|
||||
|
||||
func (s *Service) publishEvent(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
kind EventKind,
|
||||
runID uuid.UUID,
|
||||
stepID uuid.UUID,
|
||||
) {
|
||||
if s.pubsub == nil {
|
||||
s.log.Debug(ctx,
|
||||
"chat debug pubsub unavailable; skipping event",
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
event := DebugEvent{
|
||||
Kind: kind,
|
||||
ChatID: chatID,
|
||||
RunID: runID,
|
||||
StepID: stepID,
|
||||
}
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
s.log.Warn(ctx, "failed to marshal chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
channel := PubsubChannel(chatID)
|
||||
if err := s.pubsub.Publish(channel, data); err != nil {
|
||||
s.log.Warn(ctx, "failed to publish chat debug event",
|
||||
slog.Error(err),
|
||||
slog.F("channel", channel),
|
||||
slog.F("kind", kind),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user