Compare commits

..

3 Commits

Author SHA1 Message Date
Ethan Dickson de1a317890 refactor(coderd/database/pubsub): clean up batching implementation
- Trim metrics from 18 to 6, removing counters that duplicated info
  available from existing metrics or provided little operational signal.
- Inline single-use constants.
- Re-add hidden PubsubFlushInterval and PubsubQueueSize config knobs
  to ChatConfig for tuning without code changes.
- Remove dead ErrBatchingPubsubQueueFull export (never returned).
- Remove unreachable nil branch in batchFlushStage.
- Propagate resetErr through flushBatch even when delegate replay
  succeeds, so drain reports the broken sender state.
- Annotate sender field with goroutine-safety invariant.
- Regenerate metrics docs.
2026-04-09 14:43:59 +00:00
Ethan Dickson d64ee2e1cc chore(codersdk): remove pubsub batching config knobs 2026-04-09 05:17:48 +00:00
Ethan Dickson 13281d8235 feat(coderd/database/pubsub): add batched pubsub with flush-failure fallback and sender reset
Adds a chatd-specific BatchingPubsub that routes publishes through a
dedicated single-connection sender, coalescing notifications into
single transactions on a 50ms timer. Includes flush-failure fallback
to the shared delegate, automatic sender reset/recreate, expanded
histogram buckets, and focused recovery tests.
2026-04-09 02:19:21 +00:00
373 changed files with 10355 additions and 27373 deletions
@@ -18,35 +18,35 @@ The 5.x era resolves years of module system ambiguity and cleans house on legacy
The left column reflects patterns still common before TypeScript 5.x. Write the right column instead. The "Since" column tells you the minimum TypeScript version required.
| Old pattern | Modern replacement | Since |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------ |
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)``const` modifier on type parameter | 5.0 |
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
| `new RegExp(str.replace(/[.\*+?^${}()\[\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
| Old pattern | Modern replacement | Since |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | -------------------------------- | ------ |
| `--experimentalDecorators` + legacy decorator signatures | Standard decorators (TC39): `function dec(target, context: ClassMethodDecoratorContext)` — no flag needed | 5.0 |
| Requiring callers to add `as const` at call sites | `<const T extends HasNames>(arg: T)``const` modifier on type parameter | 5.0 |
| `--importsNotUsedAsValues` + `--preserveValueImports` | `--verbatimModuleSyntax` | 5.0 |
| `import { Foo } from "..."` when `Foo` is only used as a type | `import { type Foo } from "..."` or `import type { Foo } from "..."` | 5.0 |
| `"extends": "@tsconfig/strictest/tsconfig.json"` chain | `"extends": ["@tsconfig/strictest/tsconfig.json", "./tsconfig.base.json"]` (array form) | 5.0 |
| `try { ... } finally { resource.close(); resource.delete(); }` | `using resource = acquireResource()` — calls `[Symbol.dispose]()` automatically | 5.2 |
| `try { ... } finally { await resource.close() }` | `await using resource = acquireAsyncResource()` | 5.2 |
| Ad-hoc cleanup with multiple `try/finally` blocks | `using cleanup = new DisposableStack(); cleanup.defer(() => ...)` | 5.2 |
| `import data from "./data.json" assert { type: "json" }` | `import data from "./data.json" with { type: "json" }` | 5.3 |
| `.filter(Boolean)` or `.filter(x => !!x)` to remove nulls | `.filter(x => x !== undefined)` or `.filter(x => x !== null)` (infers type predicate) | 5.5 |
| Extra phantom type param to block inference bleed: `<C extends string, D extends C>` | `NoInfer<C>` on the parameter you don't want to drive inference | 5.4 |
| `/** @typedef {import("./types").Foo} Foo */` in JS files | `/** @import { Foo } from "./types" */` (JSDoc `@import` tag) | 5.5 |
| `myArray.reverse()` mutating in place | `myArray.toReversed()` (returns new array) | 5.2 |
| `myArray.sort(cmp)` mutating in place | `myArray.toSorted(cmp)` (returns new array) | 5.2 |
| `const copy = [...arr]; copy[i] = v` | `arr.with(i, v)` (returns new array) | 5.2 |
| Manual `has`/`get`/`set` pattern on `Map` | `map.getOrInsert(key, defaultValue)` or `getOrInsertComputed(key, fn)` | 6.0 RC |
| `new RegExp(str.replace(/[.\*+?^${}() | [\]\\]/g, '\\$&'))` | `new RegExp(RegExp.escape(str))` | 6.0 RC |
| `--moduleResolution node` (node10) | `--moduleResolution nodenext` (Node.js) or `--moduleResolution bundler` (bundlers/Bun) | 6.0 RC |
| `"baseUrl": "./src"` + `"@app/*": ["app/*"]` in paths | Remove `baseUrl`; use `"@app/*": ["./src/app/*"]` in paths directly | 6.0 RC |
| `module Foo { export const x = 1; }` | `namespace Foo { export const x = 1; }` | 6.0 RC |
| `export * from "..."` when all re-exported members are types | `export type * from "..."` (or `export type * as ns from "..."`) | 5.0 |
| `function f(): undefined { return undefined; }` — explicit return required in `: undefined`-returning function | Remove the `return` entirely; `undefined`-returning functions no longer require any return statement | 5.1 |
| Manual type predicate annotation on a simple arrow: `(x: T \| undefined): x is T => x !== undefined` | Remove the annotation; TypeScript infers `x is T` from `!== null/undefined` and `instanceof` checks automatically | 5.5 |
| `const val = obj[key]; if (typeof val === "string") { use(val); }` — extract to const to narrow indexed access | `if (typeof obj[key] === "string") { obj[key].toUpperCase(); }` directly — both `obj` and `key` must be effectively constant | 5.5 |
| Copy narrowed `let`/param to a `const`, or restructure code to escape stale closure narrowing after reassignment | Remove the copy; narrowing survives into closures created after the last assignment to the variable | 5.4 |
| `(arr as string[]).filter(...)` or restructure to avoid "not callable" errors on `string[] \| number[]` | Call `.filter`, `.find`, `.some`, `.every`, `.reduce` directly on union-of-array types | 5.2 |
| `if`/`else` chain used to work around lack of narrowing inside a `switch (true)` body | `switch (true)` — each `case` condition now narrows the tested variable in its clause | 5.3 |
## New capabilities
+6
View File
@@ -91,6 +91,12 @@ updates:
emotion:
patterns:
- "@emotion*"
exclude-patterns:
- "jest-runner-eslint"
jest:
patterns:
- "jest"
- "@types/jest"
vite:
patterns:
- "vite*"
+1 -5
View File
@@ -84,7 +84,6 @@ jobs:
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
SENDER: ${{ github.event.sender.login }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
@@ -140,7 +139,6 @@ jobs:
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
Requested by: @${SENDER}
EOF
)
@@ -173,6 +171,4 @@ jobs:
--base "$RELEASE_VERSION" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY" \
--assignee "$SENDER" \
--reviewer "$SENDER"
--body "$BODY"
+5 -18
View File
@@ -42,7 +42,6 @@ jobs:
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
SENDER: ${{ github.event.sender.login }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
@@ -117,7 +116,6 @@ jobs:
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
Requested by: @${SENDER}
EOF
)
@@ -134,19 +132,8 @@ jobs:
exit 0
fi
NEW_PR_URL=$(
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY" \
--assignee "$SENDER" \
--reviewer "$SENDER"
)
# Comment on the original PR to notify the author.
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
if [ "$CONFLICT" = true ]; then
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
fi
gh pr comment "$PR_NUMBER" --body "$COMMENT"
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY"
@@ -1,93 +0,0 @@
# Ensures that only bug fixes are cherry-picked to release branches.
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
name: PR Cherry-Pick Check
on:
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
pull_request_target:
types: [opened, reopened, edited]
branches:
- "release/*"
permissions:
pull-requests: write
jobs:
check-cherry-pick:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
with:
egress-policy: audit
- name: Check PR title for bug fix
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
script: |
const title = context.payload.pull_request.title;
const prNumber = context.payload.pull_request.number;
const baseBranch = context.payload.pull_request.base.ref;
const author = context.payload.pull_request.user.login;
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
// Match conventional commit "fix:" or "fix(scope):" prefix.
const isBugFix = /^fix(\(.+\))?:/.test(title);
if (isBugFix) {
console.log("PR title indicates a bug fix. No action needed.");
return;
}
console.log("PR title does not indicate a bug fix. Commenting.");
// Check for an existing comment from this bot to avoid duplicates
// on title edits.
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
});
const marker = "<!-- cherry-pick-check -->";
const existingComment = comments.find(
(c) => c.body && c.body.includes(marker),
);
const body = [
marker,
`👋 Hey @${author}!`,
"",
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
"",
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
"",
"```",
"fix: description of the bug fix",
"fix(scope): description of the bug fix",
"```",
"",
"If this is **not** a bug fix, it likely should not target a release branch.",
].join("\n");
if (existingComment) {
console.log(`Updating existing comment ${existingComment.id}.`);
await github.rest.issues.updateComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: existingComment.id,
body,
});
} else {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body,
});
}
core.warning(
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
);
+41 -95
View File
@@ -91,59 +91,6 @@ define atomic_write
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
endef
# Helper binary targets. Built with go build -o to avoid caching
# link-stage executables in GOCACHE. Each binary is a real Make
# target so parallel -j builds serialize correctly instead of
# racing on the same output path.
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/apitypings
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/auditdocgen
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/check-scopes
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/clidocgen
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./coderd/database/gen/dump
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/examplegen
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/gensite
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/apikeyscopesgen
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/metricsdocgen
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/metricsdocgen/scanner
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/modeloptionsgen
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
@mkdir -p _gen/bin
go build -o $@ ./scripts/typegen
# Shared temp directory for atomic writes. Lives at the project root
# so all targets share the same filesystem, and is gitignored.
# Order-only prerequisite: recipes that need it depend on | _gen
@@ -254,7 +201,6 @@ endif
clean:
rm -rf build/ site/build/ site/out/
rm -rf _gen/bin
mkdir -p build/
git restore site/out/
.PHONY: clean
@@ -708,8 +654,8 @@ lint/go:
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
.PHONY: lint/go
lint/examples: | _gen/bin/examplegen
_gen/bin/examplegen -lint
lint/examples:
go run ./scripts/examplegen/main.go -lint
.PHONY: lint/examples
# Use shfmt to determine the shell files, takes editorconfig into consideration.
@@ -747,8 +693,8 @@ lint/actions/zizmor:
.PHONY: lint/actions/zizmor
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
_gen/bin/check-scopes
lint/check-scopes: coderd/database/dump.sql
go run ./scripts/check-scopes
.PHONY: lint/check-scopes
# Verify migrations do not hardcode the public schema.
@@ -788,8 +734,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
#
# pre-commit uses two phases: gen+fmt first, then lint+build. This
# avoids races where gen creates temporary .go files that lint's
# find-based checks pick up. Within each phase, targets run in
# avoids races where gen's `go run` creates temporary .go files that
# lint's find-based checks pick up. Within each phase, targets run in
# parallel via -j. It fails if any tracked files have unstaged
# changes afterward.
@@ -1003,8 +949,8 @@ gen/mark-fresh:
# Runs migrations to output a dump of the database schema after migrations are
# applied.
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
_gen/bin/dbdump
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
go run ./coderd/database/gen/dump/main.go
touch "$@"
# Generates Go code for querying the database.
@@ -1121,88 +1067,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
(cd site/ && pnpm run gen:provisioner)
touch "$@"
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
_gen/bin/gensite -icons "$$tmpfile" && \
go run ./scripts/gensite/ -icons "$$tmpfile" && \
./scripts/biome_format.sh "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
$(call atomic_write,_gen/bin/examplegen)
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
$(call atomic_write,go run ./scripts/examplegen/main.go)
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
$(call atomic_write,_gen/bin/typegen rbac object)
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
touch "$@"
# NOTE: depends on object_gen.go because the generator build
# compiles coderd/rbac which includes it.
# NOTE: depends on object_gen.go because `go run` compiles
# coderd/rbac which includes it.
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
coderd/rbac/object_gen.go | _gen
# Write to a temp file first to avoid truncating the package
# during build since the generator imports the rbac package.
$(call atomic_write,_gen/bin/typegen rbac scopenames)
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# the generator build compiles coderd/rbac which includes both.
# `go run` compiles coderd/rbac which includes both.
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
# Write to a temp file to avoid truncating the target, which
# would break the codersdk package and any parallel build targets.
$(call atomic_write,_gen/bin/typegen rbac codersdk)
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# the generator build compiles coderd/rbac which includes both.
# `go run` compiles coderd/rbac which includes both.
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
# Generate SDK constants for external API key scopes.
$(call atomic_write,_gen/bin/apikeyscopesgen)
$(call atomic_write,go run ./scripts/apikeyscopesgen)
touch "$@"
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
# the generator build compiles coderd/rbac which includes both.
# `go run` compiles coderd/rbac which includes both.
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
pnpm exec markdown-table-formatter "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
tmpdir=$$(mktemp -d -p _gen) && \
tmpdir=$$(realpath "$$tmpdir") && \
mkdir -p "$$tmpdir/docs/reference/cli" && \
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
rm -rf "$$tmpdir"
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
pnpm exec markdown-table-formatter "$$tmpfile" && \
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
-120
View File
@@ -2862,126 +2862,6 @@ func TestAPI(t *testing.T) {
"rebuilt agent should include updated display apps")
})
// Verify that when a terraform-managed subagent is injected into
// a devcontainer, the Directory field sent to Create reflects
// the container-internal workspaceFolder from devcontainer
// read-configuration, not the host-side workspace_folder from
// the terraform resource. This is the scenario described in
// https://linear.app/codercom/issue/PRODUCT-259:
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
// 2. Terraform subagent → directory was stuck on host path (bug)
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
}
var (
ctx = testutil.Context(t, testutil.WaitMedium)
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mCtrl = gomock.NewController(t)
terraformAgentID = uuid.New()
containerID = "test-container-id"
// Given: A container with a host-side workspace folder.
terraformContainer = codersdk.WorkspaceAgentContainer{
ID: containerID,
FriendlyName: "test-container",
Image: "test-image",
Running: true,
CreatedAt: time.Now(),
Labels: map[string]string{
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
},
}
// Given: A terraform-defined devcontainer whose
// workspace_folder is the HOST-side path (set by provisioner).
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
ID: uuid.New(),
Name: "terraform-devcontainer",
WorkspaceFolder: "/home/coder/project",
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
}
fCCLI = &fakeContainerCLI{
containers: codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
},
arch: runtime.GOARCH,
}
// Given: devcontainer read-configuration returns the
// CONTAINER-INTERNAL workspace folder.
fDCCLI = &fakeDevcontainerCLI{
upID: containerID,
readConfig: agentcontainers.DevcontainerConfig{
Workspace: agentcontainers.DevcontainerWorkspace{
WorkspaceFolder: "/workspaces/project",
},
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
Customizations: agentcontainers.DevcontainerMergedCustomizations{
Coder: []agentcontainers.CoderCustomization{{}},
},
},
},
}
mSAC = acmock.NewMockSubAgentClient(mCtrl)
createCalls = make(chan agentcontainers.SubAgent, 1)
closed bool
)
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
agent.AuthToken = uuid.New()
createCalls <- agent
return agent, nil
},
).Times(1)
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
assert.True(t, closed, "Delete should only be called after Close")
return nil
}).AnyTimes()
api := agentcontainers.NewAPI(logger,
agentcontainers.WithContainerCLI(fCCLI),
agentcontainers.WithDevcontainerCLI(fDCCLI),
agentcontainers.WithDevcontainers(
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
),
agentcontainers.WithSubAgentClient(mSAC),
agentcontainers.WithSubAgentURL("test-subagent-url"),
agentcontainers.WithWatcher(watcher.NewNoop()),
)
api.Start()
defer func() {
closed = true
api.Close()
}()
// When: The devcontainer is created (triggering injection).
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
require.NoError(t, err)
// Then: The subagent sent to Create has the correct
// container-internal directory, not the host path.
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
assert.Equal(t, terraformAgentID, createdAgent.ID,
"agent should use terraform-defined ID")
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
"directory should be the container-internal path from devcontainer "+
"read-configuration, not the host-side workspace_folder")
})
t.Run("Error", func(t *testing.T) {
t.Parallel()
-27
View File
@@ -134,33 +134,6 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
}, ResolvePaths(mcpConfigFile, workingDir)
}
// ContextPartsFromDir reads instruction files and discovers skills
// from a specific directory, using default file names. This is used
// by the CLI chat context commands to read context from an arbitrary
// directory without consulting agent env vars.
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
var parts []codersdk.ChatMessagePart
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
parts = append(parts, entry)
}
// Reuse ResolvePaths so CLI skill discovery follows the same
// project-relative path handling as agent config resolution.
skillParts := discoverSkills(
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
DefaultSkillMetaFile,
)
parts = append(parts, skillParts...)
// Guarantee non-nil slice.
if parts == nil {
parts = []codersdk.ChatMessagePart{}
}
return parts
}
// MCPConfigFiles returns the resolved MCP configuration file
// paths for the agent's MCP manager.
func (api *API) MCPConfigFiles() []string {
+164 -210
View File
@@ -23,144 +23,18 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
return out
}
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
t.Helper()
skillDir := filepath.Join(skillsRoot, name)
require.NoError(t, os.MkdirAll(skillDir, 0o755))
require.NoError(t, os.WriteFile(
filepath.Join(skillDir, "SKILL.md"),
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
0o600,
))
return skillDir
}
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
t.Helper()
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
}
func TestContextPartsFromDir(t *testing.T) {
t.Parallel()
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
instructionPath := filepath.Join(dir, "AGENTS.md")
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Len(t, contextParts, 1)
require.Empty(t, skillParts)
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
require.False(t, contextParts[0].ContextFileTruncated)
})
t.Run("ReturnsSkillParts", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Empty(t, contextParts)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
skillDir := writeSkillMetaFileInRoot(
t,
filepath.Join(dir, "skills"),
"my-skill",
"A test skill",
)
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 1)
require.Empty(t, contextParts)
require.Len(t, skillParts, 1)
require.Equal(t, "my-skill", skillParts[0].SkillName)
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
require.Equal(t, skillDir, skillParts[0].SkillDir)
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
t.Parallel()
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
require.NotNil(t, parts)
require.Empty(t, parts)
})
t.Run("ReturnsCombinedResults", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
instructionPath := filepath.Join(dir, "AGENTS.md")
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
parts := agentcontextconfig.ContextPartsFromDir(dir)
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
require.Len(t, parts, 2)
require.Len(t, contextParts, 1)
require.Len(t, skillParts, 1)
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
require.Equal(t, "combined-skill", skillParts[0].SkillName)
require.Equal(t, skillDir, skillParts[0].SkillDir)
})
}
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
t.Helper()
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
for key, value := range overrides {
t.Setenv(key, value)
}
return fakeHome
}
func TestConfig(t *testing.T) {
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("Defaults", func(t *testing.T) {
setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
// Clear all env vars so defaults are used.
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := platformAbsPath("work")
cfg, mcpFiles := agentcontextconfig.Config(workDir)
@@ -172,18 +46,20 @@ func TestConfig(t *testing.T) {
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("CustomEnvVars", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
optInstructions := t.TempDir()
optSkills := t.TempDir()
optMCP := platformAbsPath("opt", "mcp.json")
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: optInstructions,
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
agentcontextconfig.EnvSkillsDirs: optSkills,
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
agentcontextconfig.EnvMCPConfigFiles: optMCP,
})
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
// Create files matching the custom names so we can
// verify the env vars actually change lookup behavior.
@@ -209,12 +85,15 @@ func TestConfig(t *testing.T) {
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("WhitespaceInFileNames", func(t *testing.T) {
fakeHome := setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
})
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create a file matching the trimmed name.
@@ -227,13 +106,19 @@ func TestConfig(t *testing.T) {
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("CommaSeparatedDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
a := t.TempDir()
b := t.TempDir()
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
})
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
// Put instruction files in both dirs.
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
@@ -248,10 +133,17 @@ func TestConfig(t *testing.T) {
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("ReadsInstructionFiles", func(t *testing.T) {
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
fakeHome := setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
// Create ~/.coder/AGENTS.md
coderDir := filepath.Join(fakeHome, ".coder")
@@ -272,9 +164,16 @@ func TestConfig(t *testing.T) {
require.False(t, ctxFiles[0].ContextFileTruncated)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// Create AGENTS.md in the working directory.
@@ -294,9 +193,16 @@ func TestConfig(t *testing.T) {
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
largeContent := strings.Repeat("a", 64*1024+100)
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
@@ -309,47 +215,79 @@ func TestConfig(t *testing.T) {
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
})
sanitizationTests := []struct {
name string
input string
expected string
}{
{
name: "SanitizesHTMLComments",
input: "visible\n<!-- hidden -->content",
expected: "visible\ncontent",
},
{
name: "SanitizesInvisibleUnicode",
input: "before\u200bafter",
expected: "beforeafter",
},
{
name: "NormalizesCRLF",
input: "line1\r\nline2\rline3",
expected: "line1\nline2\nline3",
},
}
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
for _, tt := range sanitizationTests {
t.Run(tt.name, func(t *testing.T) {
setupConfigTestEnv(t, nil)
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte(tt.input),
0o600,
))
t.Run("SanitizesHTMLComments", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
cfg, _ := agentcontextconfig.Config(workDir)
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("visible\n<!-- hidden -->content"),
0o600,
))
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
})
}
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
})
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
// U+200B (zero-width space) should be stripped.
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("before\u200bafter"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
})
t.Run("NormalizesCRLF", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(workDir, "AGENTS.md"),
[]byte("line1\r\nline2\rline3"),
0o600,
))
cfg, _ := agentcontextconfig.Config(workDir)
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
require.Len(t, ctxFiles, 1)
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("DiscoversSkills", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
@@ -382,13 +320,17 @@ func TestConfig(t *testing.T) {
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("SkipsMissingDirs", func(t *testing.T) {
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvInstructionsDirs: nonExistent,
agentcontextconfig.EnvSkillsDirs: nonExistent,
})
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
cfg, _ := agentcontextconfig.Config(workDir)
@@ -398,13 +340,17 @@ func TestConfig(t *testing.T) {
require.Empty(t, cfg.Parts)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
optMCP := platformAbsPath("opt", "custom.json")
fakeHome := setupConfigTestEnv(t, map[string]string{
agentcontextconfig.EnvMCPConfigFiles: optMCP,
})
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
optMCP := platformAbsPath("opt", "custom.json")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
workDir := t.TempDir()
_, mcpFiles := agentcontextconfig.Config(workDir)
@@ -412,10 +358,14 @@ func TestConfig(t *testing.T) {
require.Equal(t, []string{optMCP}, mcpFiles)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
fakeHome := setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir := filepath.Join(workDir, "skills")
@@ -435,10 +385,14 @@ func TestConfig(t *testing.T) {
require.Empty(t, skillParts)
})
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
fakeHome := setupConfigTestEnv(t, nil)
fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)
t.Setenv("USERPROFILE", fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
workDir := t.TempDir()
skillsDir1 := filepath.Join(workDir, "skills1")
+1036 -1139
View File
File diff suppressed because it is too large Load Diff
-15
View File
@@ -98,21 +98,6 @@ message Manifest {
repeated WorkspaceApp apps = 11;
repeated WorkspaceAgentMetadata.Description metadata = 12;
repeated WorkspaceAgentDevcontainer devcontainers = 17;
repeated WorkspaceSecret secrets = 19;
}
// WorkspaceSecret is a secret included in the agent manifest
// for injection into a workspace.
message WorkspaceSecret {
// Environment variable name to inject (e.g. "GITHUB_TOKEN").
// Empty string means this secret is not injected as an env var.
string env_name = 1;
// File path to write the secret value to (e.g.
// "~/.aws/credentials"). Empty string means this secret is not
// written to a file.
string file_path = 2;
// The decrypted secret value.
bytes value = 3;
}
message WorkspaceAgentDevcontainer {
+3 -60
View File
@@ -5,9 +5,7 @@ import (
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"sync"
"time"
@@ -622,11 +620,6 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
return
}
defer artifact.Reader.Close()
defer func() {
if artifact.ThumbnailReader != nil {
_ = artifact.ThumbnailReader.Close()
}
}()
if artifact.Size > workspacesdk.MaxRecordingSize {
a.logger.Warn(ctx, "recording file exceeds maximum size",
@@ -640,60 +633,10 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
return
}
// Discard the thumbnail if it exceeds the maximum size.
// The server-side consumer also enforces this per-part, but
// rejecting it here avoids streaming a large thumbnail over
// the wire for nothing.
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
slog.F("recording_id", recordingID),
slog.F("size", artifact.ThumbnailSize),
slog.F("max_size", workspacesdk.MaxThumbnailSize),
)
_ = artifact.ThumbnailReader.Close()
artifact.ThumbnailReader = nil
artifact.ThumbnailSize = 0
}
// The multipart response is best-effort: once WriteHeader(200) is
// called, CreatePart failures produce a truncated response without
// the closing boundary. The server-side consumer handles this
// gracefully, preserving any parts read before the error.
mw := multipart.NewWriter(rw)
defer mw.Close()
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
rw.Header().Set("Content-Type", "video/mp4")
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
rw.WriteHeader(http.StatusOK)
// Part 1: video/mp4 (always present).
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"video/mp4"},
})
if err != nil {
a.logger.Warn(ctx, "failed to create video multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
a.logger.Warn(ctx, "failed to write video multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
// Part 2: image/jpeg (present only when thumbnail was extracted).
if artifact.ThumbnailReader != nil {
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
"Content-Type": {"image/jpeg"},
})
if err != nil {
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
slog.F("recording_id", recordingID),
slog.Error(err))
return
}
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
}
_, _ = io.Copy(rw, artifact.Reader)
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
+16 -191
View File
@@ -4,17 +4,12 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"mime/multipart"
"net"
"net/http"
"net/http/httptest"
"os"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -64,8 +59,6 @@ type fakeDesktop struct {
lastKeyDown string
lastKeyUp string
thumbnailData []byte // if set, StopRecording includes a thumbnail
// Recording tracking (guarded by recMu).
recMu sync.Mutex
recordings map[string]string // ID → file path
@@ -194,15 +187,10 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
_ = file.Close()
return nil, err
}
artifact := &agentdesktop.RecordingArtifact{
return &agentdesktop.RecordingArtifact{
Reader: file,
Size: info.Size(),
}
if f.thumbnailData != nil {
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
artifact.ThumbnailSize = int64(len(f.thumbnailData))
}
return artifact, nil
}, nil
}
func (f *fakeDesktop) RecordActivity() {
@@ -797,8 +785,8 @@ func TestRecordingStartStop(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
}
func TestRecordingStartFails(t *testing.T) {
@@ -859,8 +847,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
}
func TestRecordingStopIdempotent(t *testing.T) {
@@ -884,7 +872,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
require.Equal(t, http.StatusOK, rr.Code)
// Stop twice - both should succeed with identical data.
var videoParts [2][]byte
var bodies [2][]byte
for i := range 2 {
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
require.NoError(t, err)
@@ -892,10 +880,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code)
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
videoParts[i] = parts["video/mp4"]
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
bodies[i] = recorder.Body.Bytes()
}
assert.Equal(t, videoParts[0], videoParts[1])
assert.Equal(t, bodies[0], bodies[1])
}
func TestRecordingStopInvalidIDFormat(t *testing.T) {
@@ -1016,8 +1004,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
assert.Equal(t, expected[id], parts["video/mp4"])
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
assert.Equal(t, expected[id], rr.Body.Bytes())
}
}
@@ -1124,8 +1112,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
firstData := firstParts["video/mp4"]
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
firstData := rr.Body.Bytes()
require.NotEmpty(t, firstData)
// Step 3: Start again with the same ID - should succeed
@@ -1140,8 +1128,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
secondData := secondParts["video/mp4"]
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
secondData := rr.Body.Bytes()
require.NotEmpty(t, secondData)
// The two recordings should have different data because the
@@ -1247,166 +1235,3 @@ func TestRecordingStopCorrupted(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "Recording is corrupted.", respStop.Message)
}
// parseMultipartParts parses a multipart/mixed response and returns
// a map from Content-Type to body bytes.
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
t.Helper()
_, params, err := mime.ParseMediaType(contentType)
require.NoError(t, err, "parse Content-Type")
boundary := params["boundary"]
require.NotEmpty(t, boundary, "missing boundary")
mr := multipart.NewReader(bytes.NewReader(body), boundary)
parts := make(map[string][]byte)
for {
part, err := mr.NextPart()
if errors.Is(err, io.EOF) {
break
}
require.NoError(t, err, "unexpected multipart parse error")
ct := part.Header.Get("Content-Type")
data, readErr := io.ReadAll(part)
require.NoError(t, readErr)
parts[ct] = data
}
return parts
}
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
thumbnail := make([]byte, 512)
thumbnail[0] = 0xff
thumbnail[1] = 0xd8
thumbnail[2] = 0xff
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
thumbnailData: thumbnail,
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
assert.Equal(t, thumbnail, parts["image/jpeg"])
}
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
}
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Create thumbnail data that exceeds MaxThumbnailSize.
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
oversizedThumb[0] = 0xff
oversizedThumb[1] = 0xd8
oversizedThumb[2] = 0xff
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
thumbnailData: oversizedThumb,
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
handler := api.Routes()
// Start recording.
recID := uuid.New().String()
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Stop recording.
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
require.NoError(t, err)
rr = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code)
// Verify multipart response contains only the video part.
ct := rr.Header().Get("Content-Type")
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
"expected multipart/mixed Content-Type, got %s", ct)
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
assert.Equal(t, expectedMP4, parts["video/mp4"])
}
-5
View File
@@ -105,11 +105,6 @@ type RecordingArtifact struct {
Reader io.ReadCloser
// Size is the byte length of the MP4 content.
Size int64
// ThumbnailReader is the JPEG thumbnail. May be nil if no
// thumbnail was produced. Callers must close it when done.
ThumbnailReader io.ReadCloser
// ThumbnailSize is the byte length of the thumbnail.
ThumbnailSize int64
}
// DisplayConfig describes a running desktop session.
+13 -72
View File
@@ -56,7 +56,6 @@ type screenshotOutput struct {
type recordingProcess struct {
cmd *exec.Cmd
filePath string
thumbPath string
stopped bool
killed bool // true when the process was SIGKILLed
done chan struct{} // closed when cmd.Wait() returns
@@ -384,20 +383,13 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
}
// Completed recording - discard old file, start fresh.
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove old recording file",
slog.F("recording_id", recordingID),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove old thumbnail file",
slog.F("recording_id", recordingID),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
delete(p.recordings, recordingID)
}
@@ -414,7 +406,6 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
// Use a background context so the process outlives the HTTP
// request that triggered it.
@@ -428,7 +419,6 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
"--idle-speedup", "20",
"--idle-min-duration", "0.35",
"--idle-noise-tolerance", "-38dB",
"--thumbnail", thumbPath,
filePath)
if err := cmd.Start(); err != nil {
@@ -437,10 +427,9 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
}
rec := &recordingProcess{
cmd: cmd,
filePath: filePath,
thumbPath: thumbPath,
done: make(chan struct{}),
cmd: cmd,
filePath: filePath,
done: make(chan struct{}),
}
go func() {
rec.waitErr = cmd.Wait()
@@ -510,35 +499,10 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
_ = f.Close()
return nil, xerrors.Errorf("stat recording artifact: %w", err)
}
artifact := &RecordingArtifact{
return &RecordingArtifact{
Reader: f,
Size: info.Size(),
}
// Attach thumbnail if the subprocess wrote one.
thumbFile, err := os.Open(rec.thumbPath)
if err != nil {
p.logger.Warn(ctx, "thumbnail not available",
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err))
return artifact, nil
}
thumbInfo, err := thumbFile.Stat()
if err != nil {
_ = thumbFile.Close()
p.logger.Warn(ctx, "thumbnail stat failed",
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err))
return artifact, nil
}
if thumbInfo.Size() == 0 {
_ = thumbFile.Close()
p.logger.Warn(ctx, "thumbnail file is empty",
slog.F("thumbnail_path", rec.thumbPath))
return artifact, nil
}
artifact.ThumbnailReader = thumbFile
artifact.ThumbnailSize = thumbInfo.Size()
return artifact, nil
}, nil
}
// lockedStopRecordingProcess stops a single recording via stopOnce.
@@ -607,33 +571,18 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
}
info, err := os.Stat(rec.filePath)
if err != nil {
// File already removed or inaccessible; clean up
// any leftover thumbnail and drop the entry.
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
slog.F("recording_id", id),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
// File already removed or inaccessible; drop entry.
delete(p.recordings, id)
continue
}
if p.clock.Since(info.ModTime()) > time.Hour {
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(ctx, "failed to remove stale recording file",
slog.F("recording_id", id),
slog.F("file_path", rec.filePath),
slog.Error(err),
)
}
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
slog.F("recording_id", id),
slog.F("thumbnail_path", rec.thumbPath),
slog.Error(err),
)
}
delete(p.recordings, id)
}
}
@@ -654,14 +603,13 @@ func (p *portableDesktop) Close() error {
// Snapshot recording file paths and idle goroutine channels
// for cleanup, then clear the map.
type recEntry struct {
id string
filePath string
thumbPath string
idleDone chan struct{}
id string
filePath string
idleDone chan struct{}
}
var allRecs []recEntry
for id, rec := range p.recordings {
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
delete(p.recordings, id)
}
session := p.session
@@ -682,20 +630,13 @@ func (p *portableDesktop) Close() error {
go func() {
defer close(cleanupDone)
for _, entry := range allRecs {
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
p.logger.Warn(context.Background(), "failed to remove recording file on close",
slog.F("recording_id", entry.id),
slog.F("file_path", entry.filePath),
slog.Error(err),
)
}
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
slog.F("recording_id", entry.id),
slog.F("thumbnail_path", entry.thumbPath),
slog.Error(err),
)
}
}
if session != nil {
session.cancel()
@@ -2,7 +2,6 @@ package agentdesktop
import (
"context"
"io"
"os"
"os/exec"
"path/filepath"
@@ -585,7 +584,6 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
joined := strings.Join(cmd, " ")
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
found = true
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
break
}
}
@@ -668,66 +666,6 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
// No thumbnail file exists, so ThumbnailReader should be nil.
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"record": `trap 'exit 0' INT; sleep 120 & wait`,
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
clk := quartz.NewReal()
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
clock: clk,
binPath: "portabledesktop",
recordings: make(map[string]*recordingProcess),
}
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
ctx := t.Context()
recID := uuid.New().String()
err := pd.StartRecording(ctx, recID)
require.NoError(t, err)
// Write a dummy MP4 file at the expected path.
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
t.Cleanup(func() { _ = os.Remove(filePath) })
// Write a thumbnail file at the expected path.
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
thumbContent := []byte("fake-jpeg-thumbnail")
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
t.Cleanup(func() { _ = os.Remove(thumbPath) })
artifact, err := pd.StopRecording(ctx, recID)
require.NoError(t, err)
defer artifact.Reader.Close()
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
// Thumbnail should be attached.
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
defer artifact.ThumbnailReader.Close()
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
// Read and verify thumbnail content.
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
require.NoError(t, err)
assert.Equal(t, thumbContent, thumbData)
require.NoError(t, pd.Close())
}
@@ -812,18 +750,12 @@ func TestPortableDesktop_IdleTimeout_StopsRecordings(t *testing.T) {
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout to trigger the stop-all.
clk.Advance(idleTimeout).MustWait(ctx)
clk.Advance(idleTimeout)
// Wait for the stop timer to be created, then release it.
stopTrap.MustWait(ctx).MustRelease(ctx)
stopTrap.Close()
// Advance past the 15s stop timeout so the process is
// forcibly killed. Without this the test depends on the real
// shell handling SIGINT promptly, which is unreliable on
// macOS CI runners (the flake in #1461).
clk.Advance(15 * time.Second).MustWait(ctx)
// The recording process should now be stopped.
require.Eventually(t, func() bool {
pd.mu.Lock()
@@ -945,17 +877,11 @@ func TestPortableDesktop_IdleTimeout_MultipleRecordings(t *testing.T) {
stopTrap := clk.Trap().NewTimer("agentdesktop", "stop_timeout")
// Advance past idle timeout.
clk.Advance(idleTimeout).MustWait(ctx)
clk.Advance(idleTimeout)
// Each idle monitor goroutine serializes on p.mu, so the
// second stop timer is only created after the first stop
// completes. Advance past the 15s stop timeout after each
// release so the process is forcibly killed instead of
// depending on SIGINT (unreliable on macOS — see #1461).
// Wait for both stop timers.
stopTrap.MustWait(ctx).MustRelease(ctx)
clk.Advance(15 * time.Second).MustWait(ctx)
stopTrap.MustWait(ctx).MustRelease(ctx)
clk.Advance(15 * time.Second).MustWait(ctx)
stopTrap.Close()
// Both recordings should be stopped.
-6
View File
@@ -87,12 +87,6 @@ func IsDevVersion(v string) bool {
return strings.Contains(v, "-"+develPreRelease)
}
// IsRCVersion returns true if the version has a release candidate
// pre-release tag, e.g. "v2.31.0-rc.0".
func IsRCVersion(v string) bool {
return strings.Contains(v, "-rc.")
}
// IsDev returns true if this is a development build.
// CI builds are also considered development builds.
func IsDev() bool {
-26
View File
@@ -102,29 +102,3 @@ func TestBuildInfo(t *testing.T) {
}
})
}
func TestIsRCVersion(t *testing.T) {
t.Parallel()
cases := []struct {
name string
version string
expected bool
}{
{"RC0", "v2.31.0-rc.0", true},
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
{"RC10", "v2.31.0-rc.10", true},
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
{"DevelVersion", "v2.31.0-devel+abc123", false},
{"StableVersion", "v2.31.0", false},
{"DevNoVersion", "v0.0.0-devel+abc123", false},
{"BetaVersion", "v2.31.0-beta.1", false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
})
}
}
-194
View File
@@ -1,194 +0,0 @@
package cli
import (
"fmt"
"os"
"path/filepath"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/serpent"
)
func (r *RootCmd) chatCommand() *serpent.Command {
return &serpent.Command{
Use: "chat",
Short: "Manage agent chats",
Long: "Commands for interacting with chats from within a workspace.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.chatContextCommand(),
},
}
}
func (r *RootCmd) chatContextCommand() *serpent.Command {
return &serpent.Command{
Use: "context",
Short: "Manage chat context",
Long: "Add or clear context files and skills for an active chat session.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.chatContextAddCommand(),
r.chatContextClearCommand(),
},
}
}
func (*RootCmd) chatContextAddCommand() *serpent.Command {
var (
dir string
chatID string
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "add",
Short: "Add context to an active chat",
Long: "Read instruction files and discover skills from a directory, then add " +
"them as context to an active chat session. Multiple calls " +
"are additive.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()
if dir == "" && inv.Environ.Get("CODER") != "true" {
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
}
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
resolvedDir := dir
if resolvedDir == "" {
resolvedDir, err = os.Getwd()
if err != nil {
return xerrors.Errorf("get working directory: %w", err)
}
}
resolvedDir, err = filepath.Abs(resolvedDir)
if err != nil {
return xerrors.Errorf("resolve directory: %w", err)
}
info, err := os.Stat(resolvedDir)
if err != nil {
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
}
if !info.IsDir() {
return xerrors.Errorf("%q is not a directory", resolvedDir)
}
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
if len(parts) == 0 {
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
return nil
}
// Resolve chat ID from flag or auto-detect.
resolvedChatID, err := parseChatID(chatID)
if err != nil {
return err
}
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
ChatID: resolvedChatID,
Parts: parts,
})
if err != nil {
return xerrors.Errorf("add chat context: %w", err)
}
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
return nil
},
Options: serpent.OptionSet{
{
Name: "Directory",
Flag: "dir",
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
Value: serpent.StringOf(&dir),
},
{
Name: "Chat ID",
Flag: "chat",
Env: "CODER_CHAT_ID",
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
Value: serpent.StringOf(&chatID),
},
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
func (*RootCmd) chatContextClearCommand() *serpent.Command {
var chatID string
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "clear",
Short: "Clear context from an active chat",
Long: "Soft-delete all context-file and skill messages from an active chat. " +
"The next turn will re-fetch default context from the agent.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
defer stop()
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
resolvedChatID, err := parseChatID(chatID)
if err != nil {
return err
}
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
ChatID: resolvedChatID,
})
if err != nil {
return xerrors.Errorf("clear chat context: %w", err)
}
if resp.ChatID == uuid.Nil {
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
} else {
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
}
return nil
},
Options: serpent.OptionSet{{
Name: "Chat ID",
Flag: "chat",
Env: "CODER_CHAT_ID",
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
Value: serpent.StringOf(&chatID),
}},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
// parseChatID returns the chat UUID from the flag value (which
// serpent already populates from --chat or CODER_CHAT_ID). Returns
// uuid.Nil if empty (the server will auto-detect).
func parseChatID(flagValue string) (uuid.UUID, error) {
if flagValue == "" {
return uuid.Nil, nil
}
parsed, err := uuid.Parse(flagValue)
if err != nil {
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
}
return parsed, nil
}
-46
View File
@@ -1,46 +0,0 @@
package cli_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
)
func TestExpChatContextAdd(t *testing.T) {
t.Parallel()
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
err := inv.Run()
require.Error(t, err)
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
})
t.Run("AllowsExplicitDir", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
err := inv.Run()
if err != nil {
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
}
})
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
inv.Environ.Set("CODER", "true")
err := inv.Run()
if err != nil {
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
}
})
}
+5 -29
View File
@@ -7,7 +7,6 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"net/http"
@@ -149,7 +148,6 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
return []*serpent.Command{
r.scaletestCmd(),
r.errorExample(),
r.chatCommand(),
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
@@ -712,7 +710,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
transport = wrapTransportWithTelemetryHeader(transport, inv)
transport = wrapTransportWithUserAgentHeader(transport, inv)
if !r.noVersionCheck {
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
// Create a new client without any wrapped transport
// otherwise it creates an infinite loop!
basicClient := codersdk.New(serverURL)
@@ -1436,21 +1434,6 @@ func defaultUpgradeMessage(version string) string {
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
}
// serverVersionMessage returns a warning message if the server version
// is a release candidate or development build. Returns empty string
// for stable versions. RC is checked before devel because RC dev
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
func serverVersionMessage(serverVersion string) string {
switch {
case buildinfo.IsRCVersion(serverVersion):
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
case buildinfo.IsDevVersion(serverVersion):
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
default:
return ""
}
}
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
// that checks for entitlement warnings and prints them to the user.
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
@@ -1469,10 +1452,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
})
}
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
// that checks the server version and warns about development builds,
// release candidates, and client/server version mismatches.
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
// that checks for version mismatches between the client and server. If a mismatch
// is detected, a warning is printed to the user.
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
var once sync.Once
return roundTripper(func(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTrip(req)
@@ -1484,16 +1467,9 @@ func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation
if serverVersion == "" {
return
}
// Warn about non-stable server versions. Skip
// during tests to avoid polluting golden files.
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
_, _ = fmt.Fprintln(inv.Stderr, warning)
}
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
return
}
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
switch {
+3 -50
View File
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
}
}
func Test_wrapTransportWithVersionCheck(t *testing.T) {
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
t.Parallel()
t.Run("NoOutput", func(t *testing.T) {
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
inv := cmd.Invoke()
inv.Stderr = &buf
expectedUpgradeMessage := "My custom upgrade message"
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
@@ -159,53 +159,6 @@ func Test_wrapTransportWithVersionCheck(t *testing.T) {
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
require.Equal(t, expectedOutput, buf.String())
})
t.Run("ServerStableVersion", func(t *testing.T) {
t.Parallel()
r := &RootCmd{}
cmd, err := r.Command(nil)
require.NoError(t, err)
var buf bytes.Buffer
inv := cmd.Invoke()
inv.Stderr = &buf
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
codersdk.BuildVersionHeader: []string{"v2.31.0"},
},
Body: io.NopCloser(nil),
}, nil
}), inv, "v2.31.0", nil)
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
res, err := rt.RoundTrip(req)
require.NoError(t, err)
defer res.Body.Close()
require.Empty(t, buf.String())
})
}
func Test_serverVersionMessage(t *testing.T) {
t.Parallel()
cases := []struct {
name string
version string
expected string
}{
{"Stable", "v2.31.0", ""},
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
{"Empty", "", ""},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, c.expected, serverVersionMessage(c.version))
})
}
}
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
+20
View File
@@ -768,10 +768,30 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("create pubsub: %w", err)
}
options.Pubsub = ps
options.ChatPubsub = ps
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(ps)
}
defer options.Pubsub.Close()
chatPubsub, err := pubsub.NewBatching(
ctx,
logger.Named("chatd").Named("pubsub_batch"),
ps,
sqlDB,
dbURL,
pubsub.BatchingConfig{
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
},
)
if err != nil {
return xerrors.Errorf("create chat pubsub batcher: %w", err)
}
options.ChatPubsub = chatPubsub
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(chatPubsub)
}
defer options.ChatPubsub.Close()
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
pubsubWatchdogTimeout = psWatchdog.Timeout()
defer psWatchdog.Close()
+4 -6
View File
@@ -69,17 +69,15 @@ var (
// isRetryableError checks for transient connection errors worth
// retrying: DNS failures, connection refused, and server 5xx.
func isRetryableError(err error) bool {
if err == nil || xerrors.Is(err, context.Canceled) {
if err == nil {
return false
}
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return false
}
// Check connection errors before context.DeadlineExceeded because
// net.Dialer.Timeout produces *net.OpError that matches both.
if codersdk.IsConnectionError(err) {
return true
}
if xerrors.Is(err, context.DeadlineExceeded) {
return false
}
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
return sdkErr.StatusCode() >= 500
-17
View File
@@ -516,23 +516,6 @@ func TestIsRetryableError(t *testing.T) {
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
})
}
// net.Dialer.Timeout produces *net.OpError that matches both
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
t.Run("DialTimeout", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
defer cancel()
<-ctx.Done() // ensure deadline has fired
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
require.Error(t, err)
// Proves the ambiguity: this error matches BOTH checks.
require.ErrorIs(t, err, context.DeadlineExceeded)
require.ErrorAs(t, err, new(*net.OpError))
assert.True(t, isRetryableError(err))
// Also when wrapped, as runCoderConnectStdio does.
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
})
}
func TestRetryWithInterval(t *testing.T) {
+1 -1
View File
@@ -11,7 +11,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
Columns to display in table output.
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
@@ -58,8 +58,7 @@
"template_display_name": "",
"template_icon": "",
"workspace_id": "===========[workspace ID]===========",
"workspace_name": "test-workspace",
"workspace_build_transition": "start"
"workspace_name": "test-workspace"
},
"logs_overflowed": false,
"organization_name": "Coder"
-7
View File
@@ -211,13 +211,6 @@ AI BRIDGE PROXY OPTIONS:
certificates not trusted by the system. If not provided, the system
certificate pool is used.
CHAT OPTIONS:
Configure the background chat processing daemon.
--chat-debug-logging-enabled bool, $CODER_CHAT_DEBUG_LOGGING_ENABLED (default: false)
Force chat debug logging on for every chat, bypassing the runtime
admin and user opt-in settings.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-4
View File
@@ -757,10 +757,6 @@ chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
# Force chat debug logging on for every chat, bypassing the runtime admin and user
# opt-in settings.
# (default: false, type: bool)
debugLoggingEnabled: false
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
+1 -11
View File
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
// An ID is only given in the request when it is a terraform-defined devcontainer
// that has attached resources. These subagents are pre-provisioned by terraform
// (the agent record already exists), so we update configurable fields like
// display_apps and directory rather than creating a new agent.
// display_apps rather than creating a new agent.
if req.Id != nil {
id, err := uuid.FromBytes(req.Id)
if err != nil {
@@ -97,16 +97,6 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
}
if req.Directory != "" {
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
ID: id,
Directory: req.Directory,
UpdatedAt: createdAt,
}); err != nil {
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
}
}
return &agentproto.CreateSubAgentResponse{
Agent: &agentproto.SubAgent{
Name: subAgent.Name,
+2 -38
View File
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
agentID, err := uuid.FromBytes(resp.Agent.Id)
require.NoError(t, err)
// And: The database agent's name, architecture, and OS are unchanged.
// And: The database agent's other fields are unchanged.
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
require.Equal(t, "/different/path", updatedAgent.Directory)
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
@@ -1280,42 +1280,6 @@ func TestSubAgentAPI(t *testing.T) {
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
},
},
{
name: "OK_DirectoryUpdated",
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
// Given: An existing child agent with a stale host-side
// directory (as set by the provisioner at build time).
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
ResourceID: agent.ResourceID,
Name: baseChildAgent.Name,
Directory: "/home/coder/project",
Architecture: baseChildAgent.Architecture,
OperatingSystem: baseChildAgent.OperatingSystem,
DisplayApps: baseChildAgent.DisplayApps,
})
// When: Agent injection sends the correct
// container-internal path.
return &proto.CreateSubAgentRequest{
Id: childAgent.ID[:],
Directory: "/workspaces/project",
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
proto.CreateSubAgentRequest_WEB_TERMINAL,
},
}
},
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
agentID, err := uuid.FromBytes(resp.Agent.Id)
require.NoError(t, err)
// Then: Directory is updated to the container-internal
// path.
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
require.NoError(t, err)
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
},
},
{
name: "Error/MalformedID",
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
-284
View File
@@ -9514,212 +9514,6 @@ const docTemplate = `{
]
}
},
"/users/{user}/secrets": {
"get": {
"produces": [
"application/json"
],
"tags": [
"Secrets"
],
"summary": "List user secrets",
"operationId": "list-user-secrets",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"post": {
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Secrets"
],
"summary": "Create a new user secret",
"operationId": "create-a-new-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"description": "Create secret request",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
}
}
],
"responses": {
"201": {
"description": "Created",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/users/{user}/secrets/{name}": {
"get": {
"produces": [
"application/json"
],
"tags": [
"Secrets"
],
"summary": "Get a user secret by name",
"operationId": "get-a-user-secret-by-name",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"delete": {
"tags": [
"Secrets"
],
"summary": "Delete a user secret",
"operationId": "delete-a-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"patch": {
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"tags": [
"Secrets"
],
"summary": "Update a user secret",
"operationId": "update-a-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
},
{
"description": "Update secret request",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/users/{user}/status/activate": {
"put": {
"produces": [
@@ -13445,12 +13239,6 @@ const docTemplate = `{
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
}
},
"credential_hint": {
"type": "string"
},
"credential_kind": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
@@ -14691,9 +14479,6 @@ const docTemplate = `{
"properties": {
"acquire_batch_size": {
"type": "integer"
},
"debug_logging_enabled": {
"type": "boolean"
}
}
},
@@ -15357,26 +15142,6 @@ const docTemplate = `{
}
}
},
"codersdk.CreateUserSecretRequest": {
"type": "object",
"properties": {
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"name": {
"type": "string"
},
"value": {
"type": "string"
}
}
},
"codersdk.CreateWorkspaceBuildReason": {
"type": "string",
"enum": [
@@ -19152,9 +18917,6 @@ const docTemplate = `{
"template_version_name": {
"type": "string"
},
"workspace_build_transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
},
"workspace_id": {
"type": "string",
"format": "uuid"
@@ -21509,23 +21271,6 @@ const docTemplate = `{
}
}
},
"codersdk.UpdateUserSecretRequest": {
"type": "object",
"properties": {
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"value": {
"type": "string"
}
}
},
"codersdk.UpdateWorkspaceACL": {
"type": "object",
"properties": {
@@ -21981,35 +21726,6 @@ const docTemplate = `{
}
}
},
"codersdk.UserSecret": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"id": {
"type": "string",
"format": "uuid"
},
"name": {
"type": "string"
},
"updated_at": {
"type": "string",
"format": "date-time"
}
}
},
"codersdk.UserStatus": {
"type": "string",
"enum": [
-262
View File
@@ -8431,190 +8431,6 @@
]
}
},
"/users/{user}/secrets": {
"get": {
"produces": ["application/json"],
"tags": ["Secrets"],
"summary": "List user secrets",
"operationId": "list-user-secrets",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"post": {
"consumes": ["application/json"],
"produces": ["application/json"],
"tags": ["Secrets"],
"summary": "Create a new user secret",
"operationId": "create-a-new-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"description": "Create secret request",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
}
}
],
"responses": {
"201": {
"description": "Created",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/users/{user}/secrets/{name}": {
"get": {
"produces": ["application/json"],
"tags": ["Secrets"],
"summary": "Get a user secret by name",
"operationId": "get-a-user-secret-by-name",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"delete": {
"tags": ["Secrets"],
"summary": "Delete a user secret",
"operationId": "delete-a-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
]
},
"patch": {
"consumes": ["application/json"],
"produces": ["application/json"],
"tags": ["Secrets"],
"summary": "Update a user secret",
"operationId": "update-a-user-secret",
"parameters": [
{
"type": "string",
"description": "User ID, username, or me",
"name": "user",
"in": "path",
"required": true
},
{
"type": "string",
"description": "Secret name",
"name": "name",
"in": "path",
"required": true
},
{
"description": "Update secret request",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.UserSecret"
}
}
},
"security": [
{
"CoderSessionToken": []
}
]
}
},
"/users/{user}/status/activate": {
"put": {
"produces": ["application/json"],
@@ -11993,12 +11809,6 @@
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
}
},
"credential_hint": {
"type": "string"
},
"credential_kind": {
"type": "string"
},
"ended_at": {
"type": "string",
"format": "date-time"
@@ -13204,9 +13014,6 @@
"properties": {
"acquire_batch_size": {
"type": "integer"
},
"debug_logging_enabled": {
"type": "boolean"
}
}
},
@@ -13836,26 +13643,6 @@
}
}
},
"codersdk.CreateUserSecretRequest": {
"type": "object",
"properties": {
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"name": {
"type": "string"
},
"value": {
"type": "string"
}
}
},
"codersdk.CreateWorkspaceBuildReason": {
"type": "string",
"enum": [
@@ -17512,9 +17299,6 @@
"template_version_name": {
"type": "string"
},
"workspace_build_transition": {
"$ref": "#/definitions/codersdk.WorkspaceTransition"
},
"workspace_id": {
"type": "string",
"format": "uuid"
@@ -19761,23 +19545,6 @@
}
}
},
"codersdk.UpdateUserSecretRequest": {
"type": "object",
"properties": {
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"value": {
"type": "string"
}
}
},
"codersdk.UpdateWorkspaceACL": {
"type": "object",
"properties": {
@@ -20208,35 +19975,6 @@
}
}
},
"codersdk.UserSecret": {
"type": "object",
"properties": {
"created_at": {
"type": "string",
"format": "date-time"
},
"description": {
"type": "string"
},
"env_name": {
"type": "string"
},
"file_path": {
"type": "string"
},
"id": {
"type": "string",
"format": "uuid"
},
"name": {
"type": "string"
},
"updated_at": {
"type": "string",
"format": "date-time"
}
}
},
"codersdk.UserStatus": {
"type": "string",
"enum": ["active", "dormant", "suspended"],
+10 -15
View File
@@ -159,7 +159,10 @@ type Options struct {
Logger slog.Logger
Database database.Store
Pubsub pubsub.Pubsub
RuntimeConfig *runtimeconfig.Manager
// ChatPubsub allows chatd to use a dedicated publish path without changing
// the shared pubsub used by the rest of coderd.
ChatPubsub pubsub.Pubsub
RuntimeConfig *runtimeconfig.Manager
// CacheDir is used for caching files served by the API.
CacheDir string
@@ -777,6 +780,11 @@ func New(options *Options) *API {
maxChatsPerAcquire = math.MinInt32
}
chatPubsub := options.ChatPubsub
if chatPubsub == nil {
chatPubsub = options.Pubsub
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
@@ -789,7 +797,7 @@ func New(options *Options) *API {
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
Pubsub: chatPubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
})
@@ -1608,15 +1616,6 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.gitSSHKey)
r.Put("/gitsshkey", api.regenerateGitSSHKey)
r.Route("/secrets", func(r chi.Router) {
r.Post("/", api.postUserSecret)
r.Get("/", api.getUserSecrets)
r.Route("/{name}", func(r chi.Router) {
r.Get("/", api.getUserSecret)
r.Patch("/", api.patchUserSecret)
r.Delete("/", api.deleteUserSecret)
})
})
r.Route("/notifications", func(r chi.Router) {
r.Route("/preferences", func(r chi.Router) {
r.Get("/", api.userNotificationPreferences)
@@ -1662,10 +1661,6 @@ func New(options *Options) *API {
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Post("/log-source", api.workspaceAgentPostLogSource)
r.Get("/reinit", api.workspaceAgentReinit)
r.Route("/experimental", func(r chi.Router) {
r.Post("/chat-context", api.workspaceAgentAddChatContext)
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
})
r.Route("/tasks/{task}", func(r chi.Router) {
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
})
-7
View File
@@ -147,10 +147,6 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
return c
}
func isExperimentalEndpoint(route string) bool {
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
}
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
assertUniqueRoutes(t, swaggerComments)
assertSingleAnnotations(t, swaggerComments)
@@ -169,9 +165,6 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
if strings.HasSuffix(route, "/*") {
return
}
if isExperimentalEndpoint(route) {
return
}
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
assert.NotNil(t, c, "Missing @Router annotation")
-106
View File
@@ -538,12 +538,6 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
switch {
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
workspaceAgent.Health.Reason = "agent is not running"
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
// Note: the case above catches connecting+off as "not running".
// This case handles connecting agents with a non-off lifecycle
// (e.g. "created" or "starting"), where the agent binary has
// not yet established a connection to coderd.
workspaceAgent.Health.Reason = "agent has not yet connected"
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
workspaceAgent.Health.Reason = "agent is taking too long to connect"
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
@@ -1240,8 +1234,6 @@ func buildAIBridgeThread(
if rootIntc != nil {
thread.Model = rootIntc.Model
thread.Provider = rootIntc.Provider
thread.CredentialKind = string(rootIntc.CredentialKind)
thread.CredentialHint = rootIntc.CredentialHint
// Get first user prompt from root interception.
// A thread can only have one prompt, by definition, since we currently
// only store the last prompt observed in an interception.
@@ -1533,22 +1525,6 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
return &value
}
func nullStringPtr(v sql.NullString) *string {
if !v.Valid {
return nil
}
value := v.String
return &value
}
func nullTimePtr(v sql.NullTime) *time.Time {
if !v.Valid {
return nil
}
value := v.Time
return &value
}
// Chat converts a database.Chat to a codersdk.Chat. It coalesces
// nil slices and maps to empty values for JSON serialization and
// derives RootChatID from the parent chain when not explicitly set.
@@ -1635,88 +1611,6 @@ func Chat(c database.Chat, diffStatus *database.ChatDiffStatus, files []database
return chat
}
func chatDebugAttempts(raw json.RawMessage) []map[string]any {
if len(raw) == 0 {
return nil
}
var attempts []map[string]any
if err := json.Unmarshal(raw, &attempts); err != nil {
return []map[string]any{{
"error": "malformed attempts payload",
"raw": string(raw),
}}
}
return attempts
}
// rawJSONObject deserializes a JSON object payload for debug display.
// If the payload is malformed, it returns a map with "error" and "raw"
// keys preserving the original content for diagnostics. Callers that
// consume the result programmatically should check for the "error" key.
func rawJSONObject(raw json.RawMessage) map[string]any {
if len(raw) == 0 {
return nil
}
var object map[string]any
if err := json.Unmarshal(raw, &object); err != nil {
return map[string]any{
"error": "malformed debug payload",
"raw": string(raw),
}
}
return object
}
func nullRawJSONObject(raw pqtype.NullRawMessage) map[string]any {
if !raw.Valid {
return nil
}
return rawJSONObject(raw.RawMessage)
}
// ChatDebugRunSummary converts a database.ChatDebugRun to a
// codersdk.ChatDebugRunSummary.
func ChatDebugRunSummary(r database.ChatDebugRun) codersdk.ChatDebugRunSummary {
return codersdk.ChatDebugRunSummary{
ID: r.ID,
ChatID: r.ChatID,
Kind: codersdk.ChatDebugRunKind(r.Kind),
Status: codersdk.ChatDebugStatus(r.Status),
Provider: nullStringPtr(r.Provider),
Model: nullStringPtr(r.Model),
Summary: rawJSONObject(r.Summary),
StartedAt: r.StartedAt,
UpdatedAt: r.UpdatedAt,
FinishedAt: nullTimePtr(r.FinishedAt),
}
}
// ChatDebugStep converts a database.ChatDebugStep to a
// codersdk.ChatDebugStep.
func ChatDebugStep(s database.ChatDebugStep) codersdk.ChatDebugStep {
return codersdk.ChatDebugStep{
ID: s.ID,
RunID: s.RunID,
ChatID: s.ChatID,
StepNumber: s.StepNumber,
Operation: codersdk.ChatDebugStepOperation(s.Operation),
Status: codersdk.ChatDebugStatus(s.Status),
HistoryTipMessageID: nullInt64Ptr(s.HistoryTipMessageID),
AssistantMessageID: nullInt64Ptr(s.AssistantMessageID),
NormalizedRequest: rawJSONObject(s.NormalizedRequest),
NormalizedResponse: nullRawJSONObject(s.NormalizedResponse),
Usage: nullRawJSONObject(s.Usage),
Attempts: chatDebugAttempts(s.Attempts),
Error: nullRawJSONObject(s.Error),
Metadata: rawJSONObject(s.Metadata),
StartedAt: s.StartedAt,
UpdatedAt: s.UpdatedAt,
FinishedAt: nullTimePtr(s.FinishedAt),
}
}
// ChatRows converts a slice of database.GetChatsRow (which embeds
// Chat plus HasUnread) to codersdk.Chat, looking up diff statuses
// from the provided map. When diffStatusesByChatID is non-nil,
-225
View File
@@ -210,231 +210,6 @@ func TestTemplateVersionParameter_BadDescription(t *testing.T) {
req.NotEmpty(sdk.DescriptionPlaintext, "broke the markdown parser with %v", desc)
}
func TestChatDebugRunSummary(t *testing.T) {
t.Parallel()
startedAt := time.Now().UTC().Round(time.Second)
finishedAt := startedAt.Add(5 * time.Second)
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "chat_turn",
Status: "completed",
Provider: sql.NullString{String: "openai", Valid: true},
Model: sql.NullString{String: "gpt-4o", Valid: true},
Summary: json.RawMessage(`{"step_count":3,"has_error":false}`),
StartedAt: startedAt,
UpdatedAt: finishedAt,
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Equal(t, run.ID, sdk.ID)
require.Equal(t, run.ChatID, sdk.ChatID)
require.Equal(t, codersdk.ChatDebugRunKindChatTurn, sdk.Kind)
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
require.NotNil(t, sdk.Provider)
require.Equal(t, "openai", *sdk.Provider)
require.NotNil(t, sdk.Model)
require.Equal(t, "gpt-4o", *sdk.Model)
require.Equal(t, map[string]any{"step_count": float64(3), "has_error": false}, sdk.Summary)
require.Equal(t, startedAt, sdk.StartedAt)
require.Equal(t, finishedAt, sdk.UpdatedAt)
require.NotNil(t, sdk.FinishedAt)
require.Equal(t, finishedAt, *sdk.FinishedAt)
}
func TestChatDebugRunSummary_NullableFieldsNil(t *testing.T) {
t.Parallel()
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "title_generation",
Status: "in_progress",
Summary: json.RawMessage(`{}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Nil(t, sdk.Provider, "NULL Provider should map to nil")
require.Nil(t, sdk.Model, "NULL Model should map to nil")
require.Nil(t, sdk.FinishedAt, "NULL FinishedAt should map to nil")
}
func TestChatDebugStep(t *testing.T) {
t.Parallel()
startedAt := time.Now().UTC().Round(time.Second)
finishedAt := startedAt.Add(2 * time.Second)
attempts := json.RawMessage(`[
{
"attempt_number": 1,
"status": "completed",
"raw_request": {"url": "https://example.com"},
"raw_response": {"status": "200"},
"duration_ms": 123,
"started_at": "2026-03-01T10:00:01Z",
"finished_at": "2026-03-01T10:00:02Z"
}
]`)
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
Attempts: attempts,
Metadata: json.RawMessage(`{"provider":"openai"}`),
StartedAt: startedAt,
UpdatedAt: finishedAt,
FinishedAt: sql.NullTime{Time: finishedAt, Valid: true},
}
sdk := db2sdk.ChatDebugStep(step)
// Verify all scalar fields are mapped correctly.
require.Equal(t, step.ID, sdk.ID)
require.Equal(t, step.RunID, sdk.RunID)
require.Equal(t, step.ChatID, sdk.ChatID)
require.Equal(t, step.StepNumber, sdk.StepNumber)
require.Equal(t, codersdk.ChatDebugStepOperationStream, sdk.Operation)
require.Equal(t, codersdk.ChatDebugStatusCompleted, sdk.Status)
require.Equal(t, startedAt, sdk.StartedAt)
require.Equal(t, finishedAt, sdk.UpdatedAt)
require.Equal(t, &finishedAt, sdk.FinishedAt)
// Verify JSON object fields are deserialized.
require.NotNil(t, sdk.NormalizedRequest)
require.Equal(t, map[string]any{"messages": []any{}}, sdk.NormalizedRequest)
require.NotNil(t, sdk.Metadata)
require.Equal(t, map[string]any{"provider": "openai"}, sdk.Metadata)
// Verify nullable fields are nil when the DB row has NULL values.
require.Nil(t, sdk.HistoryTipMessageID, "NULL HistoryTipMessageID should map to nil")
require.Nil(t, sdk.AssistantMessageID, "NULL AssistantMessageID should map to nil")
require.Nil(t, sdk.NormalizedResponse, "NULL NormalizedResponse should map to nil")
require.Nil(t, sdk.Usage, "NULL Usage should map to nil")
require.Nil(t, sdk.Error, "NULL Error should map to nil")
// Verify attempts are preserved with all fields.
require.Len(t, sdk.Attempts, 1)
require.Equal(t, float64(1), sdk.Attempts[0]["attempt_number"])
require.Equal(t, "completed", sdk.Attempts[0]["status"])
require.Equal(t, float64(123), sdk.Attempts[0]["duration_ms"])
require.Equal(t, map[string]any{"url": "https://example.com"}, sdk.Attempts[0]["raw_request"])
require.Equal(t, map[string]any{"status": "200"}, sdk.Attempts[0]["raw_response"])
}
func TestChatDebugStep_NullableFieldsPopulated(t *testing.T) {
t.Parallel()
tipID := int64(42)
asstID := int64(99)
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 2,
Operation: "generate",
Status: "completed",
HistoryTipMessageID: sql.NullInt64{Int64: tipID, Valid: true},
AssistantMessageID: sql.NullInt64{Int64: asstID, Valid: true},
NormalizedRequest: json.RawMessage(`{}`),
NormalizedResponse: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"text":"hi"}`), Valid: true},
Usage: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"tokens":10}`), Valid: true},
Error: pqtype.NullRawMessage{RawMessage: json.RawMessage(`{"code":"rate_limit"}`), Valid: true},
Attempts: json.RawMessage(`[]`),
Metadata: json.RawMessage(`{}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.NotNil(t, sdk.HistoryTipMessageID)
require.Equal(t, tipID, *sdk.HistoryTipMessageID)
require.NotNil(t, sdk.AssistantMessageID)
require.Equal(t, asstID, *sdk.AssistantMessageID)
require.NotNil(t, sdk.NormalizedResponse)
require.Equal(t, map[string]any{"text": "hi"}, sdk.NormalizedResponse)
require.NotNil(t, sdk.Usage)
require.Equal(t, map[string]any{"tokens": float64(10)}, sdk.Usage)
require.NotNil(t, sdk.Error)
require.Equal(t, map[string]any{"code": "rate_limit"}, sdk.Error)
}
func TestChatDebugStep_PreservesMalformedAttempts(t *testing.T) {
t.Parallel()
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`{"messages":[]}`),
Attempts: json.RawMessage(`{"bad":true}`),
Metadata: json.RawMessage(`{"provider":"openai"}`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.Len(t, sdk.Attempts, 1)
require.Equal(t, "malformed attempts payload", sdk.Attempts[0]["error"])
require.Equal(t, `{"bad":true}`, sdk.Attempts[0]["raw"])
}
func TestChatDebugRunSummary_PreservesMalformedSummary(t *testing.T) {
t.Parallel()
run := database.ChatDebugRun{
ID: uuid.New(),
ChatID: uuid.New(),
Kind: "chat_turn",
Status: "completed",
Summary: json.RawMessage(`not-an-object`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugRunSummary(run)
require.Equal(t, "malformed debug payload", sdk.Summary["error"])
require.Equal(t, "not-an-object", sdk.Summary["raw"])
}
func TestChatDebugStep_PreservesMalformedRequest(t *testing.T) {
t.Parallel()
step := database.ChatDebugStep{
ID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 1,
Operation: "stream",
Status: "completed",
NormalizedRequest: json.RawMessage(`[1,2,3]`),
Attempts: json.RawMessage(`[]`),
Metadata: json.RawMessage(`"just-a-string"`),
StartedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
sdk := db2sdk.ChatDebugStep(step)
require.Equal(t, "malformed debug payload", sdk.NormalizedRequest["error"])
require.Equal(t, "[1,2,3]", sdk.NormalizedRequest["raw"])
require.Equal(t, "malformed debug payload", sdk.Metadata["error"])
require.Equal(t, `"just-a-string"`, sdk.Metadata["raw"])
}
func TestAIBridgeInterception(t *testing.T) {
t.Parallel()
+4 -204
View File
@@ -1708,17 +1708,6 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
}
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -1860,28 +1849,6 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.DeleteChatDebugDataAfterMessageID(ctx, arg)
}
func (q *querier) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return 0, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.DeleteChatDebugDataByChatID(ctx, chatID)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -2202,10 +2169,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
return 0, err
return err
}
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
}
@@ -2369,14 +2336,6 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context,
return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt)
}
func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
// Background sweep operates across all chats.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return database.FinalizeStaleChatDebugRowsRow{}, err
}
return q.db.FinalizeStaleChatDebugRows(ctx, updatedBefore)
}
func (q *querier) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
_, err := q.GetTemplateVersionByID(ctx, arg.TemplateVersionID)
if err != nil {
@@ -2454,10 +2413,6 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -2585,59 +2540,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
// The allow-users flag is a deployment-wide setting read by any
// authenticated chat user. We only require that an explicit actor
// is present in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDebugLoggingAllowUsers(ctx)
}
func (q *querier) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
run, err := q.db.GetChatDebugRunByID(ctx, id)
if err != nil {
return database.ChatDebugRun{}, err
}
// Authorize via the owning chat.
chat, err := q.db.GetChatByID(ctx, run.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return database.ChatDebugRun{}, err
}
return run, nil
}
func (q *querier) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return nil, err
}
return q.db.GetChatDebugRunsByChatID(ctx, arg)
}
func (q *querier) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
run, err := q.db.GetChatDebugRunByID(ctx, runID)
if err != nil {
return nil, err
}
// Authorize via the owning chat.
chat, err := q.db.GetChatByID(ctx, run.ChatID)
if err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, chat); err != nil {
return nil, err
}
return q.db.GetChatDebugStepsByRunID(ctx, runID)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
@@ -3484,11 +3386,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
return q.db.GetPRInsightsPerModel(ctx, arg)
}
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsPullRequests(ctx, arg)
return q.db.GetPRInsightsRecentPRs(ctx, arg)
}
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
@@ -4186,17 +4088,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
return false, err
}
if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil {
return false, err
}
return q.db.GetUserChatDebugLoggingEnabled(ctx, userID)
}
func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, userID)
if err != nil {
@@ -4943,33 +4834,6 @@ func (q *querier) InsertChat(ctx context.Context, arg database.InsertChatParams)
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()), q.db.InsertChat)(ctx, arg)
}
func (q *querier) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugRun{}, err
}
return q.db.InsertChatDebugRun(ctx, arg)
}
// InsertChatDebugStep creates a new step in a debug run. The underlying
// SQL uses INSERT ... SELECT ... FROM chat_debug_runs to enforce that the
// run exists and belongs to the specified chat. If the run_id is invalid
// or the chat_id doesn't match, the INSERT produces 0 rows and SQLC
// returns sql.ErrNoRows.
func (q *querier) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugStep{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugStep{}, err
}
return q.db.InsertChatDebugStep(ctx, arg)
}
func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
// Authorize create on chat resource scoped to the owner and org.
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
@@ -5864,17 +5728,6 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
}
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
@@ -5968,28 +5821,6 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugRun{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugRun{}, err
}
return q.db.UpdateChatDebugRun(ctx, arg)
}
func (q *querier) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatDebugStep{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatDebugStep{}, err
}
return q.db.UpdateChatDebugStep(ctx, arg)
}
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
@@ -6926,19 +6757,6 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
}
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
return err
}
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
}
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
if err != nil {
@@ -7222,13 +7040,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -7459,17 +7270,6 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error {
return q.db.UpsertTemplateUsageStats(ctx)
}
func (q *querier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil {
return err
}
return q.db.UpsertUserChatDebugLoggingEnabled(ctx, arg)
}
func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
u, err := q.db.GetUserByID(ctx, arg.UserID)
if err != nil {
+5 -130
View File
@@ -461,89 +461,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().DeleteChatQueuedMessage(gomock.Any(), args).Return(nil).AnyTimes()
check.Args(args).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatDebugDataAfterMessageID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.DeleteChatDebugDataAfterMessageIDParams{ChatID: chat.ID, MessageID: 123}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatDebugDataAfterMessageID(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("DeleteChatDebugDataByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatDebugDataByChatID(gomock.Any(), chat.ID).Return(int64(1), nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
}))
s.Run("FinalizeStaleChatDebugRows", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
updatedBefore := dbtime.Now()
row := database.FinalizeStaleChatDebugRowsRow{RunsFinalized: 1, StepsFinalized: 2}
dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), updatedBefore).Return(row, nil).AnyTimes()
check.Args(updatedBefore).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row)
}))
s.Run("GetChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).AnyTimes()
check.Args().Asserts().Returns(true)
}))
s.Run("GetChatDebugRunByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(run)
}))
s.Run("GetChatDebugRunsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
runs := []database.ChatDebugRun{{ID: uuid.New(), ChatID: chat.ID}}
arg := database.GetChatDebugRunsByChatIDParams{ChatID: chat.ID, LimitVal: 100}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDebugRunsByChatID(gomock.Any(), arg).Return(runs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(runs)
}))
s.Run("GetChatDebugStepsByRunID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
steps := []database.ChatDebugStep{{ID: uuid.New(), RunID: run.ID, ChatID: chat.ID}}
dbm.EXPECT().GetChatDebugRunByID(gomock.Any(), run.ID).Return(run, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), run.ID).Return(steps, nil).AnyTimes()
check.Args(run.ID).Asserts(chat, policy.ActionRead).Returns(steps)
}))
s.Run("InsertChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.InsertChatDebugRunParams{ChatID: chat.ID, Kind: "chat_turn", Status: "in_progress"}
run := database.ChatDebugRun{ID: uuid.New(), ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
}))
s.Run("InsertChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.InsertChatDebugStepParams{RunID: uuid.New(), ChatID: chat.ID, StepNumber: 1, Operation: "stream", Status: "in_progress"}
step := database.ChatDebugStep{ID: uuid.New(), RunID: arg.RunID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
}))
s.Run("UpdateChatDebugRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatDebugRunParams{ID: uuid.New(), ChatID: chat.ID}
run := database.ChatDebugRun{ID: arg.ID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run)
}))
s.Run("UpdateChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatDebugStepParams{ID: uuid.New(), ChatID: chat.ID}
step := database.ChatDebugStep{ID: arg.ID, ChatID: chat.ID}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatDebugStep(gomock.Any(), arg).Return(step, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(step)
}))
s.Run("UpsertChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDebugLoggingAllowUsers(gomock.Any(), true).Return(nil).AnyTimes()
check.Args(true).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetChatByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
@@ -561,24 +478,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
agentID := uuid.New()
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
}))
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerChatParams{
OwnerID: uuid.New(),
@@ -2344,9 +2243,9 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsPullRequestsParams{}
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsRecentPRsParams{}
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
@@ -2577,19 +2476,6 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key)
}))
s.Run("GetUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), u.ID).Return(true, nil).AnyTimes()
check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns(true)
}))
s.Run("UpsertUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
arg := database.UpsertUserChatDebugLoggingEnabledParams{UserID: u.ID, DebugLoggingEnabled: true}
dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes()
dbm.EXPECT().UpsertUserChatDebugLoggingEnabled(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdatePersonal)
}))
s.Run("UpdateUserChatCustomPrompt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
uc := database.UserConfig{UserID: u.ID, Key: "chat_custom_prompt", Value: "my custom prompt"}
@@ -3031,17 +2917,6 @@ func (s *MethodTestSuite) TestWorkspace() {
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
}))
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
ID: agt.ID,
Directory: "/workspaces/project",
}
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
}))
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
@@ -5538,10 +5413,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
Returns(int64(1))
Returns()
}))
}
+7 -151
View File
@@ -280,14 +280,6 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
return r0
}
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
return r0
}
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
@@ -416,22 +408,6 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteChatDebugDataAfterMessageID(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatDebugDataAfterMessageID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataAfterMessageID").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteChatDebugDataByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteChatDebugDataByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatDebugDataByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
@@ -752,12 +728,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
start := time.Now()
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
return r0, r1
return r0
}
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
@@ -888,14 +864,6 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.
return r0, r1
}
func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
start := time.Now()
r0, r1 := m.s.FinalizeStaleChatDebugRows(ctx, updatedBefore)
m.queryLatencies.WithLabelValues("FinalizeStaleChatDebugRows").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "FinalizeStaleChatDebugRows").Inc()
return r0, r1
}
func (m queryMetricsStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.FindMatchingPresetID(ctx, arg)
@@ -1000,14 +968,6 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
return r0, r1
}
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
start := time.Now()
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
@@ -1152,38 +1112,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugLoggingAllowUsers(ctx)
m.queryLatencies.WithLabelValues("GetChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugLoggingAllowUsers").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugRunByID(ctx, id)
m.queryLatencies.WithLabelValues("GetChatDebugRunByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugRunsByChatID(ctx context.Context, chatID database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugRunsByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("GetChatDebugRunsByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugRunsByChatID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.GetChatDebugStepsByRunID(ctx, runID)
m.queryLatencies.WithLabelValues("GetChatDebugStepsByRunID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDebugStepsByRunID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
@@ -2048,11 +1976,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
return r0, r1
}
@@ -2672,14 +2600,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatDebugLoggingEnabled(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatDebugLoggingEnabled").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID)
@@ -3376,22 +3296,6 @@ func (m queryMetricsStore) InsertChat(ctx context.Context, arg database.InsertCh
return r0, r1
}
func (m queryMetricsStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.InsertChatDebugRun(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatDebugRun").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugRun").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.InsertChatDebugStep(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatDebugStep").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatDebugStep").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
start := time.Now()
r0, r1 := m.s.InsertChatFile(ctx, arg)
@@ -4200,14 +4104,6 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
return r0
}
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
return r0
}
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
start := time.Now()
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
@@ -4288,22 +4184,6 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatDebugRun(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatDebugRun").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugRun").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatDebugStep(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatDebugStep").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatDebugStep").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
@@ -4936,14 +4816,6 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
return r0
}
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
start := time.Now()
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
return r0
}
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
start := time.Now()
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
@@ -5144,14 +5016,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
start := time.Now()
r0 := m.s.UpsertChatDebugLoggingAllowUsers(ctx, allowUsers)
m.queryLatencies.WithLabelValues("UpsertChatDebugLoggingAllowUsers").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDebugLoggingAllowUsers").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
start := time.Now()
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
@@ -5384,14 +5248,6 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
start := time.Now()
r0 := m.s.UpsertUserChatDebugLoggingEnabled(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertUserChatDebugLoggingEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatDebugLoggingEnabled").Inc()
return r0
}
func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
start := time.Now()
r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg)
+10 -276
View File
@@ -363,20 +363,6 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
}
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
}
// CountAIBridgeInterceptions mocks base method.
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -671,36 +657,6 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
}
// DeleteChatDebugDataAfterMessageID mocks base method.
func (m *MockStore) DeleteChatDebugDataAfterMessageID(ctx context.Context, arg database.DeleteChatDebugDataAfterMessageIDParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatDebugDataAfterMessageID", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteChatDebugDataAfterMessageID indicates an expected call of DeleteChatDebugDataAfterMessageID.
func (mr *MockStoreMockRecorder) DeleteChatDebugDataAfterMessageID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataAfterMessageID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataAfterMessageID), ctx, arg)
}
// DeleteChatDebugDataByChatID mocks base method.
func (m *MockStore) DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatDebugDataByChatID", ctx, chatID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteChatDebugDataByChatID indicates an expected call of DeleteChatDebugDataByChatID.
func (mr *MockStoreMockRecorder) DeleteChatDebugDataByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatDebugDataByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatDebugDataByChatID), ctx, chatID)
}
// DeleteChatModelConfigByID mocks base method.
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -1274,12 +1230,11 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
}
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
@@ -1517,21 +1472,6 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchVolumesResourceMonitorsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).FetchVolumesResourceMonitorsUpdatedAfter), ctx, updatedAt)
}
// FinalizeStaleChatDebugRows mocks base method.
func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, updatedBefore)
ret0, _ := ret[0].(database.FinalizeStaleChatDebugRowsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FinalizeStaleChatDebugRows indicates an expected call of FinalizeStaleChatDebugRows.
func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, updatedBefore any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, updatedBefore)
}
// FindMatchingPresetID mocks base method.
func (m *MockStore) FindMatchingPresetID(ctx context.Context, arg database.FindMatchingPresetIDParams) (uuid.UUID, error) {
m.ctrl.T.Helper()
@@ -1727,21 +1667,6 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
}
// GetActiveChatsByAgentID mocks base method.
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
}
// GetActivePresetPrebuildSchedules mocks base method.
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
m.ctrl.T.Helper()
@@ -2117,66 +2042,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDebugLoggingAllowUsers mocks base method.
func (m *MockStore) GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugLoggingAllowUsers", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugLoggingAllowUsers indicates an expected call of GetChatDebugLoggingAllowUsers.
func (mr *MockStoreMockRecorder) GetChatDebugLoggingAllowUsers(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).GetChatDebugLoggingAllowUsers), ctx)
}
// GetChatDebugRunByID mocks base method.
func (m *MockStore) GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugRunByID", ctx, id)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugRunByID indicates an expected call of GetChatDebugRunByID.
func (mr *MockStoreMockRecorder) GetChatDebugRunByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunByID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunByID), ctx, id)
}
// GetChatDebugRunsByChatID mocks base method.
func (m *MockStore) GetChatDebugRunsByChatID(ctx context.Context, arg database.GetChatDebugRunsByChatIDParams) ([]database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugRunsByChatID", ctx, arg)
ret0, _ := ret[0].([]database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugRunsByChatID indicates an expected call of GetChatDebugRunsByChatID.
func (mr *MockStoreMockRecorder) GetChatDebugRunsByChatID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugRunsByChatID", reflect.TypeOf((*MockStore)(nil).GetChatDebugRunsByChatID), ctx, arg)
}
// GetChatDebugStepsByRunID mocks base method.
func (m *MockStore) GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDebugStepsByRunID", ctx, runID)
ret0, _ := ret[0].([]database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDebugStepsByRunID indicates an expected call of GetChatDebugStepsByRunID.
func (mr *MockStoreMockRecorder) GetChatDebugStepsByRunID(ctx, runID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDebugStepsByRunID", reflect.TypeOf((*MockStore)(nil).GetChatDebugStepsByRunID), ctx, runID)
}
// GetChatDesktopEnabled mocks base method.
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
@@ -3797,19 +3662,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
}
// GetPRInsightsPullRequests mocks base method.
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
// GetPRInsightsRecentPRs mocks base method.
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
}
// GetPRInsightsSummary mocks base method.
@@ -4997,21 +4862,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatDebugLoggingEnabled mocks base method.
func (m *MockStore) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatDebugLoggingEnabled", ctx, userID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatDebugLoggingEnabled indicates an expected call of GetUserChatDebugLoggingEnabled.
func (mr *MockStoreMockRecorder) GetUserChatDebugLoggingEnabled(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).GetUserChatDebugLoggingEnabled), ctx, userID)
}
// GetUserChatProviderKeys mocks base method.
func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
@@ -6331,36 +6181,6 @@ func (mr *MockStoreMockRecorder) InsertChat(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChat", reflect.TypeOf((*MockStore)(nil).InsertChat), ctx, arg)
}
// InsertChatDebugRun mocks base method.
func (m *MockStore) InsertChatDebugRun(ctx context.Context, arg database.InsertChatDebugRunParams) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatDebugRun", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatDebugRun indicates an expected call of InsertChatDebugRun.
func (mr *MockStoreMockRecorder) InsertChatDebugRun(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugRun", reflect.TypeOf((*MockStore)(nil).InsertChatDebugRun), ctx, arg)
}
// InsertChatDebugStep mocks base method.
func (m *MockStore) InsertChatDebugStep(ctx context.Context, arg database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatDebugStep", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatDebugStep indicates an expected call of InsertChatDebugStep.
func (mr *MockStoreMockRecorder) InsertChatDebugStep(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatDebugStep", reflect.TypeOf((*MockStore)(nil).InsertChatDebugStep), ctx, arg)
}
// InsertChatFile mocks base method.
func (m *MockStore) InsertChatFile(ctx context.Context, arg database.InsertChatFileParams) (database.InsertChatFileRow, error) {
m.ctrl.T.Helper()
@@ -7960,20 +7780,6 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
}
// SoftDeleteContextFileMessages mocks base method.
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
}
// TryAcquireLock mocks base method.
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
m.ctrl.T.Helper()
@@ -8119,36 +7925,6 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatDebugRun mocks base method.
func (m *MockStore) UpdateChatDebugRun(ctx context.Context, arg database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatDebugRun", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugRun)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatDebugRun indicates an expected call of UpdateChatDebugRun.
func (mr *MockStoreMockRecorder) UpdateChatDebugRun(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugRun", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugRun), ctx, arg)
}
// UpdateChatDebugStep mocks base method.
func (m *MockStore) UpdateChatDebugStep(ctx context.Context, arg database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatDebugStep", ctx, arg)
ret0, _ := ret[0].(database.ChatDebugStep)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatDebugStep indicates an expected call of UpdateChatDebugStep.
func (mr *MockStoreMockRecorder) UpdateChatDebugStep(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatDebugStep", reflect.TypeOf((*MockStore)(nil).UpdateChatDebugStep), ctx, arg)
}
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
@@ -9300,20 +9076,6 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
}
// UpdateWorkspaceAgentDirectoryByID mocks base method.
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
}
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
m.ctrl.T.Helper()
@@ -9669,20 +9431,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDebugLoggingAllowUsers mocks base method.
func (m *MockStore) UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDebugLoggingAllowUsers", ctx, allowUsers)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatDebugLoggingAllowUsers indicates an expected call of UpsertChatDebugLoggingAllowUsers.
func (mr *MockStoreMockRecorder) UpsertChatDebugLoggingAllowUsers(ctx, allowUsers any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDebugLoggingAllowUsers", reflect.TypeOf((*MockStore)(nil).UpsertChatDebugLoggingAllowUsers), ctx, allowUsers)
}
// UpsertChatDesktopEnabled mocks base method.
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
m.ctrl.T.Helper()
@@ -10100,20 +9848,6 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx)
}
// UpsertUserChatDebugLoggingEnabled mocks base method.
func (m *MockStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertUserChatDebugLoggingEnabled", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertUserChatDebugLoggingEnabled indicates an expected call of UpsertUserChatDebugLoggingEnabled.
func (mr *MockStoreMockRecorder) UpsertUserChatDebugLoggingEnabled(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatDebugLoggingEnabled", reflect.TypeOf((*MockStore)(nil).UpsertUserChatDebugLoggingEnabled), ctx, arg)
}
// UpsertUserChatProviderKey mocks base method.
func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) {
m.ctrl.T.Helper()
+2 -69
View File
@@ -1255,44 +1255,6 @@ COMMENT ON COLUMN boundary_usage_stats.window_start IS 'Start of the time window
COMMENT ON COLUMN boundary_usage_stats.updated_at IS 'Timestamp of the last update to this row.';
CREATE TABLE chat_debug_runs (
id uuid DEFAULT gen_random_uuid() NOT NULL,
chat_id uuid NOT NULL,
root_chat_id uuid,
parent_chat_id uuid,
model_config_id uuid,
trigger_message_id bigint,
history_tip_message_id bigint,
kind text NOT NULL,
status text NOT NULL,
provider text,
model text,
summary jsonb DEFAULT '{}'::jsonb NOT NULL,
started_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
finished_at timestamp with time zone
);
CREATE TABLE chat_debug_steps (
id uuid DEFAULT gen_random_uuid() NOT NULL,
run_id uuid NOT NULL,
chat_id uuid NOT NULL,
step_number integer NOT NULL,
operation text NOT NULL,
status text NOT NULL,
history_tip_message_id bigint,
assistant_message_id bigint,
normalized_request jsonb NOT NULL,
normalized_response jsonb,
usage jsonb,
attempts jsonb DEFAULT '[]'::jsonb NOT NULL,
error jsonb,
metadata jsonb DEFAULT '{}'::jsonb NOT NULL,
started_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
finished_at timestamp with time zone
);
CREATE TABLE chat_diff_statuses (
chat_id uuid NOT NULL,
url text,
@@ -3397,12 +3359,6 @@ ALTER TABLE ONLY audit_logs
ALTER TABLE ONLY boundary_usage_stats
ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
ALTER TABLE ONLY chat_debug_runs
ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
@@ -3797,20 +3753,6 @@ CREATE INDEX idx_audit_log_user_id ON audit_logs USING btree (user_id);
CREATE INDEX idx_audit_logs_time_desc ON audit_logs USING btree ("time" DESC);
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs USING btree (chat_id, started_at DESC);
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs USING btree (updated_at) WHERE (finished_at IS NULL);
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps USING btree (chat_id, assistant_message_id) WHERE (assistant_message_id IS NOT NULL);
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps USING btree (chat_id, history_tip_message_id);
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps USING btree (updated_at) WHERE (finished_at IS NULL);
CREATE INDEX idx_chat_diff_statuses_stale_at ON chat_diff_statuses USING btree (stale_at);
CREATE INDEX idx_chat_file_links_chat_id ON chat_file_links USING btree (chat_id);
@@ -3841,14 +3783,14 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
@@ -4114,12 +4056,6 @@ ALTER TABLE ONLY aibridge_interceptions
ALTER TABLE ONLY api_keys
ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_debug_runs
ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_diff_statuses
ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -4192,9 +4128,6 @@ ALTER TABLE ONLY connection_logs
ALTER TABLE ONLY crypto_keys
ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY chat_debug_steps
ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
ALTER TABLE ONLY oauth2_provider_app_tokens
ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
@@ -9,8 +9,6 @@ const (
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatDebugRunsChatID ForeignKeyConstraint = "chat_debug_runs_chat_id_fkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatDebugStepsChatID ForeignKeyConstraint = "chat_debug_steps_chat_id_fkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksChatID ForeignKeyConstraint = "chat_file_links_chat_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
@@ -35,7 +33,6 @@ const (
ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyConnectionLogsWorkspaceOwnerID ForeignKeyConstraint = "connection_logs_workspace_owner_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_owner_id_fkey FOREIGN KEY (workspace_owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyCryptoKeysSecretKeyID ForeignKeyConstraint = "crypto_keys_secret_key_id_fkey" // ALTER TABLE ONLY crypto_keys ADD CONSTRAINT crypto_keys_secret_key_id_fkey FOREIGN KEY (secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyFkChatDebugStepsRunChat ForeignKeyConstraint = "fk_chat_debug_steps_run_chat" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT fk_chat_debug_steps_run_chat FOREIGN KEY (run_id, chat_id) REFERENCES chat_debug_runs(id, chat_id) ON DELETE CASCADE;
ForeignKeyFkOauth2ProviderAppTokensUserID ForeignKeyConstraint = "fk_oauth2_provider_app_tokens_user_id" // ALTER TABLE ONLY oauth2_provider_app_tokens ADD CONSTRAINT fk_oauth2_provider_app_tokens_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyGitAuthLinksOauthAccessTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyGitAuthLinksOauthRefreshTokenKeyID ForeignKeyConstraint = "git_auth_links_oauth_refresh_token_key_id_fkey" // ALTER TABLE ONLY external_auth_links ADD CONSTRAINT git_auth_links_oauth_refresh_token_key_id_fkey FOREIGN KEY (oauth_refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
@@ -1 +0,0 @@
DROP INDEX IF EXISTS idx_chats_agent_id;
@@ -1 +0,0 @@
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
@@ -1 +0,0 @@
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
@@ -1,5 +0,0 @@
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
-- This index was purpose-built for the old sort and no longer provides
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
@@ -1,2 +0,0 @@
DROP TABLE IF EXISTS chat_debug_steps;
DROP TABLE IF EXISTS chat_debug_runs;
@@ -1,59 +0,0 @@
CREATE TABLE chat_debug_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
-- root_chat_id and parent_chat_id are intentionally NOT
-- foreign-keyed to chats(id). They are snapshot values that
-- record the subchat hierarchy at run time. The referenced
-- chat may be archived or deleted independently, and we want
-- to preserve the historical lineage in debug rows rather
-- than cascade-delete them.
root_chat_id UUID,
parent_chat_id UUID,
model_config_id UUID,
trigger_message_id BIGINT,
history_tip_message_id BIGINT,
kind TEXT NOT NULL,
status TEXT NOT NULL,
provider TEXT,
model TEXT,
summary JSONB NOT NULL DEFAULT '{}'::jsonb,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ
);
CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs(id, chat_id);
CREATE INDEX idx_chat_debug_runs_chat_started ON chat_debug_runs(chat_id, started_at DESC);
CREATE TABLE chat_debug_steps (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
run_id UUID NOT NULL,
chat_id UUID NOT NULL REFERENCES chats(id) ON DELETE CASCADE,
step_number INT NOT NULL,
operation TEXT NOT NULL,
status TEXT NOT NULL,
history_tip_message_id BIGINT,
assistant_message_id BIGINT,
normalized_request JSONB NOT NULL,
normalized_response JSONB,
usage JSONB,
attempts JSONB NOT NULL DEFAULT '[]'::jsonb,
error JSONB,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
finished_at TIMESTAMPTZ,
CONSTRAINT fk_chat_debug_steps_run_chat
FOREIGN KEY (run_id, chat_id)
REFERENCES chat_debug_runs(id, chat_id)
ON DELETE CASCADE
);
CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps(run_id, step_number);
CREATE INDEX idx_chat_debug_steps_chat_tip ON chat_debug_steps(chat_id, history_tip_message_id);
-- Supports DeleteChatDebugDataAfterMessageID assistant_message_id branch.
CREATE INDEX idx_chat_debug_steps_chat_assistant_msg ON chat_debug_steps(chat_id, assistant_message_id) WHERE assistant_message_id IS NOT NULL;
-- Supports FinalizeStaleChatDebugRows worker query.
CREATE INDEX idx_chat_debug_runs_stale ON chat_debug_runs(updated_at) WHERE finished_at IS NULL;
CREATE INDEX idx_chat_debug_steps_stale ON chat_debug_steps(updated_at) WHERE finished_at IS NULL;
@@ -1,65 +0,0 @@
INSERT INTO chat_debug_runs (
id,
chat_id,
model_config_id,
history_tip_message_id,
kind,
status,
provider,
model,
summary,
started_at,
updated_at,
finished_at
) VALUES (
'c98518f8-9fb3-458b-a642-57552af1db63',
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
'9af5f8d5-6a57-4505-8a69-3d6c787b95fd',
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
'chat_turn',
'completed',
'openai',
'gpt-5.2',
'{"step_count":1,"has_error":false}'::jsonb,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:01+00',
'2024-01-01 00:00:01+00'
);
INSERT INTO chat_debug_steps (
id,
run_id,
chat_id,
step_number,
operation,
status,
history_tip_message_id,
assistant_message_id,
normalized_request,
normalized_response,
usage,
attempts,
error,
metadata,
started_at,
updated_at,
finished_at
) VALUES (
'59471c60-7851-4fa6-bf05-e21dd939721f',
'c98518f8-9fb3-458b-a642-57552af1db63',
'72c0438a-18eb-4688-ab80-e4c6a126ef96',
1,
'stream',
'completed',
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
(SELECT MAX(id) FROM chat_messages WHERE chat_id = '72c0438a-18eb-4688-ab80-e4c6a126ef96'),
'{"messages":[]}'::jsonb,
'{"finish_reason":"stop"}'::jsonb,
'{"input_tokens":1,"output_tokens":1}'::jsonb,
'[]'::jsonb,
NULL,
'{"provider":"openai"}'::jsonb,
'2024-01-01 00:00:00+00',
'2024-01-01 00:00:01+00',
'2024-01-01 00:00:01+00'
);
-38
View File
@@ -4248,44 +4248,6 @@ type Chat struct {
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
}
type ChatDebugRun struct {
ID uuid.UUID `db:"id" json:"id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
RootChatID uuid.NullUUID `db:"root_chat_id" json:"root_chat_id"`
ParentChatID uuid.NullUUID `db:"parent_chat_id" json:"parent_chat_id"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
TriggerMessageID sql.NullInt64 `db:"trigger_message_id" json:"trigger_message_id"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
Kind string `db:"kind" json:"kind"`
Status string `db:"status" json:"status"`
Provider sql.NullString `db:"provider" json:"provider"`
Model sql.NullString `db:"model" json:"model"`
Summary json.RawMessage `db:"summary" json:"summary"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
}
type ChatDebugStep struct {
ID uuid.UUID `db:"id" json:"id"`
RunID uuid.UUID `db:"run_id" json:"run_id"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
StepNumber int32 `db:"step_number" json:"step_number"`
Operation string `db:"operation" json:"operation"`
Status string `db:"status" json:"status"`
HistoryTipMessageID sql.NullInt64 `db:"history_tip_message_id" json:"history_tip_message_id"`
AssistantMessageID sql.NullInt64 `db:"assistant_message_id" json:"assistant_message_id"`
NormalizedRequest json.RawMessage `db:"normalized_request" json:"normalized_request"`
NormalizedResponse pqtype.NullRawMessage `db:"normalized_response" json:"normalized_response"`
Usage pqtype.NullRawMessage `db:"usage" json:"usage"`
Attempts json.RawMessage `db:"attempts" json:"attempts"`
Error pqtype.NullRawMessage `db:"error" json:"error"`
Metadata json.RawMessage `db:"metadata" json:"metadata"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"`
}
type ChatDiffStatus struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Url sql.NullString `db:"url" json:"url"`
+749
View File
@@ -0,0 +1,749 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"errors"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/quartz"
)
const (
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
// publishes wait before a scheduled flush when nearby publishes do not
// naturally coalesce sooner.
DefaultBatchingFlushInterval = 50 * time.Millisecond
// DefaultBatchingQueueSize is the default number of buffered chatd publish
// requests waiting to be flushed.
DefaultBatchingQueueSize = 8192
defaultBatchingPressureWait = 10 * time.Millisecond
defaultBatchingFinalFlushLimit = 15 * time.Second
batchingWarnInterval = 10 * time.Second
batchFlushScheduled = "scheduled"
batchFlushShutdown = "shutdown"
batchFlushStageNone = "none"
batchFlushStageBegin = "begin"
batchFlushStageExec = "exec"
batchFlushStageCommit = "commit"
batchDelegateFallbackReasonQueueFull = "queue_full"
batchDelegateFallbackReasonFlushError = "flush_error"
batchChannelClassStreamNotify = "stream_notify"
batchChannelClassOwnerEvent = "owner_event"
batchChannelClassConfigChange = "config_change"
batchChannelClassOther = "other"
)
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
// attempted after shutdown has started.
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
// backpressure) and drains everything currently queued into a single
// transaction. There is no fixed batch-size knob — the batch size is simply
// whatever accumulated since the last flush, which naturally adapts to load.
type BatchingConfig struct {
FlushInterval time.Duration
QueueSize int
PressureWait time.Duration
FinalFlushTimeout time.Duration
Clock quartz.Clock
}
type queuedPublish struct {
event string
channelClass string
message []byte
}
type batchSender interface {
Flush(ctx context.Context, batch []queuedPublish) error
Close() error
}
type batchFlushError struct {
stage string
err error
}
func (e *batchFlushError) Error() string {
return e.err.Error()
}
func (e *batchFlushError) Unwrap() error {
return e.err
}
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
// sender connection while delegating subscribe behavior to the shared listener
// pubsub instance.
type BatchingPubsub struct {
logger slog.Logger
delegate *PGPubsub
// sender is only accessed from the run() goroutine (including
// flushBatch and resetSender which it calls). Do not read or
// write this field from Publish or any other goroutine.
sender batchSender
newSender func(context.Context) (batchSender, error)
clock quartz.Clock
publishCh chan queuedPublish
flushCh chan struct{}
closeCh chan struct{}
doneCh chan struct{}
spaceMu sync.Mutex
spaceSignal chan struct{}
warnTicker *quartz.Ticker
flushInterval time.Duration
pressureWait time.Duration
finalFlushTimeout time.Duration
queuedCount atomic.Int64
closed atomic.Bool
closeOnce sync.Once
closeErr error
runErr error
runCtx context.Context
cancel context.CancelFunc
metrics batchingMetrics
}
type batchingMetrics struct {
QueueDepth prometheus.Gauge
BatchSize prometheus.Histogram
FlushDuration *prometheus.HistogramVec
DelegateFallbacksTotal *prometheus.CounterVec
SenderResetsTotal prometheus.Counter
SenderResetFailuresTotal prometheus.Counter
}
func newBatchingMetrics() batchingMetrics {
return batchingMetrics{
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_queue_depth",
Help: "The number of chatd notifications waiting in the batching queue.",
}),
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_size",
Help: "The number of logical notifications sent in each chatd batch flush.",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
}),
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_flush_duration_seconds",
Help: "The time spent flushing one chatd batch to PostgreSQL.",
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
}, []string{"reason"}),
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_delegate_fallbacks_total",
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
}, []string{"channel_class", "reason", "stage"}),
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_resets_total",
Help: "The number of successful batched pubsub sender resets after flush failures.",
}),
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_reset_failures_total",
Help: "The number of batched pubsub sender reset attempts that failed.",
}),
}
}
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
m.QueueDepth.Describe(descs)
m.BatchSize.Describe(descs)
m.FlushDuration.Describe(descs)
m.DelegateFallbacksTotal.Describe(descs)
m.SenderResetsTotal.Describe(descs)
m.SenderResetFailuresTotal.Describe(descs)
}
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
m.QueueDepth.Collect(metrics)
m.BatchSize.Collect(metrics)
m.FlushDuration.Collect(metrics)
m.DelegateFallbacksTotal.Collect(metrics)
m.SenderResetsTotal.Collect(metrics)
m.SenderResetFailuresTotal.Collect(metrics)
}
// NewBatching creates a chatd-specific batched pubsub wrapper around the
// shared PostgreSQL listener implementation.
func NewBatching(
ctx context.Context,
logger slog.Logger,
delegate *PGPubsub,
prototype *sql.DB,
connectURL string,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if prototype == nil {
return nil, xerrors.New("prototype database is nil")
}
if connectURL == "" {
return nil, xerrors.New("connect URL is empty")
}
newSender := func(ctx context.Context) (batchSender, error) {
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
}
sender, err := newSender(ctx)
if err != nil {
return nil, err
}
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
if err != nil {
_ = sender.Close()
return nil, err
}
ps.newSender = newSender
return ps, nil
}
func newBatchingPubsub(
logger slog.Logger,
delegate *PGPubsub,
sender batchSender,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if sender == nil {
return nil, xerrors.New("batch sender is nil")
}
flushInterval := cfg.FlushInterval
if flushInterval == 0 {
flushInterval = DefaultBatchingFlushInterval
}
if flushInterval < 0 {
return nil, xerrors.New("flush interval must be positive")
}
queueSize := cfg.QueueSize
if queueSize == 0 {
queueSize = DefaultBatchingQueueSize
}
if queueSize < 0 {
return nil, xerrors.New("queue size must be positive")
}
pressureWait := cfg.PressureWait
if pressureWait == 0 {
pressureWait = defaultBatchingPressureWait
}
if pressureWait < 0 {
return nil, xerrors.New("pressure wait must be positive")
}
finalFlushTimeout := cfg.FinalFlushTimeout
if finalFlushTimeout == 0 {
finalFlushTimeout = defaultBatchingFinalFlushLimit
}
if finalFlushTimeout < 0 {
return nil, xerrors.New("final flush timeout must be positive")
}
clock := cfg.Clock
if clock == nil {
clock = quartz.NewReal()
}
runCtx, cancel := context.WithCancel(context.Background())
ps := &BatchingPubsub{
logger: logger,
delegate: delegate,
sender: sender,
clock: clock,
publishCh: make(chan queuedPublish, queueSize),
flushCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
spaceSignal: make(chan struct{}),
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
flushInterval: flushInterval,
pressureWait: pressureWait,
finalFlushTimeout: finalFlushTimeout,
runCtx: runCtx,
cancel: cancel,
metrics: newBatchingMetrics(),
}
ps.metrics.QueueDepth.Set(0)
go ps.run()
return ps, nil
}
// Describe implements prometheus.Collector.
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
p.metrics.Describe(descs)
}
// Collect implements prometheus.Collector.
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
p.metrics.Collect(metrics)
}
// Subscribe delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
return p.delegate.Subscribe(event, listener)
}
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
return p.delegate.SubscribeWithErr(event, listener)
}
// Publish enqueues a logical notification for asynchronous batched delivery.
func (p *BatchingPubsub) Publish(event string, message []byte) error {
channelClass := batchChannelClass(event)
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
req := queuedPublish{
event: event,
channelClass: channelClass,
message: bytes.Clone(message),
}
if p.tryEnqueue(req) {
return nil
}
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
defer timer.Stop("pubsubBatcher", "pressureWait")
for {
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
p.signalPressureFlush()
spaceSignal := p.currentSpaceSignal()
if p.tryEnqueue(req) {
return nil
}
select {
case <-spaceSignal:
continue
case <-timer.C:
if p.tryEnqueue(req) {
return nil
}
// The batching queue is still full after a pressure
// flush and brief wait. Fall back to the shared
// pubsub pool so the notification is still delivered
// rather than dropped.
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
p.logPublishRejection(event)
return p.delegate.Publish(event, message)
case <-p.doneCh:
return ErrBatchingPubsubClosed
}
}
}
// Close stops accepting new publishes, performs a bounded best-effort drain,
// and then closes the dedicated sender connection.
func (p *BatchingPubsub) Close() error {
p.closeOnce.Do(func() {
p.closed.Store(true)
p.cancel()
p.notifySpaceAvailable()
close(p.closeCh)
<-p.doneCh
p.closeErr = p.runErr
})
return p.closeErr
}
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
if p.closed.Load() {
return false
}
select {
case p.publishCh <- req:
queuedDepth := p.queuedCount.Add(1)
p.observeQueueDepth(queuedDepth)
return true
default:
return false
}
}
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
p.metrics.QueueDepth.Set(float64(depth))
}
func (p *BatchingPubsub) signalPressureFlush() {
select {
case p.flushCh <- struct{}{}:
default:
}
}
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
return p.spaceSignal
}
func (p *BatchingPubsub) notifySpaceAvailable() {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
close(p.spaceSignal)
p.spaceSignal = make(chan struct{})
}
func batchChannelClass(event string) string {
switch {
case strings.HasPrefix(event, "chat:stream:"):
return batchChannelClassStreamNotify
case strings.HasPrefix(event, "chat:owner:"):
return batchChannelClassOwnerEvent
case event == "chat:config_change":
return batchChannelClassConfigChange
default:
return batchChannelClassOther
}
}
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
}
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
if len(batch) == 0 {
return
}
counts := make(map[string]int)
for _, item := range batch {
counts[item.channelClass]++
}
for channelClass, count := range counts {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
}
}
func batchFlushStage(err error) string {
var flushErr *batchFlushError
if errors.As(err, &flushErr) {
return flushErr.stage
}
return "unknown"
}
func (p *BatchingPubsub) run() {
defer close(p.doneCh)
defer p.warnTicker.Stop("pubsubBatcher", "warn")
batch := make([]queuedPublish, 0, 64)
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
defer timer.Stop("pubsubBatcher", "scheduledFlush")
flush := func(reason string) {
batch = p.drainIntoBatch(batch)
batch, _ = p.flushBatch(p.runCtx, batch, reason)
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
}
for {
select {
case item := <-p.publishCh:
// An item arrived before the timer fired. Append it and
// let the timer or pressure signal trigger the actual
// flush so that nearby publishes coalesce naturally.
batch = append(batch, item)
p.notifySpaceAvailable()
case <-timer.C:
flush(batchFlushScheduled)
case <-p.flushCh:
flush("pressure")
case <-p.closeCh:
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
return
}
}
}
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
drained := false
for {
select {
case item := <-p.publishCh:
batch = append(batch, item)
drained = true
default:
if drained {
p.notifySpaceAvailable()
}
return batch
}
}
}
func (p *BatchingPubsub) flushBatch(
ctx context.Context,
batch []queuedPublish,
reason string,
) ([]queuedPublish, error) {
if len(batch) == 0 {
return batch[:0], nil
}
count := len(batch)
totalBytes := 0
for _, item := range batch {
totalBytes += len(item.message)
}
p.metrics.BatchSize.Observe(float64(count))
start := p.clock.Now()
senderErr := p.sender.Flush(ctx, batch)
elapsed := p.clock.Since(start)
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
var err error
if senderErr != nil {
stage := batchFlushStage(senderErr)
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
var resetErr error
if reason != batchFlushShutdown {
resetErr = p.resetSender()
}
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
if fallbackErr != nil || resetErr != nil {
err = errors.Join(senderErr, fallbackErr, resetErr)
}
} else if p.delegate != nil {
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
}
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
clear(batch)
return batch[:0], err
}
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
if len(batch) == 0 {
return 0, 0, nil
}
p.observeDelegateFallbackBatch(batch, reason, stage)
if p.delegate == nil {
return 0, len(batch), xerrors.New("delegate pubsub is nil")
}
var errs []error
for _, item := range batch {
if err := p.delegate.Publish(item.event, item.message); err != nil {
failed++
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
continue
}
delivered++
}
return delivered, failed, errors.Join(errs...)
}
func (p *BatchingPubsub) resetSender() error {
if p.newSender == nil {
return nil
}
newSender, err := p.newSender(context.Background())
if err != nil {
p.metrics.SenderResetFailuresTotal.Inc()
return err
}
oldSender := p.sender
p.sender = newSender
p.metrics.SenderResetsTotal.Inc()
if oldSender == nil {
return nil
}
if err := oldSender.Close(); err != nil {
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
}
return nil
}
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
fields := []slog.Field{
slog.F("reason", reason),
slog.F("stage", stage),
slog.F("count", count),
slog.F("total_bytes", totalBytes),
slog.F("delegate_delivered", delivered),
slog.F("delegate_failed", failed),
slog.Error(senderErr),
}
if fallbackErr != nil {
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
}
if resetErr != nil {
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
}
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
}
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
defer cancel()
var errs []error
for {
batch = p.drainIntoBatch(batch)
if len(batch) == 0 {
break
}
var err error
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
if err != nil {
errs = append(errs, err)
}
if ctx.Err() != nil {
break
}
}
dropped := p.dropPendingPublishes()
if dropped > 0 {
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
}
if ctx.Err() != nil {
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
}
return errors.Join(errs...)
}
func (p *BatchingPubsub) dropPendingPublishes() int {
count := 0
for {
select {
case <-p.publishCh:
count++
default:
if count > 0 {
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
}
return count
}
}
}
func (p *BatchingPubsub) logPublishRejection(event string) {
fields := []slog.Field{
slog.F("event", event),
slog.F("queue_size", cap(p.publishCh)),
slog.F("queued", p.queuedCount.Load()),
}
select {
case <-p.warnTicker.C:
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
default:
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
}
}
type pgBatchSender struct {
logger slog.Logger
db *sql.DB
}
func newPGBatchSender(
ctx context.Context,
logger slog.Logger,
prototype *sql.DB,
connectURL string,
) (*pgBatchSender, error) {
connector, err := newConnector(ctx, logger, prototype, connectURL)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxIdleTime(0)
db.SetConnMaxLifetime(0)
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
}
return &pgBatchSender{logger: logger, db: db}, nil
}
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, item := range batch {
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
// not support the first parameter being a prepared statement.
//nolint:gosec
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
if err != nil {
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
}
}
if err := tx.Commit(); err != nil {
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
}
committed = true
return nil
}
func (s *pgBatchSender) Close() error {
return s.db.Close()
}
@@ -0,0 +1,520 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"sync"
"testing"
"time"
_ "github.com/lib/pq"
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchingPubsubScheduledFlush(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 2)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("two"), batch[1].message)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 2, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
}
func TestBatchChannelClass(t *testing.T) {
t.Parallel()
tests := []struct {
name string
event string
want string
}{
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, batchChannelClass(tt.event))
})
}
}
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 64,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Enqueue many messages before the timer fires — all should be
// drained and flushed in a single batch.
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 5)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("five"), batch[4].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
defer pressureTrap.Close()
sender := newFakeBatchSender()
sender.blockCh = make(chan struct{})
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 1,
PressureWait: 10 * time.Millisecond,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Fill the queue (capacity 1).
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
// Fire the timer so the run loop starts flushing "one" — the
// sender blocks on blockCh so the flush stays in-flight.
clock.Advance(10 * time.Millisecond).MustWait(ctx)
<-sender.started
// The run loop is blocked in flushBatch. Fill the queue again.
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
// A third publish should fall back to the delegate (which has a
// closed db, so the delegate Publish itself will error — but we
// verify the fallback metric was incremented).
errCh := make(chan error, 1)
go func() {
errCh <- ps.Publish("chat:stream:a", []byte("three"))
}()
pressureCall, err := pressureTrap.Wait(ctx)
require.NoError(t, err)
pressureCall.MustRelease(ctx)
clock.Advance(10 * time.Millisecond).MustWait(ctx)
err = testutil.TryReceive(ctx, t, errCh)
// The delegate has a closed db so it returns an error from the
// shared pool, not a batching-specific sentinel.
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
close(sender.blockCh)
// Let the run loop finish the blocked flush and process "two".
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
require.NoError(t, ps.Close())
}
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
require.NoError(t, ps.Close())
batches := sender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 3)
require.Equal(t, []byte("one"), batches[0][0].message)
require.Equal(t, []byte("two"), batches[0][1].message)
require.Equal(t, []byte("three"), batches[0][2].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, 1, sender.CloseCalls())
}
func TestBatchingPubsubPreservesOrder(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.NoError(t, ps.Close())
batches := sender.Batches()
require.NotEmpty(t, batches)
messages := make([]string, 0, 5)
for _, batch := range batches {
for _, item := range batch {
messages = append(messages, string(item.message))
}
}
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
}
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 1, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
t.Parallel()
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
for _, stage := range stages {
stage := stage
t.Run(stage, func(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = stage
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:test",
channelClass: batchChannelClass("chat:stream:test"),
message: []byte("fallback-" + stage),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
})
}
}
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
firstSender := newFakeBatchSender()
firstSender.err = context.DeadlineExceeded
firstSender.errStage = batchFlushStageExec
secondSender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
ps.newSender = func(context.Context) (batchSender, error) {
return secondSender, nil
}
firstBatch := []queuedPublish{{
event: "chat:stream:first",
channelClass: batchChannelClass("chat:stream:first"),
message: []byte("first"),
}}
ps.queuedCount.Store(int64(len(firstBatch)))
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
require.Equal(t, 1, firstSender.CloseCalls())
secondBatch := []queuedPublish{{
event: "chat:stream:second",
channelClass: batchChannelClass("chat:stream:second"),
message: []byte("second"),
}}
ps.queuedCount.Store(int64(len(secondBatch)))
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
require.NoError(t, err)
batches := secondSender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 1)
require.Equal(t, []byte("second"), batches[0][0].message)
}
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:error",
channelClass: batchChannelClass("chat:stream:error"),
message: []byte("error"),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
// Use a closed *sql.DB so that delegate.Publish returns a real
// error instead of panicking on a nil pointer when the batching
// queue falls back to the shared pool under pressure.
closedDB := newClosedDB(t)
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Close()
})
return ps, delegate
}
// newClosedDB returns an *sql.DB whose connections have been closed,
// so any ExecContext call returns an error rather than panicking.
func newClosedDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
require.NoError(t, err)
require.NoError(t, db.Close())
return db
}
type fakeBatchSender struct {
mu sync.Mutex
batches [][]queuedPublish
flushes chan []queuedPublish
started chan struct{}
blockCh chan struct{}
err error
errStage string
closeErr error
closeCall int
}
func newFakeBatchSender() *fakeBatchSender {
return &fakeBatchSender{
flushes: make(chan []queuedPublish, 16),
started: make(chan struct{}, 16),
}
}
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
select {
case s.started <- struct{}{}:
default:
}
if s.blockCh != nil {
select {
case <-s.blockCh:
case <-ctx.Done():
return ctx.Err()
}
}
clone := make([]queuedPublish, len(batch))
for i, item := range batch {
clone[i] = queuedPublish{
event: item.event,
message: bytes.Clone(item.message),
}
}
s.mu.Lock()
s.batches = append(s.batches, clone)
s.mu.Unlock()
select {
case s.flushes <- clone:
default:
}
if s.err == nil {
return nil
}
if s.errStage != "" {
return &batchFlushError{stage: s.errStage, err: s.err}
}
return s.err
}
type metricWriter interface {
Write(*dto.Metric) error
}
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
t.Helper()
writer, ok := observer.(metricWriter)
require.True(t, ok)
metric := &dto.Metric{}
require.NoError(t, writer.Write(metric))
histogram := metric.GetHistogram()
require.NotNil(t, histogram)
return histogram.GetSampleCount(), histogram.GetSampleSum()
}
func (s *fakeBatchSender) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.closeCall++
return s.closeErr
}
func (s *fakeBatchSender) Batches() [][]queuedPublish {
s.mu.Lock()
defer s.mu.Unlock()
clone := make([][]queuedPublish, len(s.batches))
for i, batch := range s.batches {
clone[i] = make([]queuedPublish, len(batch))
copy(clone[i], batch)
}
return clone
}
func (s *fakeBatchSender) CloseCalls() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeCall
}
+130
View File
@@ -0,0 +1,130 @@
package pubsub_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
event := t.Name()
messageCh := make(chan []byte, 1)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("hello")))
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
}
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
event := t.Name()
messageCh := make(chan []byte, 4)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
require.NoError(t, senderConn.Close())
reconnected := false
delivered := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
if !reconnected {
select {
case conn := <-trackedDriver.Connections:
reconnected = conn != nil
default:
}
}
select {
case <-messageCh:
default:
}
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
return false
}
select {
case msg := <-messageCh:
delivered = string(msg) == "after-disconnect"
case <-time.After(testutil.IntervalFast):
delivered = false
}
return reconnected && delivered
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
}
+21 -9
View File
@@ -487,12 +487,14 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
return conn, nil
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
// Creates a new listener using pq.
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
if db == nil {
return nil, xerrors.New("database is nil")
}
var (
dialer = logDialer{
logger: p.logger,
logger: logger,
// pq.defaultDialer uses a zero net.Dialer as well.
d: net.Dialer{},
}
@@ -501,28 +503,38 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
)
// Create a custom connector if the database driver supports it.
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
if ok {
connector, err = connectorCreator.Connector(connectURL)
if err != nil {
return xerrors.Errorf("create custom connector: %w", err)
return nil, xerrors.Errorf("create custom connector: %w", err)
}
} else {
// use the default pq connector otherwise
// Use the default pq connector otherwise.
connector, err = pq.NewConnector(connectURL)
if err != nil {
return xerrors.Errorf("create pq connector: %w", err)
return nil, xerrors.Errorf("create pq connector: %w", err)
}
}
// Set the dialer if the connector supports it.
dc, ok := connector.(database.DialerConnector)
if !ok {
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
} else {
dc.Dialer(dialer)
}
return connector, nil
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
if err != nil {
return err
}
var (
errCh = make(chan error, 1)
sentErrCh = false
+4 -47
View File
@@ -76,7 +76,6 @@ type sqlcQuerier interface {
CleanTailnetLostPeers(ctx context.Context) error
CleanTailnetTunnels(ctx context.Context) error
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
@@ -102,8 +101,6 @@ type sqlcQuerier interface {
// be recreated.
DeleteAllWebpushSubscriptions(ctx context.Context) error
DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error
DeleteChatDebugDataAfterMessageID(ctx context.Context, arg DeleteChatDebugDataAfterMessageIDParams) (int64, error)
DeleteChatDebugDataByChatID(ctx context.Context, chatID uuid.UUID) (int64, error)
DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error
DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error
DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error
@@ -171,7 +168,7 @@ type sqlcQuerier interface {
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
@@ -196,16 +193,6 @@ type sqlcQuerier interface {
FetchNewMessageMetadata(ctx context.Context, arg FetchNewMessageMetadataParams) (FetchNewMessageMetadataRow, error)
FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error)
FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error)
// Marks orphaned in-progress rows as interrupted so they do not stay
// in a non-terminal state forever. The NOT IN list must match the
// terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
//
// The steps CTE also catches steps whose parent run was just finalized
// (via run_id IN), because PostgreSQL data-modifying CTEs share the
// same snapshot and cannot see each other's row updates. Without this,
// a step with a recent updated_at would survive its run's finalization
// and remain in 'in_progress' state permanently.
FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error)
// FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters.
// It returns the preset ID if a match is found, or NULL if no match is found.
// The query finds presets where all preset parameters are present in the provided parameters,
@@ -228,7 +215,6 @@ type sqlcQuerier interface {
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveAISeatCount(ctx context.Context) (int64, error)
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
@@ -270,15 +256,6 @@ type sqlcQuerier interface {
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
// GetChatDebugLoggingAllowUsers returns the runtime admin setting that
// allows users to opt into chat debug logging when the deployment does
// not already force debug logging on globally.
GetChatDebugLoggingAllowUsers(ctx context.Context) (bool, error)
GetChatDebugRunByID(ctx context.Context, id uuid.UUID) (ChatDebugRun, error)
// Returns the most recent debug runs for a chat, ordered newest-first.
// Callers must supply an explicit limit to avoid unbounded result sets.
GetChatDebugRunsByChatID(ctx context.Context, arg GetChatDebugRunsByChatIDParams) ([]ChatDebugRun, error)
GetChatDebugStepsByRunID(ctx context.Context, runID uuid.UUID) ([]ChatDebugStep, error)
GetChatDesktopEnabled(ctx context.Context) (bool, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
@@ -439,12 +416,11 @@ type sqlcQuerier interface {
// per PR for state/additions/deletions/model (model comes from the
// most recent chat).
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
// Returns all individual PR rows with cost for the selected time range.
// Returns individual PR rows with cost for the recent PRs table.
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
// direct children (that lack their own PR), and deduped picks one row
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
// large result sets from direct API callers.
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
// per PR for metadata.
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
// PR Insights queries for the /agents analytics dashboard.
// These aggregate data from chat_diff_statuses (PR metadata) joined
// with chats and chat_messages (cost) to power the PR Insights view.
@@ -640,7 +616,6 @@ type sqlcQuerier interface {
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserChatCompactionThreshold(ctx context.Context, arg GetUserChatCompactionThresholdParams) (string, error)
GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error)
GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error)
GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error)
GetUserChatSpendInPeriod(ctx context.Context, arg GetUserChatSpendInPeriodParams) (int64, error)
GetUserCount(ctx context.Context, includeSystem bool) (int64, error)
@@ -760,8 +735,6 @@ type sqlcQuerier interface {
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error)
InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error)
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
@@ -920,7 +893,6 @@ type sqlcQuerier interface {
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
//
// This must be called from within a transaction. The lock will be automatically
@@ -940,16 +912,6 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error)
// Uses COALESCE so that passing NULL from Go means "keep the
// existing value." This is intentional: debug rows follow a
// write-once-finalize pattern where fields are set at creation
// or finalization and never cleared back to NULL.
UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
@@ -1046,7 +1008,6 @@ type sqlcQuerier interface {
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
@@ -1078,9 +1039,6 @@ type sqlcQuerier interface {
// cumulative values for unique counts (accurate period totals). Request counts
// are always deltas, accumulated in DB. Returns true if insert, false if update.
UpsertBoundaryUsageStats(ctx context.Context, arg UpsertBoundaryUsageStatsParams) (bool, error)
// UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
// allows users to opt into chat debug logging.
UpsertChatDebugLoggingAllowUsers(ctx context.Context, allowUsers bool) error
UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
@@ -1118,7 +1076,6 @@ type sqlcQuerier interface {
// used to store the data, and the minutes are summed for each user and template
// combination. The result is stored in the template_usage_stats table.
UpsertTemplateUsageStats(ctx context.Context) error
UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error
UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error)
UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error
UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-205
View File
@@ -1,205 +0,0 @@
-- name: InsertChatDebugRun :one
INSERT INTO chat_debug_runs (
chat_id,
root_chat_id,
parent_chat_id,
model_config_id,
trigger_message_id,
history_tip_message_id,
kind,
status,
provider,
model,
summary,
started_at,
updated_at,
finished_at
)
VALUES (
@chat_id::uuid,
sqlc.narg('root_chat_id')::uuid,
sqlc.narg('parent_chat_id')::uuid,
sqlc.narg('model_config_id')::uuid,
sqlc.narg('trigger_message_id')::bigint,
sqlc.narg('history_tip_message_id')::bigint,
@kind::text,
@status::text,
sqlc.narg('provider')::text,
sqlc.narg('model')::text,
COALESCE(sqlc.narg('summary')::jsonb, '{}'::jsonb),
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
sqlc.narg('finished_at')::timestamptz
)
RETURNING *;
-- name: UpdateChatDebugRun :one
-- Uses COALESCE so that passing NULL from Go means "keep the
-- existing value." This is intentional: debug rows follow a
-- write-once-finalize pattern where fields are set at creation
-- or finalization and never cleared back to NULL.
UPDATE chat_debug_runs
SET
root_chat_id = COALESCE(sqlc.narg('root_chat_id')::uuid, root_chat_id),
parent_chat_id = COALESCE(sqlc.narg('parent_chat_id')::uuid, parent_chat_id),
model_config_id = COALESCE(sqlc.narg('model_config_id')::uuid, model_config_id),
trigger_message_id = COALESCE(sqlc.narg('trigger_message_id')::bigint, trigger_message_id),
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
status = COALESCE(sqlc.narg('status')::text, status),
provider = COALESCE(sqlc.narg('provider')::text, provider),
model = COALESCE(sqlc.narg('model')::text, model),
summary = COALESCE(sqlc.narg('summary')::jsonb, summary),
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
updated_at = NOW()
WHERE id = @id::uuid
AND chat_id = @chat_id::uuid
RETURNING *;
-- name: InsertChatDebugStep :one
INSERT INTO chat_debug_steps (
run_id,
chat_id,
step_number,
operation,
status,
history_tip_message_id,
assistant_message_id,
normalized_request,
normalized_response,
usage,
attempts,
error,
metadata,
started_at,
updated_at,
finished_at
)
SELECT
@run_id::uuid,
run.chat_id,
@step_number::int,
@operation::text,
@status::text,
sqlc.narg('history_tip_message_id')::bigint,
sqlc.narg('assistant_message_id')::bigint,
COALESCE(sqlc.narg('normalized_request')::jsonb, '{}'::jsonb),
sqlc.narg('normalized_response')::jsonb,
sqlc.narg('usage')::jsonb,
COALESCE(sqlc.narg('attempts')::jsonb, '[]'::jsonb),
sqlc.narg('error')::jsonb,
COALESCE(sqlc.narg('metadata')::jsonb, '{}'::jsonb),
COALESCE(sqlc.narg('started_at')::timestamptz, NOW()),
COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()),
sqlc.narg('finished_at')::timestamptz
FROM chat_debug_runs run
WHERE run.id = @run_id::uuid
AND run.chat_id = @chat_id::uuid
RETURNING *;
-- name: UpdateChatDebugStep :one
-- Uses COALESCE so that passing NULL from Go means "keep the
-- existing value." This is intentional: debug rows follow a
-- write-once-finalize pattern where fields are set at creation
-- or finalization and never cleared back to NULL.
UPDATE chat_debug_steps
SET
status = COALESCE(sqlc.narg('status')::text, status),
history_tip_message_id = COALESCE(sqlc.narg('history_tip_message_id')::bigint, history_tip_message_id),
assistant_message_id = COALESCE(sqlc.narg('assistant_message_id')::bigint, assistant_message_id),
normalized_request = COALESCE(sqlc.narg('normalized_request')::jsonb, normalized_request),
normalized_response = COALESCE(sqlc.narg('normalized_response')::jsonb, normalized_response),
usage = COALESCE(sqlc.narg('usage')::jsonb, usage),
attempts = COALESCE(sqlc.narg('attempts')::jsonb, attempts),
error = COALESCE(sqlc.narg('error')::jsonb, error),
metadata = COALESCE(sqlc.narg('metadata')::jsonb, metadata),
finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at),
updated_at = NOW()
WHERE id = @id::uuid
AND chat_id = @chat_id::uuid
RETURNING *;
-- name: GetChatDebugRunsByChatID :many
-- Returns the most recent debug runs for a chat, ordered newest-first.
-- Callers must supply an explicit limit to avoid unbounded result sets.
SELECT *
FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid
ORDER BY started_at DESC, id DESC
LIMIT @limit_val::int;
-- name: GetChatDebugRunByID :one
SELECT *
FROM chat_debug_runs
WHERE id = @id::uuid;
-- name: GetChatDebugStepsByRunID :many
SELECT *
FROM chat_debug_steps
WHERE run_id = @run_id::uuid
ORDER BY step_number ASC, started_at ASC;
-- name: DeleteChatDebugDataByChatID :execrows
DELETE FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid;
-- name: DeleteChatDebugDataAfterMessageID :execrows
WITH affected_runs AS (
SELECT DISTINCT run.id
FROM chat_debug_runs run
WHERE run.chat_id = @chat_id::uuid
AND (
run.history_tip_message_id > @message_id::bigint
OR run.trigger_message_id > @message_id::bigint
)
UNION
SELECT DISTINCT step.run_id AS id
FROM chat_debug_steps step
WHERE step.chat_id = @chat_id::uuid
AND (
step.assistant_message_id > @message_id::bigint
OR step.history_tip_message_id > @message_id::bigint
)
)
DELETE FROM chat_debug_runs
WHERE chat_id = @chat_id::uuid
AND id IN (SELECT id FROM affected_runs);
-- name: FinalizeStaleChatDebugRows :one
-- Marks orphaned in-progress rows as interrupted so they do not stay
-- in a non-terminal state forever. The NOT IN list must match the
-- terminal statuses defined by ChatDebugStatus in codersdk/chats.go.
--
-- The steps CTE also catches steps whose parent run was just finalized
-- (via run_id IN), because PostgreSQL data-modifying CTEs share the
-- same snapshot and cannot see each other's row updates. Without this,
-- a step with a recent updated_at would survive its run's finalization
-- and remain in 'in_progress' state permanently.
WITH finalized_runs AS (
UPDATE chat_debug_runs
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE updated_at < @updated_before::timestamptz
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING id
), finalized_steps AS (
UPDATE chat_debug_steps
SET
status = 'interrupted',
updated_at = NOW(),
finished_at = NOW()
WHERE (
updated_at < @updated_before::timestamptz
OR run_id IN (SELECT id FROM finalized_runs)
)
AND finished_at IS NULL
AND status NOT IN ('completed', 'error', 'interrupted')
RETURNING 1
)
SELECT
(SELECT COUNT(*) FROM finalized_runs)::bigint AS runs_finalized,
(SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized;
+4 -5
View File
@@ -173,12 +173,11 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
ORDER BY total_prs DESC;
-- name: GetPRInsightsPullRequests :many
-- Returns all individual PR rows with cost for the selected time range.
-- name: GetPRInsightsRecentPRs :many
-- Returns individual PR rows with cost for the recent PRs table.
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
-- direct children (that lack their own PR), and deduped picks one row
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
-- large result sets from direct API callers.
-- per PR for metadata.
WITH pr_costs AS (
SELECT
prc.pr_key,
@@ -265,4 +264,4 @@ SELECT * FROM (
JOIN pr_costs pc ON pc.pr_key = d.pr_key
) sub
ORDER BY sub.created_at DESC
LIMIT 500;
LIMIT @limit_val::int;
+13 -40
View File
@@ -353,18 +353,20 @@ WHERE
ELSE chats.archived = sqlc.narg('archived') :: boolean
END
AND CASE
-- Cursor pagination: the last element on a page acts as the cursor.
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
-- (pin_order is negated so lower values sort first in DESC order),
-- which lets us use a single tuple < comparison.
-- This allows using the last element on a page as effectively a cursor.
-- This is an important option for scripts that need to paginate without
-- duplicating or missing data.
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
-- The pagination cursor is the last ID of the previous page.
-- The query is ordered by the updated_at field, so select all
-- rows before the cursor.
(updated_at, id) < (
SELECT
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
updated_at, id
FROM
chats c2
chats
WHERE
c2.id = @after_id
id = @after_id
)
)
ELSE true
@@ -376,15 +378,9 @@ WHERE
-- Authorize Filter clause will be injected below in GetAuthorizedChats
-- @authorize_filter
ORDER BY
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
-- pinned chats, lower pin_order values come first. The negation
-- trick (-pin_order) keeps all sort columns DESC so the cursor
-- tuple < comparison works with uniform direction.
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
-pin_order DESC,
updated_at DESC,
id DESC
OFFSET @offset_opt
-- Deterministic and consistent ordering of all rows, even if they share
-- a timestamp. This is to ensure consistent pagination.
(updated_at, id) DESC OFFSET @offset_opt
LIMIT
-- The chat list is unbounded and expected to grow large.
-- Default to 50 to prevent accidental excessively large queries.
@@ -1297,26 +1293,3 @@ GROUP BY cm.chat_id;
SELECT id, provider, model, context_limit, enabled, is_default
FROM chat_model_configs
WHERE deleted = false;
-- name: GetActiveChatsByAgentID :many
SELECT *
FROM chats
WHERE agent_id = @agent_id::uuid
AND archived = false
-- Active statuses only: waiting, pending, running, paused,
-- requires_action.
-- Excludes completed and error (terminal states).
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
ORDER BY updated_at DESC;
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
UPDATE chat_messages
SET provider_response_id = NULL
WHERE chat_id = @chat_id::uuid
AND deleted = false
AND provider_response_id IS NOT NULL;
-- name: SoftDeleteContextFileMessages :exec
UPDATE chat_messages SET deleted = true
WHERE chat_id = @chat_id::uuid
AND deleted = false
AND content::jsonb @> '[{"type": "context-file"}]';
+2 -4
View File
@@ -195,8 +195,7 @@ SELECT
w.id AS workspace_id,
COALESCE(w.name, '') AS workspace_name,
-- Include the name of the provisioner_daemon associated to the job
COALESCE(pd.name, '') AS worker_name,
wb.transition as workspace_build_transition
COALESCE(pd.name, '') AS worker_name
FROM
provisioner_jobs pj
LEFT JOIN
@@ -241,8 +240,7 @@ GROUP BY
t.icon,
w.id,
w.name,
pd.name,
wb.transition
pd.name
ORDER BY
pj.created_at DESC
LIMIT
-25
View File
@@ -179,31 +179,6 @@ SET value = CASE
END
WHERE site_configs.key = 'agents_desktop_enabled';
-- GetChatDebugLoggingAllowUsers returns the runtime admin setting that
-- allows users to opt into chat debug logging when the deployment does
-- not already force debug logging on globally.
-- name: GetChatDebugLoggingAllowUsers :one
SELECT
COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_chat_debug_logging_allow_users'), false) :: boolean AS allow_users;
-- UpsertChatDebugLoggingAllowUsers updates the runtime admin setting that
-- allows users to opt into chat debug logging.
-- name: UpsertChatDebugLoggingAllowUsers :exec
INSERT INTO site_configs (key, value)
VALUES (
'agents_chat_debug_logging_allow_users',
CASE
WHEN sqlc.arg(allow_users)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT (key) DO UPDATE
SET value = CASE
WHEN sqlc.arg(allow_users)::bool THEN 'true'
ELSE 'false'
END
WHERE site_configs.key = 'agents_chat_debug_logging_allow_users';
-- GetChatTemplateAllowlist returns the JSON-encoded template allowlist.
-- Returns an empty string when no allowlist has been configured (all templates allowed).
-- name: GetChatTemplateAllowlist :one
+1 -1
View File
@@ -56,6 +56,6 @@ SET
WHERE user_id = @user_id AND name = @name
RETURNING *;
-- name: DeleteUserSecretByUserIDAndName :execrows
-- name: DeleteUserSecretByUserIDAndName :exec
DELETE FROM user_secrets
WHERE user_id = @user_id AND name = @name;
-25
View File
@@ -213,31 +213,6 @@ RETURNING *;
-- name: DeleteUserChatCompactionThreshold :exec
DELETE FROM user_configs WHERE user_id = @user_id AND key = @key;
-- name: GetUserChatDebugLoggingEnabled :one
SELECT
value = 'true' AS debug_logging_enabled
FROM user_configs
WHERE user_id = @user_id
AND key = 'chat_debug_logging_enabled';
-- name: UpsertUserChatDebugLoggingEnabled :exec
INSERT INTO user_configs (user_id, key, value)
VALUES (
@user_id,
'chat_debug_logging_enabled',
CASE
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
ELSE 'false'
END
)
ON CONFLICT ON CONSTRAINT user_configs_pkey
DO UPDATE SET value = CASE
WHEN sqlc.arg(debug_logging_enabled)::bool THEN 'true'
ELSE 'false'
END
WHERE user_configs.user_id = @user_id
AND user_configs.key = 'chat_debug_logging_enabled';
-- name: GetUserTaskNotificationAlertDismissed :one
SELECT
value::boolean as task_notification_alert_dismissed
@@ -190,14 +190,6 @@ SET
WHERE
id = $1;
-- name: UpdateWorkspaceAgentDirectoryByID :exec
UPDATE
workspace_agents
SET
directory = $2, updated_at = $3
WHERE
id = $1;
-- name: GetWorkspaceAgentLogsAfter :many
SELECT
*
-4
View File
@@ -15,8 +15,6 @@ const (
UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id);
UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id);
UniqueBoundaryUsageStatsPkey UniqueConstraint = "boundary_usage_stats_pkey" // ALTER TABLE ONLY boundary_usage_stats ADD CONSTRAINT boundary_usage_stats_pkey PRIMARY KEY (replica_id);
UniqueChatDebugRunsPkey UniqueConstraint = "chat_debug_runs_pkey" // ALTER TABLE ONLY chat_debug_runs ADD CONSTRAINT chat_debug_runs_pkey PRIMARY KEY (id);
UniqueChatDebugStepsPkey UniqueConstraint = "chat_debug_steps_pkey" // ALTER TABLE ONLY chat_debug_steps ADD CONSTRAINT chat_debug_steps_pkey PRIMARY KEY (id);
UniqueChatDiffStatusesPkey UniqueConstraint = "chat_diff_statuses_pkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_pkey PRIMARY KEY (chat_id);
UniqueChatFileLinksChatIDFileIDKey UniqueConstraint = "chat_file_links_chat_id_file_id_key" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_chat_id_file_id_key UNIQUE (chat_id, file_id);
UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id);
@@ -130,8 +128,6 @@ const (
UniqueWorkspaceResourcesPkey UniqueConstraint = "workspace_resources_pkey" // ALTER TABLE ONLY workspace_resources ADD CONSTRAINT workspace_resources_pkey PRIMARY KEY (id);
UniqueWorkspacesPkey UniqueConstraint = "workspaces_pkey" // ALTER TABLE ONLY workspaces ADD CONSTRAINT workspaces_pkey PRIMARY KEY (id);
UniqueIndexAPIKeyName UniqueConstraint = "idx_api_key_name" // CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type);
UniqueIndexChatDebugRunsIDChat UniqueConstraint = "idx_chat_debug_runs_id_chat" // CREATE UNIQUE INDEX idx_chat_debug_runs_id_chat ON chat_debug_runs USING btree (id, chat_id);
UniqueIndexChatDebugStepsRunStep UniqueConstraint = "idx_chat_debug_steps_run_step" // CREATE UNIQUE INDEX idx_chat_debug_steps_run_step ON chat_debug_steps USING btree (run_id, step_number);
UniqueIndexChatModelConfigsSingleDefault UniqueConstraint = "idx_chat_model_configs_single_default" // CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false));
UniqueIndexConnectionLogsConnectionIDWorkspaceIDAgentName UniqueConstraint = "idx_connection_logs_connection_id_workspace_id_agent_name" // CREATE UNIQUE INDEX idx_connection_logs_connection_id_workspace_id_agent_name ON connection_logs USING btree (connection_id, workspace_id, agent_name);
UniqueIndexCustomRolesNameLowerOrganizationID UniqueConstraint = "idx_custom_roles_name_lower_organization_id" // CREATE UNIQUE INDEX idx_custom_roles_name_lower_organization_id ON custom_roles USING btree (lower(name), COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid));
+77 -70
View File
@@ -137,9 +137,8 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
logger := api.Logger.Named("chat_watcher")
conn, err := websocket.Accept(rw, r, nil)
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat watch stream.",
@@ -147,44 +146,54 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
})
return
}
defer func() {
<-senderClosed
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
_ = conn.CloseRead(context.Background())
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close()
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
// The encoder is only written from the SubscribeWithErr callback,
// which delivers serially per subscription. Do not add a second
// write path without introducing synchronization.
encoder := json.NewEncoder(wsNetConn)
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
pubsub.HandleChatWatchEvent(
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
pubsub.HandleChatEvent(
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
if err != nil {
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
return
}
if err := encoder.Encode(payload); err != nil {
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
cancel()
return
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: payload,
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
}
},
))
if err != nil {
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Internal error subscribing to chat events.",
Detail: err.Error(),
},
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
}
return
}
defer cancelSubscribe()
<-ctx.Done()
// Send initial ping to signal the connection is ready.
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypePing,
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
}
for {
select {
case <-ctx.Done():
return
case <-senderClosed:
return
}
}
}
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
@@ -1810,9 +1819,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
// - pinOrder > 0 && already pinned: reorder (shift
// neighbors, clamp to [1, count]).
// - pinOrder > 0 && not pinned: append to end. The
// requested value is intentionally ignored; the
// SQL ORDER BY sorts pinned chats first so they
// appear on page 1 of the paginated sidebar.
// requested value is intentionally ignored because
// PinChatByID also bumps updated_at to keep the
// chat visible in the paginated sidebar.
var err error
errMsg := "Failed to pin chat."
switch {
@@ -2167,7 +2176,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
chatID := chat.ID
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
if api.chatDaemon == nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -2190,22 +2198,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
}
}
// Subscribe before accepting the WebSocket so that failures
// can still be reported as normal HTTP errors.
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
// Subscribe only fails today when the receiver is nil, which
// the chatDaemon == nil guard above already catches. This is
// defensive against future Subscribe failure modes.
if !ok {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
})
return
}
defer cancelSub()
conn, err := websocket.Accept(rw, r, nil)
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to open chat stream.",
@@ -2213,30 +2206,41 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
})
return
}
ctx, cancel := context.WithCancel(ctx)
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
if !ok {
if err := sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeError,
Data: codersdk.Response{
Message: "Chat streaming is not available.",
Detail: "Chat stream state is not configured.",
},
}); err != nil {
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
}
// Ensure the WebSocket is closed so senderClosed
// completes and the handler can return.
<-senderClosed
return
}
defer func() {
<-senderClosed
}()
defer cancel()
_ = conn.CloseRead(context.Background())
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
defer wsNetConn.Close()
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
// Mark the chat as read when the stream connects and again
// when it disconnects so we avoid per-message API calls while
// messages are actively streaming.
api.markChatAsRead(ctx, chatID)
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
encoder := json.NewEncoder(wsNetConn)
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
if len(batch) == 0 {
return nil
}
return encoder.Encode(batch)
return sendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: batch,
})
}
drainChatStreamBatch := func(
@@ -2269,7 +2273,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
end = len(snapshot)
}
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
return
}
}
@@ -2278,6 +2282,8 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
select {
case <-ctx.Done():
return
case <-senderClosed:
return
case firstEvent, ok := <-events:
if !ok {
return
@@ -2287,7 +2293,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
chatStreamBatchSize,
)
if err := sendChatStreamBatch(batch); err != nil {
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
return
}
if streamClosed {
@@ -2302,7 +2308,6 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
chatID := chat.ID
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
if api.chatDaemon != nil {
chat = api.chatDaemon.InterruptChat(ctx, chat)
@@ -2316,7 +2321,8 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
LastError: sql.NullString{},
})
if updateErr != nil {
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
api.Logger.Error(ctx, "failed to mark chat as waiting",
slog.F("chat_id", chatID), slog.Error(updateErr))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to interrupt chat.",
Detail: updateErr.Error(),
@@ -5626,7 +5632,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
previousSummary database.GetPRInsightsSummaryRow
timeSeries []database.GetPRInsightsTimeSeriesRow
byModel []database.GetPRInsightsPerModelRow
recentPRs []database.GetPRInsightsPullRequestsRow
recentPRs []database.GetPRInsightsRecentPRsRow
)
eg, egCtx := errgroup.WithContext(ctx)
@@ -5674,10 +5680,11 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
eg.Go(func() error {
var err error
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
LimitVal: 20,
})
return err
})
@@ -5787,10 +5794,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
Summary: summary,
TimeSeries: tsEntries,
ByModel: modelEntries,
PullRequests: prEntries,
Summary: summary,
TimeSeries: tsEntries,
ByModel: modelEntries,
RecentPRs: prEntries,
})
}
+98 -199
View File
@@ -876,186 +876,6 @@ func TestListChats(t *testing.T) {
require.NoError(t, err)
require.Len(t, allChats, totalChats)
})
// Test that a pinned chat with an old updated_at appears on page 1.
t.Run("PinnedOnFirstPage", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
// Create the chat that will later be pinned. It gets the
// earliest updated_at because it is inserted first.
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "pinned-chat",
}},
})
require.NoError(t, err)
// Fill page 1 with newer chats so the pinned chat would
// normally be pushed off the first page (default limit 50).
const fillerCount = 51
fillerChats := make([]codersdk.Chat, 0, fillerCount)
for i := range fillerCount {
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("filler-%d", i),
}},
})
require.NoError(t, createErr)
fillerChats = append(fillerChats, c)
}
// Wait for all chats to reach a terminal status so
// updated_at is stable before paginating. A single
// polling loop checks every chat per tick to avoid
// O(N) separate Eventually loops.
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
pending := make(map[uuid.UUID]struct{}, len(allCreated))
for _, c := range allCreated {
pending[c.ID] = struct{}{}
}
testutil.Eventually(ctx, t, func(_ context.Context) bool {
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
})
if listErr != nil {
return false
}
for _, ch := range all {
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
delete(pending, ch.ID)
}
}
return len(pending) == 0
}, testutil.IntervalFast)
// Pin the earliest chat.
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
// Fetch page 1 with default limit (50).
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: 50},
})
require.NoError(t, err)
// The pinned chat must appear on page 1.
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
for _, c := range page1 {
page1IDs[c.ID] = struct{}{}
}
_, found := page1IDs[pinnedChat.ID]
require.True(t, found, "pinned chat should appear on page 1")
// The pinned chat should be the first item in the list.
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
})
// Test cursor pagination with a mix of pinned and unpinned chats.
t.Run("CursorWithPins", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, _ := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
// Create 5 chats: 2 will be pinned, 3 unpinned.
const totalChats = 5
createdChats := make([]codersdk.Chat, 0, totalChats)
for i := range totalChats {
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
}},
})
require.NoError(t, createErr)
createdChats = append(createdChats, c)
}
// Wait for all chats to reach terminal status.
// Check each chat by ID rather than fetching the full list.
testutil.Eventually(ctx, t, func(_ context.Context) bool {
for _, c := range createdChats {
ch, err := client.GetChat(ctx, c.ID)
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
return false
}
}
return true
}, testutil.IntervalFast)
// Pin the first two chats (oldest updated_at).
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
PinOrder: ptr.Ref(int32(1)),
})
require.NoError(t, err)
// Paginate with limit=2 using cursor (after_id).
const pageSize = 2
maxPages := totalChats/pageSize + 2
var allPaginated []codersdk.Chat
var afterID uuid.UUID
for range maxPages {
opts := &codersdk.ListChatsOptions{
Pagination: codersdk.Pagination{Limit: pageSize},
}
if afterID != uuid.Nil {
opts.Pagination.AfterID = afterID
}
page, listErr := client.ListChats(ctx, opts)
require.NoError(t, listErr)
if len(page) == 0 {
break
}
allPaginated = append(allPaginated, page...)
afterID = page[len(page)-1].ID
}
// All chats should appear exactly once.
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
for _, c := range allPaginated {
_, dup := seenIDs[c.ID]
require.False(t, dup, "chat %s appeared more than once", c.ID)
seenIDs[c.ID] = struct{}{}
}
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
// Pinned chats should come before unpinned ones, and
// within the pinned group, lower pin_order sorts first.
pinnedSeen := false
unpinnedSeen := false
for _, c := range allPaginated {
if c.PinOrder > 0 {
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
pinnedSeen = true
} else {
unpinnedSeen = true
}
}
require.True(t, pinnedSeen, "at least one pinned chat should exist")
// Verify within-pinned ordering: pin_order=1 before
// pin_order=2 (the -pin_order DESC column).
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
"pin_order=1 chat should be first")
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
"pin_order=2 chat should be second")
})
}
func TestListChatModels(t *testing.T) {
@@ -1294,6 +1114,17 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
var event watchEvent
err = wsjson.Read(ctx, conn, &event)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
@@ -1305,16 +1136,25 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
for {
var payload codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &payload)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var payload coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &payload)
require.NoError(t, err)
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
payload.Chat.ID == createdChat.ID {
break
}
}
})
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
t.Parallel()
@@ -1334,6 +1174,18 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// Skip the initial ping.
var event watchEvent
err = wsjson.Read(ctx, conn, &event)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
@@ -1346,11 +1198,18 @@ func TestWatchChats(t *testing.T) {
var got codersdk.Chat
testutil.Eventually(ctx, t, func(_ context.Context) bool {
var payload codersdk.ChatWatchEvent
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
var update watchEvent
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
return false
}
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
if update.Type != codersdk.ServerSentEventTypeData {
return false
}
var payload coderdpubsub.ChatEvent
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
return false
}
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
payload.Chat.ID == createdChat.ID {
got = payload.Chat
return true
@@ -1423,14 +1282,25 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// Read the initial ping.
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
// Publish a diff_status_change event via pubsub,
// mimicking what PublishDiffStatusChange does after
// it reads the diff status from the DB.
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
event := codersdk.ChatWatchEvent{
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
Chat: codersdk.Chat{
ID: chat.ID,
OwnerID: chat.OwnerID,
@@ -1443,15 +1313,25 @@ func TestWatchChats(t *testing.T) {
}
payload, err := json.Marshal(event)
require.NoError(t, err)
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
require.NoError(t, err)
// Read events until we find the diff_status_change.
for {
var received codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &received)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var received coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &received)
require.NoError(t, err)
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
received.Chat.ID != chat.ID {
continue
}
@@ -1470,6 +1350,7 @@ func TestWatchChats(t *testing.T) {
break
}
})
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
t.Parallel()
@@ -1512,13 +1393,31 @@ func TestWatchChats(t *testing.T) {
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "done")
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
type watchEvent struct {
Type codersdk.ServerSentEventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
var ping watchEvent
err = wsjson.Read(ctx, conn, &ping)
require.NoError(t, err)
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
t.Helper()
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
for len(events) < 3 {
var payload codersdk.ChatWatchEvent
err = wsjson.Read(ctx, conn, &payload)
var update watchEvent
err = wsjson.Read(ctx, conn, &update)
require.NoError(t, err)
if update.Type == codersdk.ServerSentEventTypePing {
continue
}
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
var payload coderdpubsub.ChatEvent
err = json.Unmarshal(update.Data, &payload)
require.NoError(t, err)
if payload.Kind != expectedKind {
continue
@@ -1528,7 +1427,7 @@ func TestWatchChats(t *testing.T) {
return events
}
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
t.Helper()
require.Len(t, events, 3)
@@ -1541,12 +1440,12 @@ func TestWatchChats(t *testing.T) {
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
require.NoError(t, err)
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
assertLifecycleEvents(deletedEvents, true)
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
require.NoError(t, err)
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
assertLifecycleEvents(createdEvents, false)
})
-4
View File
@@ -1,4 +0,0 @@
package coderd
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
+3 -27
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/hex"
"errors"
htmltemplate "html/template"
"net/http"
"net/url"
"strings"
@@ -147,35 +146,12 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
cancel := params.redirectURL
cancelQuery := params.redirectURL.Query()
cancelQuery.Add("error", "access_denied")
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
if params.state != "" {
cancelQuery.Add("state", params.state)
}
cancel.RawQuery = cancelQuery.Encode()
cancelURI := cancel.String()
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
Status: http.StatusBadRequest,
HideStatus: false,
Title: "Invalid Callback URL",
Description: "The application's registered callback URL has an invalid scheme.",
Actions: []site.Action{
{
URL: accessURL.String(),
Text: "Back to site",
},
},
})
return
}
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
AppIcon: app.Icon,
AppName: app.Name,
// #nosec G203 -- The scheme is validated by
// codersdk.ValidateRedirectURIScheme above.
CancelURI: htmltemplate.URL(cancelURI),
AppIcon: app.Icon,
AppName: app.Name,
CancelURI: cancel.String(),
RedirectURI: r.URL.String(),
CSRFToken: nosurf.Token(r),
Username: ua.FriendlyName,
+1 -2
View File
@@ -1,7 +1,6 @@
package oauth2provider_test
import (
htmltemplate "html/template"
"net/http"
"net/http/httptest"
"testing"
@@ -21,7 +20,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
AppName: "Test OAuth App",
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
CancelURI: "https://coder.com/cancel",
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
CSRFToken: csrfFieldValue,
Username: "test-user",
-3
View File
@@ -435,9 +435,6 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
if pj.WorkspaceID.Valid {
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
}
if pj.WorkspaceBuildTransition.Valid {
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
}
return job
}
+7 -8
View File
@@ -97,14 +97,13 @@ func TestProvisionerJobs(t *testing.T) {
// Verify that job metadata is correct.
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
TemplateVersionName: version.Name,
TemplateID: template.ID,
TemplateName: template.Name,
TemplateDisplayName: template.DisplayName,
TemplateIcon: template.Icon,
WorkspaceID: &w.ID,
WorkspaceName: w.Name,
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
TemplateVersionName: version.Name,
TemplateID: template.ID,
TemplateName: template.Name,
TemplateDisplayName: template.DisplayName,
TemplateIcon: template.Icon,
WorkspaceID: &w.ID,
WorkspaceName: w.Name,
})
})
})
+1 -1
View File
@@ -14,7 +14,7 @@ import (
const ChatConfigEventChannel = "chat:config_change"
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
// messages, following the same pattern as HandleChatWatchEvent.
// messages, following the same pattern as HandleChatEvent.
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
+49
View File
@@ -0,0 +1,49 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
func ChatEventChannel(ownerID uuid.UUID) string {
return fmt.Sprintf("chat:owner:%s", ownerID)
}
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
return
}
var payload ChatEvent
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
return
}
cb(ctx, payload, err)
}
}
type ChatEvent struct {
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
}
type ChatEventKind string
const (
ChatEventKindStatusChange ChatEventKind = "status_change"
ChatEventKindTitleChange ChatEventKind = "title_change"
ChatEventKindCreated ChatEventKind = "created"
ChatEventKindDeleted ChatEventKind = "deleted"
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
ChatEventKindActionRequired ChatEventKind = "action_required"
)
-36
View File
@@ -1,36 +0,0 @@
package pubsub
import (
"context"
"encoding/json"
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
// ChatWatchEventChannel returns the pubsub channel for chat
// lifecycle events scoped to a single user.
func ChatWatchEventChannel(ownerID uuid.UUID) string {
return fmt.Sprintf("chat:owner:%s", ownerID)
}
// HandleChatWatchEvent wraps a typed callback for
// ChatWatchEvent messages delivered via pubsub.
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
return func(ctx context.Context, message []byte, err error) {
if err != nil {
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
return
}
var payload codersdk.ChatWatchEvent
if err := json.Unmarshal(message, &payload); err != nil {
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
return
}
cb(ctx, payload, err)
}
}
-280
View File
@@ -1,280 +0,0 @@
package coderd
import (
"database/sql"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
)
// @Summary Create a new user secret
// @ID create-a-new-user-secret
// @Security CoderSessionToken
// @Accept json
// @Produce json
// @Tags Secrets
// @Param user path string true "User ID, username, or me"
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
// @Success 201 {object} codersdk.UserSecret
// @Router /users/{user}/secrets [post]
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := httpmw.UserParam(r)
var req codersdk.CreateUserSecretRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.Name == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Name is required.",
})
return
}
if req.Value == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Value is required.",
})
return
}
envOpts := codersdk.UserSecretEnvValidationOptions{
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
}
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid environment variable name.",
Detail: err.Error(),
})
return
}
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid file path.",
Detail: err.Error(),
})
return
}
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
ID: uuid.New(),
UserID: user.ID,
Name: req.Name,
Description: req.Description,
Value: req.Value,
ValueKeyID: sql.NullString{},
EnvName: req.EnvName,
FilePath: req.FilePath,
})
if err != nil {
if database.IsUniqueViolation(err) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "A secret with that name, environment variable, or file path already exists.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error creating secret.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
}
// @Summary List user secrets
// @ID list-user-secrets
// @Security CoderSessionToken
// @Produce json
// @Tags Secrets
// @Param user path string true "User ID, username, or me"
// @Success 200 {array} codersdk.UserSecret
// @Router /users/{user}/secrets [get]
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
ctx := r.Context()
user := httpmw.UserParam(r)
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error listing secrets.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
}
// @Summary Get a user secret by name
// @ID get-a-user-secret-by-name
// @Security CoderSessionToken
// @Produce json
// @Tags Secrets
// @Param user path string true "User ID, username, or me"
// @Param name path string true "Secret name"
// @Success 200 {object} codersdk.UserSecret
// @Router /users/{user}/secrets/{name} [get]
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
ctx := r.Context()
user := httpmw.UserParam(r)
name := chi.URLParam(r, "name")
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: user.ID,
Name: name,
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching secret.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
}
// @Summary Update a user secret
// @ID update-a-user-secret
// @Security CoderSessionToken
// @Accept json
// @Produce json
// @Tags Secrets
// @Param user path string true "User ID, username, or me"
// @Param name path string true "Secret name"
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
// @Success 200 {object} codersdk.UserSecret
// @Router /users/{user}/secrets/{name} [patch]
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := httpmw.UserParam(r)
name := chi.URLParam(r, "name")
var req codersdk.UpdateUserSecretRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "At least one field must be provided.",
})
return
}
if req.EnvName != nil {
envOpts := codersdk.UserSecretEnvValidationOptions{
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
}
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid environment variable name.",
Detail: err.Error(),
})
return
}
}
if req.FilePath != nil {
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid file path.",
Detail: err.Error(),
})
return
}
}
params := database.UpdateUserSecretByUserIDAndNameParams{
UserID: user.ID,
Name: name,
UpdateValue: req.Value != nil,
Value: "",
ValueKeyID: sql.NullString{},
UpdateDescription: req.Description != nil,
Description: "",
UpdateEnvName: req.EnvName != nil,
EnvName: "",
UpdateFilePath: req.FilePath != nil,
FilePath: "",
}
if req.Value != nil {
params.Value = *req.Value
}
if req.Description != nil {
params.Description = *req.Description
}
if req.EnvName != nil {
params.EnvName = *req.EnvName
}
if req.FilePath != nil {
params.FilePath = *req.FilePath
}
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
}
if database.IsUniqueViolation(err) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Update would conflict with an existing secret.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating secret.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
}
// @Summary Delete a user secret
// @ID delete-a-user-secret
// @Security CoderSessionToken
// @Tags Secrets
// @Param user path string true "User ID, username, or me"
// @Param name path string true "Secret name"
// @Success 204
// @Router /users/{user}/secrets/{name} [delete]
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := httpmw.UserParam(r)
name := chi.URLParam(r, "name")
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
UserID: user.ID,
Name: name,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error deleting secret.",
Detail: err.Error(),
})
return
}
if rowsAffected == 0 {
httpapi.ResourceNotFound(rw)
return
}
rw.WriteHeader(http.StatusNoContent)
}
-413
View File
@@ -1,413 +0,0 @@
package coderd_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestPostUserSecret(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "github-token",
Value: "ghp_xxxxxxxxxxxx",
Description: "Personal GitHub PAT",
EnvName: "GITHUB_TOKEN",
FilePath: "~/.github-token",
})
require.NoError(t, err)
assert.Equal(t, "github-token", secret.Name)
assert.Equal(t, "Personal GitHub PAT", secret.Description)
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
assert.Equal(t, "~/.github-token", secret.FilePath)
assert.NotZero(t, secret.ID)
assert.NotZero(t, secret.CreatedAt)
})
t.Run("MissingName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Value: "some-value",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
assert.Contains(t, sdkErr.Message, "Name is required")
})
t.Run("MissingValue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "missing-value-secret",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
assert.Contains(t, sdkErr.Message, "Value is required")
})
t.Run("DuplicateName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "dup-secret",
Value: "value1",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "dup-secret",
Value: "value2",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
t.Run("DuplicateEnvName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "env-dup-1",
Value: "value1",
EnvName: "DUPLICATE_ENV",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "env-dup-2",
Value: "value2",
EnvName: "DUPLICATE_ENV",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
t.Run("DuplicateFilePath", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "fp-dup-1",
Value: "value1",
FilePath: "/tmp/dup-file",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "fp-dup-2",
Value: "value2",
FilePath: "/tmp/dup-file",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
t.Run("InvalidEnvName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "invalid-env-secret",
Value: "value",
EnvName: "1INVALID",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("ReservedEnvName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "reserved-env-secret",
Value: "value",
EnvName: "PATH",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("CoderPrefixEnvName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "coder-prefix-secret",
Value: "value",
EnvName: "CODER_AGENT_TOKEN",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("InvalidFilePath", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "bad-path-secret",
Value: "value",
FilePath: "relative/path",
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
func TestGetUserSecrets(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
// Verify no secrets exist on a fresh user.
ctx := testutil.Context(t, testutil.WaitMedium)
secrets, err := client.UserSecrets(ctx, codersdk.Me)
require.NoError(t, err)
assert.Empty(t, secrets)
t.Run("WithSecrets", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "list-secret-a",
Value: "value-a",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "list-secret-b",
Value: "value-b",
})
require.NoError(t, err)
secrets, err := client.UserSecrets(ctx, codersdk.Me)
require.NoError(t, err)
require.Len(t, secrets, 2)
// Sorted by name.
assert.Equal(t, "list-secret-a", secrets[0].Name)
assert.Equal(t, "list-secret-b", secrets[1].Name)
})
}
func TestGetUserSecret(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("Found", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "get-found-secret",
Value: "my-value",
EnvName: "GET_FOUND_SECRET",
})
require.NoError(t, err)
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "get-found-secret", got.Name)
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
}
func TestPatchUserSecret(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("UpdateDescription", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "patch-desc-secret",
Value: "my-value",
Description: "original",
EnvName: "PATCH_DESC_ENV",
})
require.NoError(t, err)
newDesc := "updated"
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
Description: &newDesc,
})
require.NoError(t, err)
assert.Equal(t, "updated", updated.Description)
// Other fields unchanged.
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
})
t.Run("NoFields", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "patch-nofields-secret",
Value: "my-value",
})
require.NoError(t, err)
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
newVal := "new-value"
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
Value: &newVal,
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("ConflictEnvName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "conflict-env-1",
Value: "value1",
EnvName: "CONFLICT_TAKEN_ENV",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "conflict-env-2",
Value: "value2",
})
require.NoError(t, err)
taken := "CONFLICT_TAKEN_ENV"
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
EnvName: &taken,
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
t.Run("ConflictFilePath", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "conflict-fp-1",
Value: "value1",
FilePath: "/tmp/conflict-taken",
})
require.NoError(t, err)
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "conflict-fp-2",
Value: "value2",
})
require.NoError(t, err)
taken := "/tmp/conflict-taken"
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
FilePath: &taken,
})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
})
}
func TestDeleteUserSecret(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
Name: "delete-me-secret",
Value: "my-value",
})
require.NoError(t, err)
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
require.NoError(t, err)
// Verify it's gone.
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
}
-597
View File
@@ -42,8 +42,6 @@ import (
"github.com/coder/coder/v2/coderd/telemetry"
maputil "github.com/coder/coder/v2/coderd/util/maps"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/gitsync"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -2395,598 +2393,3 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
}
return sdk
}
// maxChatContextParts caps the number of parts per request to
// prevent unbounded message payloads.
const maxChatContextParts = 100
// maxChatContextFileBytes caps each context-file part to the same
// 64KiB budget used when the agent reads instruction files from disk.
const maxChatContextFileBytes = 64 * 1024
// maxChatContextRequestBodyBytes caps the JSON request body size for
// agent-added context to roughly the same per-part budget used when
// reading instruction files from disk.
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
// sanitizeWorkspaceAgentContextFileContent applies prompt
// sanitization, then enforces the 64KiB per-file budget. The
// truncated flag is preserved when the caller already capped the
// file before sending it.
func sanitizeWorkspaceAgentContextFileContent(
content string,
truncated bool,
) (string, bool) {
content = chatd.SanitizePromptText(content)
if len(content) > maxChatContextFileBytes {
content = content[:maxChatContextFileBytes]
truncated = true
}
return content, truncated
}
// readChatContextBody reads and validates the request body for chat
// context endpoints. It handles MaxBytesReader wrapping, error
// responses, and body rewind. If the body is empty or whitespace-only
// and allowEmpty is true, it returns false without writing an error.
//
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
body, err := io.ReadAll(r.Body)
if err != nil {
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
Message: "Request body too large.",
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
})
return false
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to read request body.",
Detail: err.Error(),
})
return false
}
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
r.Body = http.NoBody
return false
}
r.Body = io.NopCloser(bytes.NewReader(body))
return httpapi.Read(ctx, rw, r, dst)
}
// @x-apidocgen {"skip": true}
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req agentsdk.AddChatContextRequest
if !readChatContextBody(ctx, rw, r, &req, false) {
return
}
if len(req.Parts) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context parts provided.",
})
return
}
if len(req.Parts) > maxChatContextParts {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
})
return
}
// Filter to only non-empty context-file and skill parts.
filtered := chatd.FilterContextParts(req.Parts, false)
if len(filtered) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context-file or skill parts provided.",
})
return
}
req.Parts = filtered
responsePartCount := 0
// Use system context for chat operations since the
// workspace agent scope does not include chat resources.
// We verify agent-to-chat ownership explicitly below.
//nolint:gocritic // Agent needs system access to read/write chat resources.
sysCtx := dbauthz.AsSystemRestricted(ctx)
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to determine workspace from agent token.",
Detail: err.Error(),
})
return
}
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
if err != nil {
writeAgentChatError(ctx, rw, err)
return
}
// Stamp each persisted part with the agent identity. Context-file
// parts also get server-authoritative workspace metadata.
directory := workspaceAgent.ExpandedDirectory
if directory == "" {
directory = workspaceAgent.Directory
}
for i := range req.Parts {
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
UUID: workspaceAgent.ID,
Valid: true,
}
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
continue
}
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
req.Parts[i].ContextFileContent,
req.Parts[i].ContextFileTruncated,
)
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
req.Parts[i].ContextFileDirectory = directory
}
req.Parts = chatd.FilterContextParts(req.Parts, false)
if len(req.Parts) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "No context-file or skill parts provided.",
})
return
}
responsePartCount = len(req.Parts)
// Skill-only messages need a sentinel context-file part so the turn
// pipeline trusts the associated skill metadata.
req.Parts = prependAgentChatContextSentinelIfNeeded(
req.Parts,
workspaceAgent.ID,
workspaceAgent.OperatingSystem,
directory,
)
content, err := chatprompt.MarshalParts(req.Parts)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to marshal context parts.",
Detail: err.Error(),
})
return
}
err = api.Database.InTx(func(tx database.Store) error {
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
if err != nil {
return xerrors.Errorf("lock chat: %w", err)
}
if !isActiveAgentChat(locked) {
return errChatNotActive
}
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
return errChatDoesNotBelongToAgent
}
if locked.OwnerID != workspace.OwnerID {
return errChatDoesNotBelongToWorkspaceOwner
}
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
chat.ID,
database.ChatMessageRoleUser,
content,
database.ChatMessageVisibilityBoth,
locked.LastModelConfigID,
chatprompt.CurrentContentVersion,
uuid.Nil,
)); err != nil {
return xerrors.Errorf("insert context message: %w", err)
}
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
return xerrors.Errorf("rebuild injected context cache: %w", err)
}
return nil
}, nil)
if err != nil {
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
writeAgentChatError(ctx, rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to persist context message.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
ChatID: chat.ID,
Count: responsePartCount,
})
}
// @x-apidocgen {"skip": true}
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
var req agentsdk.ClearChatContextRequest
populated := readChatContextBody(ctx, rw, r, &req, true)
if !populated && r.Body != http.NoBody {
return
}
// Use system context for chat operations since the
// workspace agent scope does not include chat resources.
//nolint:gocritic // Agent needs system access to read/write chat resources.
sysCtx := dbauthz.AsSystemRestricted(ctx)
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to determine workspace from agent token.",
Detail: err.Error(),
})
return
}
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
if err != nil {
// Zero active chats is not an error for clear.
if errors.Is(err, errNoActiveChats) {
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
return
}
writeAgentChatError(ctx, rw, err)
return
}
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
if err != nil {
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
writeAgentChatError(ctx, rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to clear context from chat.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
ChatID: chat.ID,
})
}
var (
errNoActiveChats = xerrors.New("no active chats found")
errChatNotFound = xerrors.New("chat not found")
errChatNotActive = xerrors.New("chat is not active")
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
)
type multipleActiveChatsError struct {
count int
}
func (e *multipleActiveChatsError) Error() string {
return fmt.Sprintf(
"multiple active chats (%d) found for this agent, specify a chat ID",
e.count,
)
}
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
switch len(chats) {
case 0:
return database.Chat{}, errNoActiveChats
case 1:
return chats[0], nil
}
var rootChat *database.Chat
for i := range chats {
chat := &chats[i]
if chat.ParentChatID.Valid {
continue
}
if rootChat != nil {
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
}
rootChat = chat
}
if rootChat != nil {
return *rootChat, nil
}
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
}
// resolveAgentChat finds the target chat from either an explicit ID
// or auto-detection via the agent's active chats.
func resolveAgentChat(
ctx context.Context,
db database.Store,
agentID uuid.UUID,
workspaceOwnerID uuid.UUID,
explicitChatID uuid.UUID,
) (database.Chat, error) {
if explicitChatID == uuid.Nil {
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
if err != nil {
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
}
ownerChats := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if chat.OwnerID != workspaceOwnerID {
continue
}
ownerChats = append(ownerChats, chat)
}
return resolveDefaultAgentChat(ownerChats)
}
chat, err := db.GetChatByID(ctx, explicitChatID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return database.Chat{}, errChatNotFound
}
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
}
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
return database.Chat{}, errChatDoesNotBelongToAgent
}
if chat.OwnerID != workspaceOwnerID {
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
}
if !isActiveAgentChat(chat) {
return database.Chat{}, errChatNotActive
}
return chat, nil
}
func isActiveAgentChat(chat database.Chat) bool {
if chat.Archived {
return false
}
switch chat.Status {
case database.ChatStatusWaiting,
database.ChatStatusPending,
database.ChatStatusRunning,
database.ChatStatusPaused,
database.ChatStatusRequiresAction:
return true
default:
return false
}
}
func clearAgentChatContext(
ctx context.Context,
db database.Store,
chatID uuid.UUID,
agentID uuid.UUID,
workspaceOwnerID uuid.UUID,
) error {
return db.InTx(func(tx database.Store) error {
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
if err != nil {
return xerrors.Errorf("lock chat: %w", err)
}
if !isActiveAgentChat(locked) {
return errChatNotActive
}
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
return errChatDoesNotBelongToAgent
}
if locked.OwnerID != workspaceOwnerID {
return errChatDoesNotBelongToWorkspaceOwner
}
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
return xerrors.Errorf("get chat messages: %w", err)
}
hadInjectedContext := locked.LastInjectedContext.Valid
var skillOnlyMessageIDs []int64
for _, msg := range messages {
if !msg.Content.Valid {
continue
}
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
if hasContextFile || hasSkill {
hadInjectedContext = true
}
if hasSkill && !hasContextFile {
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
}
}
if !hadInjectedContext {
return nil
}
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
return xerrors.Errorf("soft delete context-file messages: %w", err)
}
for _, messageID := range skillOnlyMessageIDs {
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
}
}
// Reset provider-side Responses chaining so the next turn replays
// the post-clear history instead of inheriting cleared context.
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
return xerrors.Errorf("clear provider response chain: %w", err)
}
// Clear the injected-context cache inside the transaction so it is
// atomic with the soft-deletes.
param, err := chatd.BuildLastInjectedContext(nil)
if err != nil {
return xerrors.Errorf("clear injected context cache: %w", err)
}
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
return xerrors.Errorf("clear injected context cache: %w", err)
}
return nil
}, nil)
}
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
// part when the request only carries skills. The turn pipeline uses
// the sentinel's agent metadata to trust the skill parts.
func prependAgentChatContextSentinelIfNeeded(
parts []codersdk.ChatMessagePart,
agentID uuid.UUID,
operatingSystem string,
directory string,
) []codersdk.ChatMessagePart {
hasContextFile := false
hasSkill := false
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeContextFile:
hasContextFile = true
case codersdk.ChatMessagePartTypeSkill:
hasSkill = true
}
if hasContextFile && hasSkill {
return parts
}
}
if !hasSkill || hasContextFile {
return parts
}
return append([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: chatd.AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
ContextFileOS: operatingSystem,
ContextFileDirectory: directory,
}}, parts...)
}
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
sort.SliceStable(messages, func(i, j int) bool {
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
return messages[i].ID < messages[j].ID
}
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
})
}
// updateAgentChatLastInjectedContextFromMessages rebuilds the
// injected-context cache from all persisted context-file and skill parts.
func updateAgentChatLastInjectedContextFromMessages(
ctx context.Context,
logger slog.Logger,
db database.Store,
chatID uuid.UUID,
) error {
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
})
if err != nil {
return xerrors.Errorf("load context messages for injected context: %w", err)
}
sortChatMessagesByCreatedAtAndID(messages)
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
if err != nil {
return xerrors.Errorf("collect injected context parts: %w", err)
}
parts = chatd.FilterContextPartsToLatestAgent(parts)
param, err := chatd.BuildLastInjectedContext(parts)
if err != nil {
return xerrors.Errorf("update injected context: %w", err)
}
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
return xerrors.Errorf("update injected context: %w", err)
}
return nil
}
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(raw, &parts); err != nil {
return false
}
for _, part := range parts {
for _, typ := range types {
if part.Type == typ {
return true
}
}
}
return false
}
// writeAgentChatError translates resolveAgentChat errors to HTTP
// responses.
func writeAgentChatError(
ctx context.Context,
rw http.ResponseWriter,
err error,
) {
if errors.Is(err, errNoActiveChats) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "No active chats found for this agent.",
})
return
}
if errors.Is(err, errChatNotFound) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: "Chat not found.",
})
return
}
if errors.Is(err, errChatDoesNotBelongToAgent) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Chat does not belong to this agent.",
})
return
}
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: "Chat does not belong to this workspace owner.",
})
return
}
if errors.Is(err, errChatNotActive) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Cannot modify context: this chat is no longer active.",
})
return
}
var multipleErr *multipleActiveChatsError
if errors.As(err, &multipleErr) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to resolve chat.",
Detail: err.Error(),
})
}
@@ -1,76 +0,0 @@
package coderd
import (
"fmt"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/testutil"
)
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
t.Parallel()
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
db, _ := dbtestutil.NewDB(t)
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
owner := dbgen.User(t, db, database.User{})
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: owner.ID,
}).WithAgent().Do()
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
for _, archived := range []bool{false, true} {
for _, status := range database.AllChatStatusValues() {
chat, err := db.InsertChat(ctx, database.InsertChatParams{
Status: status,
OwnerID: owner.ID,
LastModelConfigID: modelConfig.ID,
Title: fmt.Sprintf("%s-archived-%t", status, archived),
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
})
require.NoError(t, err)
if archived {
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
chat, err = db.GetChatByID(ctx, chat.ID)
require.NoError(t, err)
}
insertedChats = append(insertedChats, chat)
}
}
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
require.NoError(t, err)
activeByID := make(map[uuid.UUID]bool, len(activeChats))
for _, chat := range activeChats {
activeByID[chat.ID] = true
}
for _, chat := range insertedChats {
require.Equalf(
t,
isActiveAgentChat(chat),
activeByID[chat.ID],
"status=%s archived=%t",
chat.Status,
chat.Archived,
)
}
}
@@ -1,128 +0,0 @@
package coderd
import (
"context"
"database/sql"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/codersdk"
)
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
oldAgentID := uuid.New()
newAgentID := uuid.New()
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}})
require.NoError(t, err)
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}})
require.NoError(t, err)
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{
{
ID: 2,
CreatedAt: createdAt,
Content: pqtype.NullRawMessage{
RawMessage: newContent,
Valid: true,
},
},
{
ID: 1,
CreatedAt: createdAt,
Content: pqtype.NullRawMessage{
RawMessage: oldContent,
Valid: true,
},
},
}, nil)
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
require.Equal(t, chatID, arg.ID)
require.True(t, arg.LastInjectedContext.Valid)
var cached []codersdk.ChatMessagePart
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
require.Len(t, cached, 1)
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
return database.Chat{}, nil
},
)
err = updateAgentChatLastInjectedContextFromMessages(
context.Background(),
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
db,
chatID,
)
require.NoError(t, err)
}
func insertAgentChatTestModelConfig(
ctx context.Context,
t testing.TB,
db database.Store,
userID uuid.UUID,
) database.ChatModelConfig {
t.Helper()
sysCtx := dbauthz.AsSystemRestricted(ctx)
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-api-key",
ApiKeyKeyID: sql.NullString{},
CreatedBy: createdBy,
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: createdBy,
UpdatedBy: createdBy,
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return model
}
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err)
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
})
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
t.Parallel()
+12 -69
View File
@@ -213,39 +213,6 @@ func TestWorkspace(t *testing.T) {
t.Parallel()
t.Run("Healthy", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
_ = agenttest.New(t, client.URL, authToken)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
var err error
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
workspace, err = client.Workspace(ctx, workspace.ID)
return assert.NoError(t, err) && workspace.Health.Healthy
}, testutil.IntervalMedium)
agent := workspace.LatestBuild.Resources[0].Agents[0]
assert.True(t, workspace.Health.Healthy)
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
assert.True(t, agent.Health.Healthy)
assert.Empty(t, agent.Health.Reason)
})
t.Run("Connecting", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
@@ -280,10 +247,10 @@ func TestWorkspace(t *testing.T) {
agent := workspace.LatestBuild.Resources[0].Agents[0]
assert.False(t, workspace.Health.Healthy)
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
assert.False(t, agent.Health.Healthy)
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
assert.True(t, workspace.Health.Healthy)
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
assert.True(t, agent.Health.Healthy)
assert.Empty(t, agent.Health.Reason)
})
t.Run("Unhealthy", func(t *testing.T) {
@@ -335,7 +302,6 @@ func TestWorkspace(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
a1AuthToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
@@ -347,9 +313,7 @@ func TestWorkspace(t *testing.T) {
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "a1",
Auth: &proto.Agent_Token{
Token: a1AuthToken,
},
Auth: &proto.Agent_Token{},
}, {
Id: uuid.NewString(),
Name: "a2",
@@ -366,21 +330,13 @@ func TestWorkspace(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
_ = agenttest.New(t, client.URL, a1AuthToken)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
var err error
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
workspace, err = client.Workspace(ctx, workspace.ID)
if err != nil {
return false
}
// Wait for the mixed state: a1 connected (healthy)
// and workspace unhealthy (because a2 timed out).
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
return agent1.Health.Healthy && !workspace.Health.Healthy
return assert.NoError(t, err) && !workspace.Health.Healthy
}, testutil.IntervalMedium)
assert.False(t, workspace.Health.Healthy)
@@ -404,7 +360,6 @@ func TestWorkspace(t *testing.T) {
// disconnected, but this should not make the workspace unhealthy.
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionGraph: []*proto.Response{{
@@ -416,9 +371,7 @@ func TestWorkspace(t *testing.T) {
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "parent",
Auth: &proto.Agent_Token{
Token: authToken,
},
Auth: &proto.Agent_Token{},
}},
}},
},
@@ -430,23 +383,14 @@ func TestWorkspace(t *testing.T) {
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
_ = agenttest.New(t, client.URL, authToken)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Wait for the parent agent to connect and be healthy.
var parentAgent codersdk.WorkspaceAgent
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
var err error
workspace, err = client.Workspace(ctx, workspace.ID)
if err != nil {
return false
}
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
return parentAgent.Health.Healthy
}, testutil.IntervalMedium)
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
// Get the workspace and parent agent.
workspace, err := client.Workspace(ctx, workspace.ID)
require.NoError(t, err)
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
// Create a sub-agent with a short connection timeout so it becomes
// unhealthy quickly (simulating a devcontainer rebuild scenario).
@@ -460,7 +404,6 @@ func TestWorkspace(t *testing.T) {
// Wait for the sub-agent to become unhealthy due to timeout.
var subAgentUnhealthy bool
require.Eventually(t, func() bool {
var err error
workspace, err = client.Workspace(ctx, workspace.ID)
if err != nil {
return false
+53 -129
View File
@@ -73,9 +73,9 @@ const (
// maxConcurrentRecordingUploads caps the number of recording
// stop-and-store operations that can run concurrently. Each
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
// (110 MB) in memory, so this value implicitly bounds memory
// to roughly maxConcurrentRecordingUploads * 110 MB.
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
// this value implicitly bounds memory to roughly
// maxConcurrentRecordingUploads * 100 MB.
maxConcurrentRecordingUploads = 25
// staleRecoveryIntervalDivisor determines how often the stale
@@ -996,7 +996,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
return database.Chat{}, txErr
}
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
p.signalWake()
return chat, nil
}
@@ -1158,7 +1158,7 @@ func (p *Server) SendMessage(
p.publishMessage(opts.ChatID, result.Message)
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
}
@@ -1301,7 +1301,7 @@ func (p *Server) EditMessage(
QueueUpdate: true,
})
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
@@ -1355,10 +1355,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if interrupted {
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
}
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
return nil
}
@@ -1373,7 +1373,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
ctx,
chat.ID,
"unarchive",
codersdk.ChatWatchEventKindCreated,
coderdpubsub.ChatEventKindCreated,
p.db.UnarchiveChatByID,
)
}
@@ -1382,7 +1382,7 @@ func (p *Server) applyChatLifecycleTransition(
ctx context.Context,
chatID uuid.UUID,
action string,
kind codersdk.ChatWatchEventKind,
kind coderdpubsub.ChatEventKind,
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
) error {
updatedChats, err := transition(ctx, chatID)
@@ -1545,7 +1545,7 @@ func (p *Server) PromoteQueued(
})
p.publishMessage(opts.ChatID, promoted)
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
p.signalWake()
return result, nil
@@ -2092,7 +2092,7 @@ func (p *Server) regenerateChatTitleWithStore(
return updatedChat, nil
}
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
return updatedChat, nil
}
@@ -2347,7 +2347,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
return database.Chat{}, err
}
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
return updatedChat, nil
}
@@ -2461,33 +2461,6 @@ type chainModeInfo struct {
// trailingUserCount is the number of contiguous user messages
// at the end of the conversation that form the current turn.
trailingUserCount int
// contributingTrailingUserCount counts the trailing user
// messages that materially change the provider input.
contributingTrailingUserCount int
}
func userMessageContributesToChainMode(msg database.ChatMessage) bool {
parts, err := chatprompt.ParseContent(msg)
if err != nil {
return false
}
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeText,
codersdk.ChatMessagePartTypeReasoning:
if strings.TrimSpace(part.Text) != "" {
return true
}
case codersdk.ChatMessagePartTypeFile,
codersdk.ChatMessagePartTypeFileReference:
return true
case codersdk.ChatMessagePartTypeContextFile:
if part.ContextFileContent != "" {
return true
}
}
}
return false
}
// resolveChainMode scans DB messages from the end to count trailing user
@@ -2497,13 +2470,11 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
var info chainModeInfo
i := len(messages) - 1
for ; i >= 0; i-- {
if messages[i].Role != database.ChatMessageRoleUser {
break
}
info.trailingUserCount++
if userMessageContributesToChainMode(messages[i]) {
info.contributingTrailingUserCount++
if messages[i].Role == database.ChatMessageRoleUser {
info.trailingUserCount++
continue
}
break
}
for ; i >= 0; i-- {
switch messages[i].Role {
@@ -2526,15 +2497,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
return info
}
// filterPromptForChainMode keeps only system messages and the trailing
// user messages that still contribute model-visible content to the
// current turn. Assistant and tool messages are dropped because the
// provider already has them via the previous_response_id chain.
// filterPromptForChainMode keeps only system messages and the last
// trailingUserCount user messages from the prompt. Assistant and tool
// messages are dropped because the provider already has them via the
// previous_response_id chain.
func filterPromptForChainMode(
prompt []fantasy.Message,
info chainModeInfo,
trailingUserCount int,
) []fantasy.Message {
if info.contributingTrailingUserCount <= 0 {
if trailingUserCount <= 0 {
return prompt
}
@@ -2545,12 +2516,7 @@ func filterPromptForChainMode(
}
}
// Prompt construction already drops user turns with no model-visible
// content, such as skill-only sentinel messages. That means the user
// count here stays aligned with contributingTrailingUserCount even
// when non-contributing DB turns are interleaved in the trailing
// block.
usersToSkip := totalUsers - info.contributingTrailingUserCount
usersToSkip := totalUsers - trailingUserCount
if usersToSkip < 0 {
usersToSkip = 0
}
@@ -2596,28 +2562,6 @@ func appendChatMessage(
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
}
// BuildSingleChatMessageInsertParams creates batch insert params for one
// message using the shared chat message builder.
func BuildSingleChatMessageInsertParams(
chatID uuid.UUID,
role database.ChatMessageRole,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
createdBy uuid.UUID,
) database.InsertChatMessagesParams {
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: chatID,
}
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
if createdBy != uuid.Nil {
msg = msg.withCreatedBy(createdBy)
}
appendChatMessage(&params, msg)
return params
}
func insertUserMessageAndSetPending(
ctx context.Context,
store database.Store,
@@ -3627,7 +3571,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
}
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
for _, chat := range chats {
p.publishChatPubsubEvent(chat, kind, nil)
}
@@ -3635,7 +3579,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.Ch
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
// pubsub so that all replicas can push updates to watching clients.
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
if p.pubsub == nil {
return
}
@@ -3647,7 +3591,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
if diffStatus != nil {
sdkChat.DiffStatus = diffStatus
}
event := codersdk.ChatWatchEvent{
event := coderdpubsub.ChatEvent{
Kind: kind,
Chat: sdkChat,
}
@@ -3659,7 +3603,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
slog.F("chat_id", chat.ID),
slog.F("kind", kind),
@@ -3692,8 +3636,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
toolCalls := pendingToStreamToolCalls(pending)
sdkChat := db2sdk.Chat(chat, nil, nil)
event := codersdk.ChatWatchEvent{
Kind: codersdk.ChatWatchEventKindActionRequired,
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindActionRequired,
Chat: sdkChat,
ToolCalls: toolCalls,
}
@@ -3705,7 +3649,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
@@ -3733,7 +3677,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
}
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
return nil
}
@@ -4215,7 +4159,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
if title, ok := generatedTitle.Load(); ok {
updatedChat.Title = title
}
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
// When the chat is parked in requires_action,
// publish the stream event and global pubsub event
@@ -4486,21 +4430,13 @@ func (p *Server) runChat(
// the workspace agent has changed (e.g. workspace rebuilt).
needsInstructionPersist := false
hasContextFiles := false
persistedSkills := skillsFromParts(messages)
latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages)
currentWorkspaceAgentID := uuid.Nil
hasCurrentWorkspaceAgent := false
if chat.WorkspaceID.Valid {
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
currentWorkspaceAgentID = agent.ID
hasCurrentWorkspaceAgent = true
}
persistedAgentID, found := contextFileAgentID(messages)
hasContextFiles = found
if !hasPersistedInstructionFiles(messages) {
if !hasContextFiles {
needsInstructionPersist = true
} else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID {
// Agent changed. Persist fresh instruction files.
} else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID {
// Agent changed — persist fresh instruction files.
// Old context-file messages remain in the conversation
// to preserve the prompt cache prefix.
needsInstructionPersist = true
@@ -4523,8 +4459,7 @@ func (p *Server) runChat(
if needsInstructionPersist {
g2.Go(func() error {
var persistErr error
var discoveredSkills []chattool.SkillMeta
instruction, discoveredSkills, persistErr = p.persistInstructionFiles(
instruction, skills, persistErr = p.persistInstructionFiles(
ctx,
chat,
modelConfig.ID,
@@ -4536,12 +4471,6 @@ func (p *Server) runChat(
return workspaceCtx.getWorkspaceConn(instructionCtx)
},
)
skills = selectSkillMetasForInstructionRefresh(
persistedSkills,
discoveredSkills,
uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent},
uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent},
)
if persistErr != nil {
p.logger.Warn(ctx, "failed to persist instruction files",
slog.F("chat_id", chat.ID),
@@ -4556,7 +4485,7 @@ func (p *Server) runChat(
// re-injected via InsertSystem after compaction drops
// those messages. No workspace dial needed.
instruction = instructionFromContextFiles(messages)
skills = persistedSkills
skills = skillsFromParts(messages)
}
g2.Go(func() error {
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
@@ -5174,14 +5103,14 @@ func (p *Server) runChat(
// assistant and tool messages that the provider already has.
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
chainInfo.previousResponseID != "" &&
chainInfo.contributingTrailingUserCount > 0 &&
chainInfo.trailingUserCount > 0 &&
chainInfo.modelConfigID == modelConfig.ID
if chainModeActive {
providerOptions = chatprovider.CloneWithPreviousResponseID(
providerOptions,
chainInfo.previousResponseID,
)
prompt = filterPromptForChainMode(prompt, chainInfo)
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
}
err = chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
@@ -5235,7 +5164,7 @@ func (p *Server) runChat(
if chainModeActive {
reloadedPrompt = filterPromptForChainMode(
reloadedPrompt,
chainInfo,
chainInfo.trailingUserCount,
)
}
return reloadedPrompt, nil
@@ -5608,9 +5537,8 @@ func refreshChatWorkspaceSnapshot(
}
// contextFileAgentID extracts the workspace agent ID from the most
// recent persisted instruction-file parts. The skill-only sentinel is
// ignored because it does not represent persisted instruction content.
// Returns uuid.Nil, false if no instruction-file parts exist.
// recent persisted context-file parts. Returns uuid.Nil, false if no
// context-file parts exist.
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
var lastID uuid.UUID
found := false
@@ -5623,14 +5551,11 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
continue
}
for _, p := range parts {
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
!p.ContextFileAgentID.Valid ||
p.ContextFilePath == AgentChatContextSentinelPath {
continue
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid {
lastID = p.ContextFileAgentID.UUID
found = true
break
}
lastID = p.ContextFileAgentID.UUID
found = true
break
}
}
return lastID, found
@@ -5700,14 +5625,13 @@ func (p *Server) persistInstructionFiles(
// agent cannot know its own UUID, OS metadata, or
// directory — those are added here at the trust boundary.
var discoveredSkills []chattool.SkillMeta
var hasContent, hasContextFilePart bool
var hasContent bool
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
for i := range agentParts {
agentParts[i].ContextFileAgentID = agentID
switch agentParts[i].Type {
case codersdk.ChatMessagePartTypeContextFile:
hasContextFilePart = true
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
agentParts[i].ContextFileOS = agent.OperatingSystem
agentParts[i].ContextFileDirectory = directory
@@ -5728,13 +5652,13 @@ func (p *Server) persistInstructionFiles(
if !workspaceConnOK {
return "", nil, nil
}
// Persist a blank context-file marker (plus any skill-only
// parts) so subsequent turns skip the workspace agent dial.
if !hasContextFilePart {
agentParts = append([]codersdk.ChatMessagePart{{
// Persist a sentinel (plus any skill-only parts) so
// subsequent turns skip the workspace agent dial.
if len(agentParts) == 0 {
agentParts = []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFileAgentID: agentID,
}}, agentParts...)
}}
}
content, err := chatprompt.MarshalParts(agentParts)
if err != nil {
+11 -548
View File
@@ -8,7 +8,6 @@ import (
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -71,14 +70,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
updatedChat.Title = wantTitle
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
coderdpubsub.ChatEventChannel(ownerID),
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}{payload: payload, err: err}
}),
@@ -184,7 +183,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
@@ -234,14 +233,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
unlockedChat.StartedAt = sql.NullTime{}
messageEvents := make(chan struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatWatchEventChannel(ownerID),
coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) {
coderdpubsub.ChatEventChannel(ownerID),
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
messageEvents <- struct {
payload codersdk.ChatWatchEvent
payload coderdpubsub.ChatEvent
err error
}{payload: payload, err: err}
}),
@@ -373,7 +372,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, codersdk.ChatWatchEventKindTitleChange, event.payload.Kind)
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
@@ -704,33 +703,7 @@ func TestPersistInstructionFilesSentinelWithSkills(t *testing.T) {
gomock.Any(),
agentID,
).Return(workspaceAgent, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.InsertChatMessagesParams)
if !ok || arg.ChatID != chat.ID || len(arg.Content) != 1 {
return false
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal([]byte(arg.Content[0]), &parts); err != nil {
return false
}
foundMarker := false
foundSkill := false
for _, p := range parts {
switch p.Type {
case codersdk.ChatMessagePartTypeContextFile:
if p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) && p.ContextFileContent == "" {
foundMarker = true
}
case codersdk.ChatMessagePartTypeSkill:
if p.SkillName == "my-skill" && p.ContextFileAgentID == (uuid.NullUUID{UUID: agentID, Valid: true}) {
foundSkill = true
}
}
}
return foundMarker && foundSkill
}),
).Return(nil, nil).Times(1)
db.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(),
gomock.Cond(func(x any) bool {
arg, ok := x.(database.UpdateChatLastInjectedContextParams)
@@ -2047,30 +2020,6 @@ func TestContextFileAgentID(t *testing.T) {
require.True(t, ok)
})
t.Run("IgnoresSkillOnlySentinel", func(t *testing.T) {
t.Parallel()
instructionAgentID := uuid.New()
sentinelAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: instructionAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: sentinelAgentID,
Valid: true,
},
}}),
}
id, ok := contextFileAgentID(msgs)
require.Equal(t, instructionAgentID, id)
require.True(t, ok)
})
t.Run("SentinelWithoutAgentID", func(t *testing.T) {
t.Parallel()
msgs := []database.ChatMessage{
@@ -2087,492 +2036,6 @@ func TestContextFileAgentID(t *testing.T) {
})
}
func TestHasPersistedInstructionFiles(t *testing.T) {
t.Parallel()
t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
}}),
}
require.False(t, hasPersistedInstructionFiles(msgs))
})
t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) {
t.Parallel()
agentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/workspace/AGENTS.md",
ContextFileContent: "repo instructions",
ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
}}),
}
require.True(t, hasPersistedInstructionFiles(msgs))
})
}
func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestInstructionFromContextFilesKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/legacy/AGENTS.md",
ContextFileContent: "legacy instructions",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileContent: "old instructions",
ContextFileOS: "darwin",
ContextFileDirectory: "/old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/new/AGENTS.md",
ContextFileContent: "new instructions",
ContextFileOS: "linux",
ContextFileDirectory: "/new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
}}),
}
got := instructionFromContextFiles(msgs)
require.Contains(t, got, "legacy instructions")
require.Contains(t, got, "new instructions")
require.Contains(t, got, "Operating System: linux")
require.Contains(t, got, "Working Directory: /new")
require.NotContains(t, got, "old instructions")
require.NotContains(t, got, "Operating System: darwin")
}
func TestSkillsFromPartsKeepsLegacyUnstampedParts(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-legacy",
SkillDir: "/skills/repo-helper-legacy",
}}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{
{Name: "repo-helper-legacy", Dir: "/skills/repo-helper-legacy"},
{Name: "repo-helper-new", Dir: "/skills/repo-helper-new"},
}, got)
}
func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) {
t.Parallel()
oldAgentID := uuid.New()
newAgentID := uuid.New()
msgs := []database.ChatMessage{
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: "/old/AGENTS.md",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-old",
SkillDir: "/skills/repo-helper-old",
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
},
}),
chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: newAgentID,
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper-new",
SkillDir: "/skills/repo-helper-new",
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
},
}),
}
got := skillsFromParts(msgs)
require.Equal(t, []chattool.SkillMeta{{
Name: "repo-helper-new",
Dir: "/skills/repo-helper-new",
}}, got)
}
func TestMergeSkillMetas(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{
Name: "repo-helper",
Description: "Persisted skill",
Dir: "/skills/repo-helper-old",
}}
discovered := []chattool.SkillMeta{
{
Name: "repo-helper",
Description: "Discovered replacement",
Dir: "/skills/repo-helper-new",
MetaFile: "SKILL.md",
},
{
Name: "deep-review",
Description: "Discovered skill",
Dir: "/skills/deep-review",
},
}
got := mergeSkillMetas(persisted, discovered)
require.Equal(t, []chattool.SkillMeta{
discovered[0],
discovered[1],
}, got)
}
func TestSelectSkillMetasForInstructionRefresh(t *testing.T) {
t.Parallel()
persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}}
discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}}
currentAgentID := uuid.New()
otherAgentID := uuid.New()
t.Run("MergesCurrentAgentSkills", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: currentAgentID, Valid: true},
)
require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got)
})
t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
discovered,
uuid.NullUUID{UUID: currentAgentID, Valid: true},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, discovered, got)
})
t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) {
t.Parallel()
got := selectSkillMetasForInstructionRefresh(
persisted,
nil,
uuid.NullUUID{},
uuid.NullUUID{UUID: otherAgentID, Valid: true},
)
require.Equal(t, persisted, got)
})
}
func TestResolveChainModeIgnoresSkillOnlySentinelMessages(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
user := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
user.Role = database.ChatMessageRoleUser
got := resolveChainMode([]database.ChatMessage{assistant, skillOnly, user})
require.Equal(t, "resp-123", got.previousResponseID)
require.Equal(t, modelConfigID, got.modelConfigID)
require.Equal(t, 2, got.trailingUserCount)
require.Equal(t, 1, got.contributingTrailingUserCount)
}
func TestFilterPromptForChainModeKeepsContributingUsersAcrossSkippedSentinelTurns(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
firstTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "first trailing user",
}})
firstTrailingUser.Role = database.ChatMessageRoleUser
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
lastTrailingUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "last trailing user",
}})
lastTrailingUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
firstTrailingUser,
skillOnly,
lastTrailingUser,
})
require.Equal(t, 3, chainInfo.trailingUserCount)
require.Equal(t, 2, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "first trailing user"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "last trailing user"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 3)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
require.Equal(t, fantasy.MessageRoleUser, got[2].Role)
firstPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "first trailing user", firstPart.Text)
lastPart, ok := fantasy.AsMessagePart[fantasy.TextPart](got[2].Content[0])
require.True(t, ok)
require.Equal(t, "last trailing user", lastPart.Text)
}
func TestFilterPromptForChainModeUsesContributingTrailingUsers(t *testing.T) {
t.Parallel()
modelConfigID := uuid.New()
priorUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "prior user message",
}})
priorUser.Role = database.ChatMessageRoleUser
assistant := database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ProviderResponseID: sql.NullString{String: "resp-123", Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
skillOnly := chatMessageWithParts([]codersdk.ChatMessagePart{
{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFilePath: AgentChatContextSentinelPath,
ContextFileAgentID: uuid.NullUUID{
UUID: uuid.New(),
Valid: true,
},
},
{
Type: codersdk.ChatMessagePartTypeSkill,
SkillName: "repo-helper",
SkillDir: "/skills/repo-helper",
},
})
skillOnly.Role = database.ChatMessageRoleUser
latestUser := chatMessageWithParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: "latest user message",
}})
latestUser.Role = database.ChatMessageRoleUser
chainInfo := resolveChainMode([]database.ChatMessage{
priorUser,
assistant,
skillOnly,
latestUser,
})
require.Equal(t, 2, chainInfo.trailingUserCount)
require.Equal(t, 1, chainInfo.contributingTrailingUserCount)
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "system instruction"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "prior user message"},
},
},
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "assistant reply"},
},
},
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "latest user message"},
},
},
}
got := filterPromptForChainMode(prompt, chainInfo)
require.Len(t, got, 2)
require.Equal(t, fantasy.MessageRoleSystem, got[0].Role)
require.Equal(t, fantasy.MessageRoleUser, got[1].Role)
part, ok := fantasy.AsMessagePart[fantasy.TextPart](got[1].Content[0])
require.True(t, ok)
require.Equal(t, "latest user message", part.Text)
}
func chatMessageWithParts(parts []codersdk.ChatMessagePart) database.ChatMessage {
raw, _ := json.Marshal(parts)
return database.ChatMessage{
-84
View File
@@ -1,84 +0,0 @@
package chatdebug
import (
"context"
"runtime"
"sync"
"github.com/google/uuid"
)
type (
runContextKey struct{}
stepContextKey struct{}
reuseStepKey struct{}
reuseHolder struct {
mu sync.Mutex
handle *stepHandle
}
)
// ContextWithRun stores rc in ctx.
//
// Step counter cleanup is reference-counted per RunID: each live
// RunContext increments a counter and runtime.AddCleanup decrements
// it when the struct is garbage collected. Shared state (step
// counters) is only deleted when the last RunContext for a given
// RunID becomes unreachable, preventing premature cleanup when
// multiple RunContext instances share the same RunID.
func ContextWithRun(ctx context.Context, rc *RunContext) context.Context {
if rc == nil {
panic("chatdebug: nil RunContext")
}
enriched := context.WithValue(ctx, runContextKey{}, rc)
if rc.RunID != uuid.Nil {
trackRunRef(rc.RunID)
runtime.AddCleanup(rc, func(id uuid.UUID) {
releaseRunRef(id)
}, rc.RunID)
}
return enriched
}
// RunFromContext returns the debug run context stored in ctx.
func RunFromContext(ctx context.Context) (*RunContext, bool) {
rc, ok := ctx.Value(runContextKey{}).(*RunContext)
if !ok {
return nil, false
}
return rc, true
}
// ContextWithStep stores sc in ctx.
func ContextWithStep(ctx context.Context, sc *StepContext) context.Context {
if sc == nil {
panic("chatdebug: nil StepContext")
}
return context.WithValue(ctx, stepContextKey{}, sc)
}
// StepFromContext returns the debug step context stored in ctx.
func StepFromContext(ctx context.Context) (*StepContext, bool) {
sc, ok := ctx.Value(stepContextKey{}).(*StepContext)
if !ok {
return nil, false
}
return sc, true
}
// ReuseStep marks ctx so wrapped model calls under it share one debug step.
func ReuseStep(ctx context.Context) context.Context {
if holder, ok := reuseHolderFromContext(ctx); ok {
return context.WithValue(ctx, reuseStepKey{}, holder)
}
return context.WithValue(ctx, reuseStepKey{}, &reuseHolder{})
}
func reuseHolderFromContext(ctx context.Context) (*reuseHolder, bool) {
holder, ok := ctx.Value(reuseStepKey{}).(*reuseHolder)
if !ok {
return nil, false
}
return holder, true
}
@@ -1,118 +0,0 @@
package chatdebug
import (
"context"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/testutil"
)
func TestReuseStep_PreservesExistingHolder(t *testing.T) {
t.Parallel()
ctx := ReuseStep(context.Background())
first, ok := reuseHolderFromContext(ctx)
require.True(t, ok)
reused := ReuseStep(ctx)
second, ok := reuseHolderFromContext(reused)
require.True(t, ok)
require.Same(t, first, second)
}
func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
func() {
_ = ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
require.Equal(t, int32(1), nextStepNumber(runID))
_, ok := stepCounters.Load(runID)
require.True(t, ok)
}()
require.Eventually(t, func() bool {
runtime.GC()
runtime.Gosched()
_, ok := stepCounters.Load(runID)
return !ok
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
// rc2 is the surviving instance that should keep the step counter alive.
rc2 := &RunContext{RunID: runID, ChatID: chatID}
ctx2 := ContextWithRun(context.Background(), rc2)
// Create a second RunContext with the same RunID and let it become
// unreachable. Its GC cleanup must NOT delete the step counter
// because rc2 is still alive.
func() {
rc1 := &RunContext{RunID: runID, ChatID: chatID}
ctx1 := ContextWithRun(context.Background(), rc1)
h, _ := beginStep(ctx1, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, h)
require.Equal(t, int32(1), h.stepCtx.StepNumber)
}()
// Force GC to collect rc1.
for range 5 {
runtime.GC()
runtime.Gosched()
}
// The step counter must still be present because rc2 is alive.
_, ok := stepCounters.Load(runID)
require.True(t, ok, "step counter was prematurely cleaned up while another RunContext is still alive")
// Subsequent steps on the surviving context must continue numbering.
h2, _ := beginStep(ctx2, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil)
require.NotNil(t, h2)
require.Equal(t, int32(2), h2.stepCtx.StepNumber)
}
func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) {
t.Parallel()
runID := uuid.New()
chatID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
// Run in a closure so the RunContext becomes unreachable after
// context cancellation, allowing GC to trigger the cleanup.
func() {
ctx, cancel := context.WithCancel(context.Background())
ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID})
require.Equal(t, int32(1), nextStepNumber(runID))
_, ok := stepCounters.Load(runID)
require.True(t, ok)
cancel()
}()
// After the closure, the RunContext is unreachable.
// runtime.AddCleanup fires during GC.
require.Eventually(t, func() bool {
runtime.GC()
runtime.Gosched()
_, ok := stepCounters.Load(runID)
return !ok
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, int32(1), nextStepNumber(runID))
}
-105
View File
@@ -1,105 +0,0 @@
package chatdebug_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
)
func TestContextWithRunRoundTrip(t *testing.T) {
t.Parallel()
rc := &chatdebug.RunContext{
RunID: uuid.New(),
ChatID: uuid.New(),
RootChatID: uuid.New(),
ParentChatID: uuid.New(),
ModelConfigID: uuid.New(),
TriggerMessageID: 11,
HistoryTipMessageID: 22,
Kind: chatdebug.KindChatTurn,
Provider: "anthropic",
Model: "claude-sonnet",
}
ctx := chatdebug.ContextWithRun(context.Background(), rc)
got, ok := chatdebug.RunFromContext(ctx)
require.True(t, ok)
require.Same(t, rc, got)
require.Equal(t, *rc, *got)
}
func TestRunFromContextAbsent(t *testing.T) {
t.Parallel()
got, ok := chatdebug.RunFromContext(context.Background())
require.False(t, ok)
require.Nil(t, got)
}
func TestContextWithStepRoundTrip(t *testing.T) {
t.Parallel()
sc := &chatdebug.StepContext{
StepID: uuid.New(),
RunID: uuid.New(),
ChatID: uuid.New(),
StepNumber: 7,
Operation: chatdebug.OperationStream,
HistoryTipMessageID: 33,
}
ctx := chatdebug.ContextWithStep(context.Background(), sc)
got, ok := chatdebug.StepFromContext(ctx)
require.True(t, ok)
require.Same(t, sc, got)
require.Equal(t, *sc, *got)
}
func TestStepFromContextAbsent(t *testing.T) {
t.Parallel()
got, ok := chatdebug.StepFromContext(context.Background())
require.False(t, ok)
require.Nil(t, got)
}
func TestContextWithRunAndStep(t *testing.T) {
t.Parallel()
rc := &chatdebug.RunContext{RunID: uuid.New(), ChatID: uuid.New()}
sc := &chatdebug.StepContext{StepID: uuid.New(), RunID: rc.RunID, ChatID: rc.ChatID}
ctx := chatdebug.ContextWithStep(
chatdebug.ContextWithRun(context.Background(), rc),
sc,
)
gotRun, ok := chatdebug.RunFromContext(ctx)
require.True(t, ok)
require.Same(t, rc, gotRun)
gotStep, ok := chatdebug.StepFromContext(ctx)
require.True(t, ok)
require.Same(t, sc, gotStep)
}
func TestContextWithRunPanicsOnNil(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
_ = chatdebug.ContextWithRun(context.Background(), nil)
})
}
func TestContextWithStepPanicsOnNil(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
_ = chatdebug.ContextWithStep(context.Background(), nil)
})
}
File diff suppressed because it is too large Load Diff
@@ -1,331 +0,0 @@
package chatdebug //nolint:testpackage // Checks unexported normalized structs against fantasy source types.
import (
"reflect"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/require"
)
// fieldDisposition documents whether a fantasy struct field is captured
// by the corresponding normalized struct ("normalized") or
// intentionally omitted ("skipped: <reason>"). The test fails when a
// fantasy type gains a field that is not yet classified, forcing the
// developer to decide whether to normalize or skip it.
//
// This mirrors the audit-table exhaustiveness check in
// enterprise/audit/table.go — same idea, different domain.
type fieldDisposition = map[string]string
// TestNormalizationFieldCoverage ensures every exported field on the
// fantasy types that model.go normalizes is explicitly accounted for.
// When the fantasy library adds a field the test fails, surfacing the
// drift at `go test` time rather than silently dropping data.
func TestNormalizationFieldCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
typ reflect.Type
fields fieldDisposition
}{
// ── struct-to-struct mappings ──────────────────────────
{
name: "fantasy.Usage → normalizedUsage",
typ: reflect.TypeFor[fantasy.Usage](),
fields: fieldDisposition{
"InputTokens": "normalized",
"OutputTokens": "normalized",
"TotalTokens": "normalized",
"ReasoningTokens": "normalized",
"CacheCreationTokens": "normalized",
"CacheReadTokens": "normalized",
},
},
{
name: "fantasy.Call → normalizedCallPayload",
typ: reflect.TypeFor[fantasy.Call](),
fields: fieldDisposition{
"Prompt": "normalized",
"MaxOutputTokens": "normalized",
"Temperature": "normalized",
"TopP": "normalized",
"TopK": "normalized",
"PresencePenalty": "normalized",
"FrequencyPenalty": "normalized",
"Tools": "normalized",
"ToolChoice": "normalized",
"UserAgent": "skipped: internal transport header, not useful for debug panel",
"ProviderOptions": "skipped: opaque provider data, only count preserved",
},
},
{
name: "fantasy.ObjectCall → normalizedObjectCallPayload",
typ: reflect.TypeFor[fantasy.ObjectCall](),
fields: fieldDisposition{
"Prompt": "normalized",
"Schema": "skipped: full schema too large; SchemaName+SchemaDescription captured instead",
"SchemaName": "normalized",
"SchemaDescription": "normalized",
"MaxOutputTokens": "normalized",
"Temperature": "normalized",
"TopP": "normalized",
"TopK": "normalized",
"PresencePenalty": "normalized",
"FrequencyPenalty": "normalized",
"UserAgent": "skipped: internal transport header, not useful for debug panel",
"ProviderOptions": "skipped: opaque provider data, only count preserved",
"RepairText": "skipped: function value, not serializable",
},
},
{
name: "fantasy.Response → normalizedResponsePayload",
typ: reflect.TypeFor[fantasy.Response](),
fields: fieldDisposition{
"Content": "normalized",
"FinishReason": "normalized",
"Usage": "normalized",
"Warnings": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ObjectResponse → normalizedObjectResponsePayload",
typ: reflect.TypeFor[fantasy.ObjectResponse](),
fields: fieldDisposition{
"Object": "skipped: arbitrary user type, not serializable generically",
"RawText": "normalized: as RawTextLength (length only, content unbounded)",
"Usage": "normalized",
"FinishReason": "normalized",
"Warnings": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.CallWarning → normalizedWarning",
typ: reflect.TypeFor[fantasy.CallWarning](),
fields: fieldDisposition{
"Type": "normalized",
"Setting": "normalized",
"Tool": "skipped: interface value, warning message+type sufficient for debug panel",
"Details": "normalized",
"Message": "normalized",
},
},
{
name: "fantasy.StreamPart → appendNormalizedStreamContent",
typ: reflect.TypeFor[fantasy.StreamPart](),
fields: fieldDisposition{
"Type": "normalized",
"ID": "normalized: as ToolCallID in content parts",
"ToolCallName": "normalized: as ToolName in content parts",
"ToolCallInput": "normalized: as Arguments or Result (bounded)",
"Delta": "normalized: accumulated into text/reasoning content parts",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"Usage": "normalized: captured in stream finalize",
"FinishReason": "normalized: captured in stream finalize",
"Error": "normalized: captured in stream error handling",
"Warnings": "normalized: captured in stream warning accumulation",
"SourceType": "normalized",
"URL": "normalized",
"Title": "normalized",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ObjectStreamPart → wrapObjectStreamSeq",
typ: reflect.TypeFor[fantasy.ObjectStreamPart](),
fields: fieldDisposition{
"Type": "normalized: drives switch in wrapObjectStreamSeq",
"Object": "skipped: arbitrary user type, only ObjectPartCount tracked",
"Delta": "normalized: accumulated into rawTextLength",
"Error": "normalized: captured in stream error handling",
"Usage": "normalized: captured in stream finalize",
"FinishReason": "normalized: captured in stream finalize",
"Warnings": "normalized: captured in stream warning accumulation",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
// ── message part types (normalizeMessageParts) ────────
{
name: "fantasy.TextPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.TextPart](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ReasoningPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ReasoningPart](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.FilePart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.FilePart](),
fields: fieldDisposition{
"Filename": "normalized",
"Data": "skipped: binary data never stored in debug records",
"MediaType": "normalized",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ToolCallPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ToolCallPart](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Input": "normalized: as Arguments (bounded)",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ToolResultPart → normalizedMessagePart",
typ: reflect.TypeFor[fantasy.ToolResultPart](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"Output": "normalized: text extracted via normalizeToolResultOutput",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
// ── response content types (normalizeContentParts) ────
{
name: "fantasy.TextContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.TextContent](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ReasoningContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ReasoningContent](),
fields: fieldDisposition{
"Text": "normalized: bounded to MaxMessagePartTextLength",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.FileContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.FileContent](),
fields: fieldDisposition{
"MediaType": "normalized",
"Data": "skipped: binary data never stored in debug records",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.SourceContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.SourceContent](),
fields: fieldDisposition{
"SourceType": "normalized",
"ID": "skipped: provider-internal identifier, not actionable in debug panel",
"URL": "normalized",
"Title": "normalized",
"MediaType": "skipped: only relevant for document sources, rarely useful for debugging",
"Filename": "skipped: only relevant for document sources, rarely useful for debugging",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
{
name: "fantasy.ToolCallContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ToolCallContent](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Input": "normalized: as Arguments (bounded), InputLength tracks original",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
"Invalid": "skipped: validation state not surfaced in debug panel",
"ValidationError": "skipped: validation state not surfaced in debug panel",
},
},
{
name: "fantasy.ToolResultContent → normalizedContentPart",
typ: reflect.TypeFor[fantasy.ToolResultContent](),
fields: fieldDisposition{
"ToolCallID": "normalized",
"ToolName": "normalized",
"Result": "normalized: text extracted via normalizeToolResultOutput",
"ClientMetadata": "skipped: client execution metadata not needed for debug panel",
"ProviderExecuted": "skipped: provider vs client distinction not needed for debug panel",
"ProviderMetadata": "skipped: opaque provider-specific metadata",
},
},
// ── tool types (normalizeTools) ───────────────────────
{
name: "fantasy.FunctionTool → normalizedTool",
typ: reflect.TypeFor[fantasy.FunctionTool](),
fields: fieldDisposition{
"Name": "normalized",
"Description": "normalized",
"InputSchema": "normalized: preserved as JSON for debug panel rendering",
"ProviderOptions": "skipped: opaque provider-specific options",
},
},
{
name: "fantasy.ProviderDefinedTool → normalizedTool",
typ: reflect.TypeFor[fantasy.ProviderDefinedTool](),
fields: fieldDisposition{
"ID": "normalized",
"Name": "normalized",
"Args": "skipped: provider-specific configuration not needed for debug panel",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Every exported field on the fantasy type must be
// registered as "normalized" or "skipped: <reason>".
for i := range tt.typ.NumField() {
field := tt.typ.Field(i)
if !field.IsExported() {
continue
}
disposition, ok := tt.fields[field.Name]
if !ok {
require.Failf(t, "unregistered field",
"%s.%s is not in the coverage map — "+
"add it as \"normalized\" or \"skipped: <reason>\"",
tt.typ.Name(), field.Name)
}
require.NotEmptyf(t, disposition,
"%s.%s has an empty disposition — "+
"use \"normalized\" or \"skipped: <reason>\"",
tt.typ.Name(), field.Name)
}
// Catch stale entries that reference removed fields.
for name := range tt.fields {
found := false
for i := range tt.typ.NumField() {
if tt.typ.Field(i).Name == name {
found = true
break
}
}
require.Truef(t, found,
"stale coverage entry %s.%s — "+
"field no longer exists in fantasy, remove it",
tt.typ.Name(), name)
}
})
}
}
@@ -1,987 +0,0 @@
package chatdebug
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/testutil"
)
type testError struct{ message string }
func (e *testError) Error() string { return e.message }
func expectDebugLoggingEnabled(
t *testing.T,
db *dbmock.MockStore,
ownerID uuid.UUID,
) {
t.Helper()
db.EXPECT().GetChatDebugLoggingEnabled(gomock.Any()).Return(true, nil)
db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil)
}
func expectCreateStepNumberWithRequestValidity(
t *testing.T,
db *dbmock.MockStore,
runID uuid.UUID,
chatID uuid.UUID,
stepNumber int32,
op Operation,
normalizedRequestValid bool,
) uuid.UUID {
t.Helper()
stepID := uuid.New()
db.EXPECT().
InsertChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{})).
DoAndReturn(func(_ context.Context, params database.InsertChatDebugStepParams) (database.ChatDebugStep, error) {
require.Equal(t, runID, params.RunID)
require.Equal(t, chatID, params.ChatID)
require.Equal(t, stepNumber, params.StepNumber)
require.Equal(t, string(op), params.Operation)
require.Equal(t, string(StatusInProgress), params.Status)
require.Equal(t, normalizedRequestValid, params.NormalizedRequest.Valid)
return database.ChatDebugStep{
ID: stepID,
RunID: runID,
ChatID: chatID,
StepNumber: params.StepNumber,
Operation: params.Operation,
Status: params.Status,
}, nil
})
// CreateStep now touches the parent run's updated_at to prevent
// premature stale finalization.
db.EXPECT().
UpdateChatDebugRun(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugRunParams{})).
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugRunParams) (database.ChatDebugRun, error) {
require.Equal(t, runID, params.ID)
require.Equal(t, chatID, params.ChatID)
return database.ChatDebugRun{ID: runID, ChatID: chatID}, nil
})
return stepID
}
func expectCreateStepNumber(
t *testing.T,
db *dbmock.MockStore,
runID uuid.UUID,
chatID uuid.UUID,
stepNumber int32,
op Operation,
) uuid.UUID {
t.Helper()
return expectCreateStepNumberWithRequestValidity(
t,
db,
runID,
chatID,
stepNumber,
op,
true,
)
}
func expectCreateStep(
t *testing.T,
db *dbmock.MockStore,
runID uuid.UUID,
chatID uuid.UUID,
op Operation,
) uuid.UUID {
t.Helper()
return expectCreateStepNumber(t, db, runID, chatID, 1, op)
}
func expectUpdateStep(
t *testing.T,
db *dbmock.MockStore,
stepID uuid.UUID,
chatID uuid.UUID,
status Status,
assertFn func(database.UpdateChatDebugStepParams),
) {
t.Helper()
db.EXPECT().
UpdateChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugStepParams{})).
DoAndReturn(func(_ context.Context, params database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) {
require.Equal(t, stepID, params.ID)
require.Equal(t, chatID, params.ChatID)
require.True(t, params.Status.Valid)
require.Equal(t, string(status), params.Status.String)
require.True(t, params.FinishedAt.Valid)
if assertFn != nil {
assertFn(params)
}
return database.ChatDebugStep{
ID: stepID,
ChatID: chatID,
Status: params.Status.String,
}, nil
})
}
func TestDebugModel_Provider(t *testing.T) {
t.Parallel()
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
model := &debugModel{inner: inner}
require.Equal(t, inner.Provider(), model.Provider())
}
func TestDebugModel_Model(t *testing.T) {
t.Parallel()
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
model := &debugModel{inner: inner}
require.Equal(t, inner.Model(), model.Model())
}
func TestDebugModel_Disabled(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
_, ok := StepFromContext(ctx)
require.False(t, ok)
require.Nil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{
ChatID: chatID,
OwnerID: ownerID,
},
}
resp, err := model.Generate(context.Background(), fantasy.Call{})
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_Generate(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
call := fantasy.Call{
Prompt: fantasy.Prompt{fantasy.NewUserMessage("hello")},
MaxOutputTokens: int64Ptr(128),
Temperature: float64Ptr(0.25),
}
respWant := &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "hello"},
fantasy.ToolCallContent{ToolCallID: "tool-1", ToolName: "tool", Input: `{}`},
fantasy.SourceContent{ID: "source-1", Title: "docs", URL: "https://example.com"},
},
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{InputTokens: 10, OutputTokens: 4, TotalTokens: 14},
Warnings: []fantasy.CallWarning{{Message: "warning"}},
}
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
require.True(t, params.NormalizedResponse.Valid)
require.True(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
// Clean successes (no prior error) leave the error column
// as SQL NULL rather than sending jsonClear.
require.False(t, params.Error.Valid)
require.False(t, params.Metadata.Valid)
// Verify actual JSON content so a broken tag or field
// rename is caught rather than only checking .Valid.
var usage fantasy.Usage
require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage))
require.EqualValues(t, 10, usage.InputTokens)
require.EqualValues(t, 4, usage.OutputTokens)
require.EqualValues(t, 14, usage.TotalTokens)
var resp map[string]any
require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp))
require.Equal(t, "stop", resp["finish_reason"])
})
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
require.Equal(t, call, got)
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationGenerate, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, call)
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.JSONEq(t, `{"message":"hello","api_key":"super-secret"}`,
string(body))
require.Equal(t, "Bearer top-secret", req.Header.Get("Authorization"))
rw.Header().Set("Content-Type", "application/json")
rw.Header().Set("X-API-Key", "response-secret")
rw.WriteHeader(http.StatusCreated)
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
}))
defer server.Close()
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
require.True(t, params.Attempts.Valid)
require.True(t, params.NormalizedResponse.Valid)
require.True(t, params.Usage.Valid)
var attempts []Attempt
require.NoError(t, json.Unmarshal(params.Attempts.RawMessage, &attempts))
require.Len(t, attempts, 1)
require.Equal(t, attemptStatusCompleted, attempts[0].Status)
require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus)
})
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
req, err := http.NewRequestWithContext(
ctx,
http.MethodPost,
server.URL,
strings.NewReader(`{"message":"hello","api_key":"super-secret"}`),
)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer top-secret")
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
require.NoError(t, resp.Body.Close())
return &fantasy.Response{FinishReason: fantasy.FinishReasonStop}, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, fantasy.Call{})
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestDebugModel_GenerateError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
wantErr := &testError{message: "boom"}
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate)
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
require.False(t, params.NormalizedResponse.Valid)
require.False(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
require.True(t, params.Error.Valid)
require.False(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateFn: func(context.Context, fantasy.Call) (*fantasy.Response, error) {
return nil, wantErr
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.Generate(ctx, fantasy.Call{})
require.Nil(t, resp)
require.ErrorIs(t, err, wantErr)
}
func TestStepStatusForError(t *testing.T) {
t.Parallel()
t.Run("Canceled", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, stepStatusForError(context.Canceled))
})
t.Run("DeadlineExceeded", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, stepStatusForError(context.DeadlineExceeded))
})
t.Run("OtherError", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, stepStatusForError(xerrors.New("boom")))
})
}
func TestDebugModel_Stream(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
errPart := xerrors.New("chunk failed")
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hel"},
{Type: fantasy.StreamPartTypeToolCall, ID: "tool-call-1", ToolCallName: "tool"},
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
{Type: fantasy.StreamPartTypeWarnings, Warnings: []fantasy.CallWarning{{Message: "w1"}, {Message: "w2"}}},
{Type: fantasy.StreamPartTypeError, Error: errPart},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
}
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
require.True(t, params.NormalizedResponse.Valid)
require.True(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
require.True(t, params.Error.Valid)
require.True(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationStream, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
got := make([]fantasy.StreamPart, 0, len(parts))
for part := range seq {
got = append(got, part)
}
require.Equal(t, parts, got)
}
func TestDebugModel_StreamObject(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
{Type: fantasy.ObjectStreamPartTypeObject, Object: map[string]any{"value": "object"}},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
}
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) {
require.True(t, params.NormalizedResponse.Valid)
require.True(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
// Clean successes (no prior error) leave the error column
// as SQL NULL rather than sending jsonClear.
require.False(t, params.Error.Valid)
require.True(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, int32(1), stepCtx.StepNumber)
require.Equal(t, OperationStream, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return objectPartsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
require.NoError(t, err)
got := make([]fantasy.ObjectStreamPart, 0, len(parts))
for part := range seq {
got = append(got, part)
}
require.Equal(t, parts, got)
}
// TestDebugModel_StreamCompletedAfterFinish verifies that when a consumer
// stops iteration after receiving a finish part, the step is marked as
// completed rather than interrupted.
func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
}
// The mock expectation for UpdateStep with StatusCompleted is the
// assertion: if the wrapper chose StatusInterrupted instead, the
// mock would reject the call.
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusCompleted, nil)
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
// Consumer reads the finish part then breaks — this should still
// be considered a completed stream, not interrupted.
for part := range seq {
if part.Type == fantasy.StreamPartTypeFinish {
break
}
}
// gomock verifies UpdateStep was called with StatusCompleted.
}
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
// stops iteration before receiving a finish part, the step is marked as
// interrupted.
func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeTextDelta, Delta: " world"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}
// The mock expectation for UpdateStep with StatusInterrupted is the
// assertion: breaking before the finish part means interrupted.
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, nil)
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
// Consumer reads the first delta then breaks before finish.
count := 0
for range seq {
count++
if count == 1 {
break
}
}
require.Equal(t, 1, count)
// gomock verifies UpdateStep was called with StatusInterrupted.
}
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
require.False(t, params.NormalizedResponse.Valid)
require.False(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
require.True(t, params.Error.Valid)
require.False(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
var nilStream fantasy.StreamResponse
return nilStream, nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.Nil(t, seq)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) {
require.False(t, params.NormalizedResponse.Valid)
require.False(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
require.True(t, params.Error.Valid)
require.True(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamObjectFn: func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
var nilStream fantasy.ObjectStreamResponse
return nilStream, nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
require.Nil(t, seq)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestDebugModel_StreamEarlyStop(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "first"},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
}
expectDebugLoggingEnabled(t, db, ownerID)
stepID := expectCreateStep(t, db, runID, chatID, OperationStream)
expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, func(params database.UpdateChatDebugStepParams) {
require.True(t, params.NormalizedResponse.Valid)
require.False(t, params.Usage.Valid)
require.True(t, params.Attempts.Valid)
require.False(t, params.Error.Valid)
require.True(t, params.Metadata.Valid)
})
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
return partsToSeq(parts), nil
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
seq, err := model.Stream(ctx, fantasy.Call{})
require.NoError(t, err)
count := 0
for part := range seq {
require.Equal(t, parts[0], part)
count++
break
}
require.Equal(t, 1, count)
}
func TestStreamErrorStatus(t *testing.T) {
t.Parallel()
t.Run("CancellationBecomesInterrupted", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.Canceled))
})
t.Run("DeadlineExceededBecomesInterrupted", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.DeadlineExceeded))
})
t.Run("NilErrorBecomesError", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, streamErrorStatus(StatusCompleted, nil))
})
t.Run("ExistingErrorWins", func(t *testing.T) {
t.Parallel()
require.Equal(t, StatusError, streamErrorStatus(StatusError, context.Canceled))
})
}
func objectPartsToSeq(parts []fantasy.ObjectStreamPart) fantasy.ObjectStreamResponse {
return func(yield func(fantasy.ObjectStreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
}
}
func partsToSeq(parts []fantasy.StreamPart) fantasy.StreamResponse {
return func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
}
}
func TestDebugModel_GenerateObject(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
call := fantasy.ObjectCall{
Prompt: fantasy.Prompt{fantasy.NewUserMessage("summarize")},
SchemaName: "Summary",
MaxOutputTokens: int64Ptr(256),
}
respWant := &fantasy.ObjectResponse{
RawText: `{"title":"test"}`,
FinishReason: fantasy.FinishReasonStop,
Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
}
svc := NewService(db, testutil.Logger(t), nil)
inner := &chattest.FakeModel{
GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
require.Equal(t, call, got)
stepCtx, ok := StepFromContext(ctx)
require.True(t, ok)
require.Equal(t, runID, stepCtx.RunID)
require.Equal(t, chatID, stepCtx.ChatID)
require.Equal(t, OperationGenerate, stepCtx.Operation)
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
require.NotNil(t, attemptSinkFromContext(ctx))
return respWant, nil
},
}
model := &debugModel{
inner: inner,
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, call)
require.NoError(t, err)
require.Same(t, respWant, resp)
}
func TestDebugModel_GenerateObjectError(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
wantErr := &testError{message: "object boom"}
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, wantErr
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
require.Nil(t, resp)
require.ErrorIs(t, err, wantErr)
}
func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
runID := uuid.New()
svc := NewService(db, testutil.Logger(t), nil)
model := &debugModel{
inner: &chattest.FakeModel{
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, nil //nolint:nilnil // Intentionally testing nil response handling.
},
},
svc: svc,
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
}
t.Cleanup(func() { CleanupStepCounter(runID) })
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
require.Nil(t, resp)
require.ErrorIs(t, err, ErrNilModelResult)
}
func TestWrapStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
// Create a context that we cancel after the stream finishes.
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
}
seq := wrapStreamSeq(ctx, handle, partsToSeq(parts))
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
for range seq {
}
// Cancel the context after the stream has been fully consumed
// and finalized. The status should remain completed.
cancel()
handle.mu.Lock()
status := handle.status
handle.mu.Unlock()
require.Equal(t, StatusCompleted, status)
}
func TestWrapObjectStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "obj"},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 1, TotalTokens: 4}},
}
seq := wrapObjectStreamSeq(ctx, handle, objectPartsToSeq(parts))
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
for range seq {
}
cancel()
handle.mu.Lock()
status := handle.status
handle.mu.Unlock()
require.Equal(t, StatusCompleted, status)
}
func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
ctx, cancel := context.WithCancel(context.Background())
parts := []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}
// Create the wrapped stream but never iterate it.
_ = wrapStreamSeq(ctx, handle, partsToSeq(parts))
// Cancel the context — the AfterFunc safety net should finalize
// the step as interrupted.
cancel()
// AfterFunc fires asynchronously; give it a moment.
require.Eventually(t, func() bool {
handle.mu.Lock()
defer handle.mu.Unlock()
return handle.status == StatusInterrupted
}, testutil.WaitShort, testutil.IntervalFast)
}
func int64Ptr(v int64) *int64 { return &v }
func float64Ptr(v float64) *float64 { return &v }
@@ -1,379 +0,0 @@
package chatdebug //nolint:testpackage // Uses unexported normalization helpers.
import (
"context"
"strings"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
func TestNormalizeCall_PreservesToolSchemasAndMessageToolPayloads(t *testing.T) {
t.Parallel()
payload := normalizeCall(fantasy.Call{
Prompt: fantasy.Prompt{
{
Role: fantasy.MessageRoleAssistant,
Content: []fantasy.MessagePart{
fantasy.ToolCallPart{
ToolCallID: "call-search",
ToolName: "search_docs",
Input: `{"query":"debug panel"}`,
},
},
},
{
Role: fantasy.MessageRoleTool,
Content: []fantasy.MessagePart{
fantasy.ToolResultPart{
ToolCallID: "call-search",
Output: fantasy.ToolResultOutputContentText{
Text: `{"matches":["model.go","DebugStepCard.tsx"]}`,
},
},
},
},
},
Tools: []fantasy.Tool{
fantasy.FunctionTool{
Name: "search_docs",
Description: "Searches documentation.",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
"required": []string{"query"},
},
},
},
})
require.Len(t, payload.Tools, 1)
require.True(t, payload.Tools[0].HasInputSchema)
require.JSONEq(t, `{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}`,
string(payload.Tools[0].InputSchema))
require.Len(t, payload.Messages, 2)
require.Equal(t, "tool-call", payload.Messages[0].Parts[0].Type)
require.Equal(t, `{"query":"debug panel"}`, payload.Messages[0].Parts[0].Arguments)
require.Equal(t, "tool-result", payload.Messages[1].Parts[0].Type)
require.Equal(t,
`{"matches":["model.go","DebugStepCard.tsx"]}`,
payload.Messages[1].Parts[0].Result,
)
}
func TestNormalizers_SkipTypedNilInterfaceValues(t *testing.T) {
t.Parallel()
t.Run("MessageParts", func(t *testing.T) {
t.Parallel()
var nilPart *fantasy.TextPart
parts := normalizeMessageParts([]fantasy.MessagePart{
nilPart,
fantasy.TextPart{Text: "hello"},
})
require.Len(t, parts, 1)
require.Equal(t, "text", parts[0].Type)
require.Equal(t, "hello", parts[0].Text)
})
t.Run("Tools", func(t *testing.T) {
t.Parallel()
var nilTool *fantasy.FunctionTool
tools := normalizeTools([]fantasy.Tool{
nilTool,
fantasy.FunctionTool{Name: "search_docs"},
})
require.Len(t, tools, 1)
require.Equal(t, "function", tools[0].Type)
require.Equal(t, "search_docs", tools[0].Name)
})
t.Run("ContentParts", func(t *testing.T) {
t.Parallel()
var nilContent *fantasy.TextContent
content := normalizeContentParts(fantasy.ResponseContent{
nilContent,
fantasy.TextContent{Text: "hello"},
})
require.Len(t, content, 1)
require.Equal(t, "text", content[0].Type)
require.Equal(t, "hello", content[0].Text)
})
}
func TestAppendNormalizedStreamContent_PreservesOrderAndCanonicalTypes(t *testing.T) {
t.Parallel()
var content []normalizedContentPart
streamDebugBytes := 0
for _, part := range []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: "before "},
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"query":"debug"}`},
{Type: fantasy.StreamPartTypeToolResult, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{"matches":1}`},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "after"},
} {
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
}
require.Equal(t, []normalizedContentPart{
{Type: "text", Text: "before "},
{Type: "tool-call", ToolCallID: "call-1", ToolName: "search_docs", Arguments: `{"query":"debug"}`, InputLength: len(`{"query":"debug"}`)},
{Type: "tool-result", ToolCallID: "call-1", ToolName: "search_docs", Result: `{"matches":1}`},
{Type: "text", Text: "after"},
}, content)
}
func TestAppendNormalizedStreamContent_GlobalTextCap(t *testing.T) {
t.Parallel()
streamDebugBytes := 0
long := strings.Repeat("a", maxStreamDebugTextBytes)
var content []normalizedContentPart
for _, part := range []fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextDelta, Delta: long},
{Type: fantasy.StreamPartTypeToolCall, ID: "call-1", ToolCallName: "search_docs", ToolCallInput: `{}`},
{Type: fantasy.StreamPartTypeTextDelta, Delta: "tail"},
} {
content = appendNormalizedStreamContent(content, part, &streamDebugBytes)
}
require.Len(t, content, 2)
require.Equal(t, strings.Repeat("a", maxStreamDebugTextBytes), content[0].Text)
require.Equal(t, "tool-call", content[1].Type)
require.Equal(t, maxStreamDebugTextBytes, streamDebugBytes)
}
func TestWrapStreamSeq_SourceCountExcludesToolResults(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
seq := wrapStreamSeq(context.Background(), handle, partsToSeq([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolResult, ID: "tool-1", ToolCallName: "search_docs"},
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}))
partCount := 0
for range seq {
partCount++
}
require.Equal(t, 3, partCount)
metadata, ok := handle.metadata.(map[string]any)
require.True(t, ok)
summary, ok := metadata["stream_summary"].(streamSummary)
require.True(t, ok)
require.Equal(t, 1, summary.SourceCount)
}
func TestWrapObjectStreamSeq_UsesStructuredOutputPayload(t *testing.T) {
t.Parallel()
handle := &stepHandle{
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
sink: &attemptSink{},
}
usage := fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5}
seq := wrapObjectStreamSeq(context.Background(), handle, objectPartsToSeq([]fantasy.ObjectStreamPart{
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: usage},
}))
partCount := 0
for range seq {
partCount++
}
require.Equal(t, 3, partCount)
resp, ok := handle.response.(normalizedObjectResponsePayload)
require.True(t, ok)
require.Equal(t, normalizedObjectResponsePayload{
RawTextLength: len("object"),
FinishReason: string(fantasy.FinishReasonStop),
Usage: normalizeUsage(usage),
StructuredOutput: true,
}, resp)
}
func TestNormalizeResponse_UsesCanonicalToolTypes(t *testing.T) {
t.Parallel()
payload := normalizeResponse(&fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.ToolCallContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Input: `{"operation":"add","operands":[2,2]}`,
},
fantasy.ToolResultContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Result: fantasy.ToolResultOutputContentText{Text: `{"sum":4}`},
},
},
})
require.Len(t, payload.Content, 2)
require.Equal(t, "tool-call", payload.Content[0].Type)
require.Equal(t, "tool-result", payload.Content[1].Type)
}
func TestBoundText_RespectsDocumentedRuneLimit(t *testing.T) {
t.Parallel()
runes := make([]rune, MaxMessagePartTextLength+5)
for i := range runes {
runes[i] = 'a'
}
input := string(runes)
got := boundText(input)
require.Equal(t, MaxMessagePartTextLength, len([]rune(got)))
require.Equal(t, '…', []rune(got)[len([]rune(got))-1])
}
func TestNormalizeToolResultOutput(t *testing.T) {
t.Parallel()
t.Run("TextValue", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{Text: "hello"})
require.Equal(t, "hello", got)
})
t.Run("TextPointer", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentText{Text: "hello"})
require.Equal(t, "hello", got)
})
t.Run("TextPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentText)(nil))
require.Equal(t, "", got)
})
t.Run("ErrorValue", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{
Error: xerrors.New("tool failed"),
})
require.Equal(t, "tool failed", got)
})
t.Run("ErrorValueNilError", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentError{Error: nil})
require.Equal(t, "", got)
})
t.Run("ErrorPointer", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{
Error: xerrors.New("ptr fail"),
})
require.Equal(t, "ptr fail", got)
})
t.Run("ErrorPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentError)(nil))
require.Equal(t, "", got)
})
t.Run("ErrorPointerNilError", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentError{Error: nil})
require.Equal(t, "", got)
})
t.Run("MediaWithText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
Text: "caption",
MediaType: "image/png",
})
require.Equal(t, "caption", got)
})
t.Run("MediaWithoutText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{
MediaType: "image/png",
})
require.Equal(t, "[media output: image/png]", got)
})
t.Run("MediaWithoutTextOrType", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentMedia{})
require.Equal(t, "[media output]", got)
})
t.Run("MediaPointerNil", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput((*fantasy.ToolResultOutputContentMedia)(nil))
require.Equal(t, "", got)
})
t.Run("MediaPointerWithText", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(&fantasy.ToolResultOutputContentMedia{
Text: "ptr caption",
MediaType: "image/jpeg",
})
require.Equal(t, "ptr caption", got)
})
t.Run("NilOutput", func(t *testing.T) {
t.Parallel()
got := normalizeToolResultOutput(nil)
require.Equal(t, "", got)
})
t.Run("DefaultJSON", func(t *testing.T) {
t.Parallel()
// An unexpected type falls through to the default JSON
// marshal branch.
got := normalizeToolResultOutput(fantasy.ToolResultOutputContentText{
Text: "fallback",
})
require.Equal(t, "fallback", got)
})
}
func TestNormalizeResponse_PreservesToolCallArguments(t *testing.T) {
t.Parallel()
payload := normalizeResponse(&fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.ToolCallContent{
ToolCallID: "call-calc",
ToolName: "calculator",
Input: `{"operation":"add","operands":[2,2]}`,
},
},
})
require.Len(t, payload.Content, 1)
require.Equal(t, "call-calc", payload.Content[0].ToolCallID)
require.Equal(t, "calculator", payload.Content[0].ToolName)
require.JSONEq(t,
`{"operation":"add","operands":[2,2]}`,
payload.Content[0].Arguments,
)
require.Equal(t, len(`{"operation":"add","operands":[2,2]}`), payload.Content[0].InputLength)
}
-319
View File
@@ -1,319 +0,0 @@
package chatdebug
import (
"context"
"sync"
"sync/atomic"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"cdr.dev/slog/v3"
)
// RecorderOptions identifies the chat/model context for debug recording.
type RecorderOptions struct {
ChatID uuid.UUID
OwnerID uuid.UUID
Provider string
Model string
}
// WrapModel returns model unchanged when debug recording is disabled, or a
// debug wrapper when a service is available.
func WrapModel(
model fantasy.LanguageModel,
svc *Service,
opts RecorderOptions,
) fantasy.LanguageModel {
if model == nil {
panic("chatdebug: nil LanguageModel")
}
if svc == nil {
return model
}
return &debugModel{inner: model, svc: svc, opts: opts}
}
type attemptSink struct {
mu sync.Mutex
attempts []Attempt
attemptCounter atomic.Int32
}
func (s *attemptSink) nextAttemptNumber() int {
if s == nil {
panic("chatdebug: nil attemptSink")
}
return int(s.attemptCounter.Add(1))
}
func (s *attemptSink) record(a Attempt) {
s.mu.Lock()
defer s.mu.Unlock()
s.attempts = append(s.attempts, a)
}
func (s *attemptSink) snapshot() []Attempt {
s.mu.Lock()
defer s.mu.Unlock()
attempts := make([]Attempt, len(s.attempts))
copy(attempts, s.attempts)
return attempts
}
type attemptSinkKey struct{}
func withAttemptSink(ctx context.Context, sink *attemptSink) context.Context {
if sink == nil {
panic("chatdebug: nil attemptSink")
}
return context.WithValue(ctx, attemptSinkKey{}, sink)
}
func attemptSinkFromContext(ctx context.Context) *attemptSink {
sink, _ := ctx.Value(attemptSinkKey{}).(*attemptSink)
return sink
}
var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32
// runRefCounts tracks how many live RunContext instances reference each
// RunID. Cleanup of shared state (step counters) is deferred until the
// last RunContext for a given RunID is garbage collected.
var runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32
func trackRunRef(runID uuid.UUID) {
val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{})
counter := val.(*atomic.Int32)
counter.Add(1)
}
// releaseRunRef decrements the reference count for runID and cleans up
// shared state when the last reference is released.
func releaseRunRef(runID uuid.UUID) {
val, ok := runRefCounts.Load(runID)
if !ok {
return
}
counter := val.(*atomic.Int32)
if counter.Add(-1) <= 0 {
runRefCounts.Delete(runID)
stepCounters.Delete(runID)
}
}
func nextStepNumber(runID uuid.UUID) int32 {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
return counter.Add(1)
}
// CleanupStepCounter removes per-run step counter and reference count
// state. This is used by tests and later stacked branches that have a
// real run lifecycle.
func CleanupStepCounter(runID uuid.UUID) {
stepCounters.Delete(runID)
runRefCounts.Delete(runID)
}
const stepFinalizeTimeout = 5 * time.Second
func stepFinalizeContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
panic("chatdebug: nil context")
}
return context.WithTimeout(context.WithoutCancel(ctx), stepFinalizeTimeout)
}
func syncStepCounter(runID uuid.UUID, stepNumber int32) {
val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{})
counter, ok := val.(*atomic.Int32)
if !ok {
panic("chatdebug: invalid step counter type")
}
for {
current := counter.Load()
if current >= stepNumber {
return
}
if counter.CompareAndSwap(current, stepNumber) {
return
}
}
}
type stepHandle struct {
stepCtx *StepContext
sink *attemptSink
svc *Service
opts RecorderOptions
mu sync.Mutex
status Status
response any
usage any
err any
metadata any
// hadError tracks whether a prior finalization wrote an error
// payload. Used to decide whether a successful retry needs to
// explicitly clear the error field via jsonClear.
hadError bool
}
// beginStep validates preconditions, creates a debug step, and returns a
// handle plus an enriched context carrying StepContext and attemptSink.
// Returns (nil, original ctx) when debug recording should be skipped.
func beginStep(
ctx context.Context,
svc *Service,
opts RecorderOptions,
op Operation,
normalizedReq any,
) (*stepHandle, context.Context) {
if svc == nil {
return nil, ctx
}
rc, ok := RunFromContext(ctx)
if !ok || rc.RunID == uuid.Nil {
return nil, ctx
}
chatID := opts.ChatID
if chatID == uuid.Nil {
chatID = rc.ChatID
}
if !svc.IsEnabled(ctx, chatID, opts.OwnerID) {
return nil, ctx
}
holder, reuseStep := reuseHolderFromContext(ctx)
if reuseStep {
holder.mu.Lock()
defer holder.mu.Unlock()
// Only reuse the cached handle if it belongs to the same run.
// A different RunContext means a new logical run, so we must
// create a fresh step to avoid cross-run attribution.
if holder.handle != nil && holder.handle.stepCtx.RunID == rc.RunID {
enriched := ContextWithStep(ctx, holder.handle.stepCtx)
enriched = withAttemptSink(enriched, holder.handle.sink)
return holder.handle, enriched
}
}
stepNum := nextStepNumber(rc.RunID)
step, err := svc.CreateStep(ctx, CreateStepParams{
RunID: rc.RunID,
ChatID: chatID,
StepNumber: stepNum,
Operation: op,
Status: StatusInProgress,
HistoryTipMessageID: rc.HistoryTipMessageID,
NormalizedRequest: normalizedReq,
})
if err != nil {
svc.log.Warn(ctx, "failed to create chat debug step",
slog.Error(err),
slog.F("chat_id", chatID),
slog.F("run_id", rc.RunID),
slog.F("operation", op),
)
return nil, ctx
}
syncStepCounter(rc.RunID, step.StepNumber)
actualStepNumber := step.StepNumber
if actualStepNumber == 0 {
actualStepNumber = stepNum
}
sc := &StepContext{
StepID: step.ID,
RunID: rc.RunID,
ChatID: chatID,
StepNumber: actualStepNumber,
Operation: op,
HistoryTipMessageID: rc.HistoryTipMessageID,
}
handle := &stepHandle{stepCtx: sc, sink: &attemptSink{}, svc: svc, opts: opts}
enriched := ContextWithStep(ctx, handle.stepCtx)
enriched = withAttemptSink(enriched, handle.sink)
if reuseStep {
holder.handle = handle
}
return handle, enriched
}
// finish updates the debug step with final status and data. A mutex
// guards the write so concurrent callers (e.g. retried stream wrappers
// sharing a reuse handle) don't race. Unlike sync.Once, later retries
// are allowed to overwrite earlier failure results so the step reflects
// the final outcome.
func (h *stepHandle) finish(
ctx context.Context,
status Status,
response any,
usage any,
errPayload any,
metadata any,
) {
if h == nil || h.stepCtx == nil {
return
}
h.mu.Lock()
defer h.mu.Unlock()
h.status = status
h.response = response
h.usage = usage
h.err = errPayload
h.metadata = metadata
if errPayload != nil {
h.hadError = true
}
if h.svc == nil {
return
}
updateCtx, cancel := stepFinalizeContext(ctx)
defer cancel()
// When the step completes successfully after a prior failed
// attempt, the error field must be explicitly cleared. A plain
// nil would leave the COALESCE-based SQL untouched, so we send
// jsonClear{} which serializes as a valid JSONB null. Only do
// this when a prior error was actually recorded — otherwise
// clean successes would get a spurious JSONB null that downstream
// aggregation could misread as an error.
errValue := errPayload
if errValue == nil && status == StatusCompleted && h.hadError {
errValue = jsonClear{}
}
if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{
ID: h.stepCtx.StepID,
ChatID: h.stepCtx.ChatID,
Status: status,
NormalizedResponse: response,
Usage: usage,
Attempts: h.sink.snapshot(),
Error: errValue,
Metadata: metadata,
FinishedAt: time.Now(),
}); updateErr != nil {
h.svc.log.Warn(updateCtx, "failed to finalize chat debug step",
slog.Error(updateErr),
slog.F("step_id", h.stepCtx.StepID),
slog.F("chat_id", h.stepCtx.ChatID),
slog.F("status", status),
)
}
}
-184
View File
@@ -1,184 +0,0 @@
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"sort"
"sync"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/testutil"
)
func TestAttemptSink_ThreadSafe(t *testing.T) {
t.Parallel()
const n = 256
sink := &attemptSink{}
var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func() {
defer wg.Done()
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
}()
}
wg.Wait()
attempts := sink.snapshot()
require.Len(t, attempts, n)
numbers := make([]int, 0, n)
statuses := make([]int, 0, n)
for _, attempt := range attempts {
numbers = append(numbers, attempt.Number)
statuses = append(statuses, attempt.ResponseStatus)
}
sort.Ints(numbers)
sort.Ints(statuses)
for i := range n {
require.Equal(t, i+1, numbers[i])
require.Equal(t, 200+i, statuses[i])
}
}
func TestAttemptSinkContext(t *testing.T) {
t.Parallel()
ctx := context.Background()
require.Nil(t, attemptSinkFromContext(ctx))
sink := &attemptSink{}
ctx = withAttemptSink(ctx, sink)
require.Same(t, sink, attemptSinkFromContext(ctx))
}
func TestWrapModel_NilModel(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
WrapModel(nil, &Service{}, RecorderOptions{})
})
}
func TestWrapModel_NilService(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, nil, RecorderOptions{})
require.Same(t, model, wrapped)
}
func TestNextStepNumber_Concurrent(t *testing.T) {
t.Parallel()
const n = 256
runID := uuid.New()
results := make([]int, n)
var wg sync.WaitGroup
wg.Add(n)
for i := range n {
go func() {
defer wg.Done()
results[i] = int(nextStepNumber(runID))
}()
}
wg.Wait()
sort.Ints(results)
for i := range n {
require.Equal(t, i+1, results[i])
}
}
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
t.Parallel()
baseCtx, cancelBase := context.WithCancel(context.Background())
cancelBase()
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
defer cancelFinalize()
require.NoError(t, finalizeCtx.Err())
_, hasDeadline := finalizeCtx.Deadline()
require.True(t, hasDeadline)
}
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
t.Parallel()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
syncStepCounter(runID, 7)
require.Equal(t, int32(8), nextStepNumber(runID))
}
func TestStepHandleFinish_NilHandle(t *testing.T) {
t.Parallel()
var handle *stepHandle
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
}
func TestBeginStep_NilService(t *testing.T) {
t.Parallel()
ctx := context.Background()
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
require.Nil(t, handle)
require.Nil(t, attemptSinkFromContext(enriched))
_, ok := StepFromContext(enriched)
require.False(t, ok)
}
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
runID := uuid.New()
runChatID := uuid.New()
ownerID := uuid.New()
expectDebugLoggingEnabled(t, db, ownerID)
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
svc := NewService(db, testutil.Logger(t), nil)
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
require.NotNil(t, handle)
require.Equal(t, runChatID, handle.stepCtx.ChatID)
stepCtx, ok := StepFromContext(enriched)
require.True(t, ok)
require.Equal(t, runChatID, stepCtx.ChatID)
}
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
require.NotSame(t, model, wrapped)
require.IsType(t, &debugModel{}, wrapped)
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
require.Equal(t, model.Provider(), wrapped.Provider())
require.Equal(t, model.Model(), wrapped.Model())
}
-227
View File
@@ -1,227 +0,0 @@
package chatdebug
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"golang.org/x/xerrors"
)
// RedactedValue replaces sensitive values in debug payloads.
const RedactedValue = "[REDACTED]"
var sensitiveHeaderNames = map[string]struct{}{
"authorization": {},
"x-api-key": {},
"api-key": {},
"proxy-authorization": {},
"cookie": {},
"set-cookie": {},
}
// sensitiveJSONKeyFragments triggers redaction for JSON keys containing
// these substrings. Notably, "token" is intentionally absent because it
// false-positively redacts LLM token-usage fields (input_tokens,
// output_tokens, prompt_tokens, completion_tokens, reasoning_tokens,
// cache_creation_input_tokens, cache_read_input_tokens, etc.). Auth-
// related token fields are caught by the exact-match set below.
var sensitiveJSONKeyFragments = []string{
"secret",
"password",
"authorization",
"credential",
}
// sensitiveJSONKeyExact matches auth-related token/key field names
// without false-positiving on LLM usage counters. Includes both
// snake_case originals and their camelCase-lowered equivalents
// (e.g. "accessToken" → "accesstoken") so that providers using
// either convention are caught.
var sensitiveJSONKeyExact = map[string]struct{}{
"token": {},
"access_token": {},
"accesstoken": {},
"refresh_token": {},
"refreshtoken": {},
"id_token": {},
"idtoken": {},
"api_token": {},
"apitoken": {},
"api_key": {},
"apikey": {},
"api-key": {},
"x-api-key": {},
"auth_token": {},
"authtoken": {},
"bearer_token": {},
"bearertoken": {},
"session_token": {},
"sessiontoken": {},
"security_token": {},
"securitytoken": {},
"private_key": {},
"privatekey": {},
"signing_key": {},
"signingkey": {},
"secret_key": {},
"secretkey": {},
}
// RedactHeaders returns a flattened copy of h with sensitive values redacted.
func RedactHeaders(h http.Header) map[string]string {
if h == nil {
return nil
}
redacted := make(map[string]string, len(h))
for name, values := range h {
if isSensitiveName(name) {
redacted[name] = RedactedValue
continue
}
redacted[name] = strings.Join(values, ", ")
}
return redacted
}
// RedactJSONSecrets redacts sensitive JSON values by key name. When
// the input is not valid JSON (truncated body, HTML error page, etc.)
// the raw bytes are replaced entirely with a diagnostic placeholder
// to avoid leaking credentials from malformed payloads.
func RedactJSONSecrets(data []byte) []byte {
if len(data) == 0 {
return data
}
decoder := json.NewDecoder(bytes.NewReader(data))
decoder.UseNumber()
var value any
if err := decoder.Decode(&value); err != nil {
// Cannot parse: replace entirely to prevent credential leaks
// from non-JSON error responses (HTML pages, partial bodies).
return []byte(`{"error":"chatdebug: body is not valid JSON, redacted for safety"}`)
}
if err := consumeJSONEOF(decoder); err != nil {
return []byte(`{"error":"chatdebug: body contains extra JSON values, redacted for safety"}`)
}
redacted, changed := redactJSONValue(value)
if !changed {
return data
}
encoded, err := json.Marshal(redacted)
if err != nil {
return data
}
return encoded
}
func consumeJSONEOF(decoder *json.Decoder) error {
var extra any
err := decoder.Decode(&extra)
if errors.Is(err, io.EOF) {
return nil
}
if err == nil {
return xerrors.New("chatdebug: extra JSON values")
}
return err
}
var safeRateLimitHeaderNames = map[string]struct{}{
"anthropic-ratelimit-requests-limit": {},
"anthropic-ratelimit-requests-remaining": {},
"anthropic-ratelimit-requests-reset": {},
"anthropic-ratelimit-tokens-limit": {},
"anthropic-ratelimit-tokens-remaining": {},
"anthropic-ratelimit-tokens-reset": {},
"x-ratelimit-limit-requests": {},
"x-ratelimit-limit-tokens": {},
"x-ratelimit-remaining-requests": {},
"x-ratelimit-remaining-tokens": {},
"x-ratelimit-reset-requests": {},
"x-ratelimit-reset-tokens": {},
}
// isSensitiveName reports whether a name (header or query parameter)
// looks like a credential-carrying key. Exact-match headers are
// checked first, then the rate-limit allowlist, then substring
// patterns for API keys and auth tokens.
func isSensitiveName(name string) bool {
lowerName := strings.ToLower(name)
if _, ok := sensitiveHeaderNames[lowerName]; ok {
return true
}
if _, ok := safeRateLimitHeaderNames[lowerName]; ok {
return false
}
if strings.Contains(lowerName, "api-key") ||
strings.Contains(lowerName, "api_key") ||
strings.Contains(lowerName, "apikey") {
return true
}
// Catch any header containing "token" (e.g. Token, X-Token,
// X-Auth-Token). Safe rate-limit headers like
// x-ratelimit-remaining-tokens are already allowlisted above
// and will not reach this point.
if strings.Contains(lowerName, "token") {
return true
}
return strings.Contains(lowerName, "secret") ||
strings.Contains(lowerName, "bearer")
}
func isSensitiveJSONKey(key string) bool {
lowerKey := strings.ToLower(key)
if _, ok := sensitiveJSONKeyExact[lowerKey]; ok {
return true
}
for _, fragment := range sensitiveJSONKeyFragments {
if strings.Contains(lowerKey, fragment) {
return true
}
}
return false
}
func redactJSONValue(value any) (any, bool) {
switch typed := value.(type) {
case map[string]any:
changed := false
for key, child := range typed {
if isSensitiveJSONKey(key) {
if current, ok := child.(string); ok && current == RedactedValue {
continue
}
typed[key] = RedactedValue
changed = true
continue
}
redactedChild, childChanged := redactJSONValue(child)
if childChanged {
typed[key] = redactedChild
changed = true
}
}
return typed, changed
case []any:
changed := false
for i, child := range typed {
redactedChild, childChanged := redactJSONValue(child)
if childChanged {
typed[i] = redactedChild
changed = true
}
}
return typed, changed
default:
return value, false
}
}
-277
View File
@@ -1,277 +0,0 @@
package chatdebug_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
)
func TestRedactHeaders(t *testing.T) {
t.Parallel()
t.Run("nil input", func(t *testing.T) {
t.Parallel()
require.Nil(t, chatdebug.RedactHeaders(nil))
})
t.Run("empty header", func(t *testing.T) {
t.Parallel()
redacted := chatdebug.RedactHeaders(http.Header{})
require.NotNil(t, redacted)
require.Empty(t, redacted)
})
t.Run("authorization redacted and others preserved", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Authorization": {"Bearer secret-token"},
"Accept": {"application/json"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
require.Equal(t, "application/json", redacted["Accept"])
})
t.Run("multi-value headers are flattened", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Accept": {"application/json", "text/plain"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "application/json, text/plain", redacted["Accept"])
})
t.Run("header name matching is case insensitive", func(t *testing.T) {
t.Parallel()
lowerAuthorization := "authorization"
upperAuthorization := "AUTHORIZATION"
headers := http.Header{
lowerAuthorization: {"lower"},
upperAuthorization: {"upper"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted[lowerAuthorization])
require.Equal(t, chatdebug.RedactedValue, redacted[upperAuthorization])
})
t.Run("token and secret substrings are redacted", func(t *testing.T) {
t.Parallel()
traceHeader := "X-Trace-ID"
headers := http.Header{
"X-Auth-Token": {"abc"},
"X-Custom-Secret": {"def"},
"X-Bearer": {"ghi"},
traceHeader: {"trace"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Auth-Token"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Bearer"])
require.Equal(t, "trace", redacted[traceHeader])
})
t.Run("known safe rate limit headers containing token are not redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Anthropic-Ratelimit-Tokens-Limit": {"1000000"},
"Anthropic-Ratelimit-Tokens-Remaining": {"999000"},
"Anthropic-Ratelimit-Tokens-Reset": {"2026-03-31T08:55:26Z"},
"X-RateLimit-Limit-Tokens": {"120000"},
"X-RateLimit-Remaining-Tokens": {"119500"},
"X-RateLimit-Reset-Tokens": {"12ms"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "1000000", redacted["Anthropic-Ratelimit-Tokens-Limit"])
require.Equal(t, "999000", redacted["Anthropic-Ratelimit-Tokens-Remaining"])
require.Equal(t, "2026-03-31T08:55:26Z", redacted["Anthropic-Ratelimit-Tokens-Reset"])
require.Equal(t, "120000", redacted["X-RateLimit-Limit-Tokens"])
require.Equal(t, "119500", redacted["X-RateLimit-Remaining-Tokens"])
require.Equal(t, "12ms", redacted["X-RateLimit-Reset-Tokens"])
})
t.Run("non-standard headers with api-key pattern are redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"X-Custom-Api-Key": {"secret-key"},
"X-Custom-Secret": {"secret-val"},
"X-Custom-Session-Token": {"session-id"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Api-Key"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Secret"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Custom-Session-Token"])
})
t.Run("rate limit headers with token in name are preserved", func(t *testing.T) {
t.Parallel()
// Rate-limit headers containing "token" should NOT be redacted
// because they carry usage/limit counts, not credentials.
headers := http.Header{
"X-Ratelimit-Limit-Tokens": {"1000000"},
"X-Ratelimit-Remaining-Tokens": {"999000"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, "1000000", redacted["X-Ratelimit-Limit-Tokens"])
require.Equal(t, "999000", redacted["X-Ratelimit-Remaining-Tokens"])
})
t.Run("original header is not modified", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"Authorization": {"Bearer keep-me"},
"X-Test": {"value"},
}
redacted := chatdebug.RedactHeaders(headers)
redacted["X-Test"] = "changed"
require.Equal(t, []string{"Bearer keep-me"}, headers["Authorization"])
require.Equal(t, []string{"value"}, headers["X-Test"])
require.Equal(t, chatdebug.RedactedValue, redacted["Authorization"])
})
t.Run("api-key header variants are redacted", func(t *testing.T) {
t.Parallel()
headers := http.Header{
"X-Goog-Api-Key": {"secret"},
"X-Api_Key": {"other-secret"},
"X-Safe": {"ok"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["X-Goog-Api-Key"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Api_Key"])
require.Equal(t, "ok", redacted["X-Safe"])
})
t.Run("plain token headers are redacted", func(t *testing.T) {
t.Parallel()
// Headers like "Token" or "X-Token" should be redacted
// even without auth/session/access qualifiers.
headers := http.Header{
"Token": {"my-secret-token"},
"X-Token": {"another-secret"},
"X-Safe": {"ok"},
}
redacted := chatdebug.RedactHeaders(headers)
require.Equal(t, chatdebug.RedactedValue, redacted["Token"])
require.Equal(t, chatdebug.RedactedValue, redacted["X-Token"])
require.Equal(t, "ok", redacted["X-Safe"])
})
}
func TestRedactJSONSecrets(t *testing.T) {
t.Parallel()
t.Run("redacts top level secret fields", func(t *testing.T) {
t.Parallel()
input := []byte(`{"api_key":"abc","token":"def","password":"ghi","safe":"ok"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"api_key":"[REDACTED]","token":"[REDACTED]","password":"[REDACTED]","safe":"ok"}`, string(redacted))
})
t.Run("redacts security_token exact key", func(t *testing.T) {
t.Parallel()
input := []byte(`{"security_token":"s3cret","securityToken":"tok","safe":"ok"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"security_token":"[REDACTED]","securityToken":"[REDACTED]","safe":"ok"}`, string(redacted))
})
t.Run("preserves LLM token usage fields", func(t *testing.T) {
t.Parallel()
input := []byte(`{"input_tokens":100,"output_tokens":50,"prompt_tokens":80,"completion_tokens":20,"reasoning_tokens":10,"cache_creation_input_tokens":5,"cache_read_input_tokens":3,"total_tokens":150,"max_tokens":4096,"max_output_tokens":2048}`)
redacted := chatdebug.RedactJSONSecrets(input)
// All usage/limit fields should be preserved, not redacted.
require.Equal(t, input, redacted)
})
t.Run("redacts nested objects", func(t *testing.T) {
t.Parallel()
input := []byte(`{"outer":{"nested_secret":"abc","safe":1},"keep":true}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"outer":{"nested_secret":"[REDACTED]","safe":1},"keep":true}`, string(redacted))
})
t.Run("redacts arrays of objects", func(t *testing.T) {
t.Parallel()
input := []byte(`[{"token":"abc"},{"value":1,"credentials":{"access_key":"def"}}]`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `[{"token":"[REDACTED]"},{"value":1,"credentials":"[REDACTED]"}]`, string(redacted))
})
t.Run("concatenated JSON is replaced with diagnostic", func(t *testing.T) {
t.Parallel()
input := []byte(`{"token":"abc"}{"safe":"ok"}`)
result := chatdebug.RedactJSONSecrets(input)
require.Contains(t, string(result), "extra JSON values")
})
t.Run("non JSON input is replaced with diagnostic", func(t *testing.T) {
t.Parallel()
input := []byte("not json")
result := chatdebug.RedactJSONSecrets(input)
require.Contains(t, string(result), "not valid JSON")
})
t.Run("empty input is unchanged", func(t *testing.T) {
t.Parallel()
input := []byte{}
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
})
t.Run("JSON without sensitive keys is unchanged", func(t *testing.T) {
t.Parallel()
input := []byte(`{"safe":"ok","nested":{"value":1}}`)
require.Equal(t, input, chatdebug.RedactJSONSecrets(input))
})
t.Run("key matching is case insensitive", func(t *testing.T) {
t.Parallel()
input := []byte(`{"API_KEY":"abc","Token":"def","PASSWORD":"ghi"}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"API_KEY":"[REDACTED]","Token":"[REDACTED]","PASSWORD":"[REDACTED]"}`, string(redacted))
})
t.Run("camelCase token field names are redacted", func(t *testing.T) {
t.Parallel()
// Providers may use camelCase (e.g. accessToken, refreshToken).
// These should be redacted even though they don't match the
// snake_case originals exactly.
input := []byte(`{"accessToken":"abc","refreshToken":"def","authToken":"ghi","input_tokens":100,"output_tokens":50}`)
redacted := chatdebug.RedactJSONSecrets(input)
require.JSONEq(t, `{"accessToken":"[REDACTED]","refreshToken":"[REDACTED]","authToken":"[REDACTED]","input_tokens":100,"output_tokens":50}`, string(redacted))
})
}
-113
View File
@@ -1,113 +0,0 @@
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
)
func TestBeginStepReuseStep(t *testing.T) {
t.Parallel()
t.Run("reuses handle under ReuseStep", func(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
expectDebugLoggingEnabled(t, db, ownerID)
expectCreateStepNumberWithRequestValidity(
t,
db,
runID,
chatID,
1,
OperationStream,
false,
)
expectDebugLoggingEnabled(t, db, ownerID)
svc := NewService(db, testutil.Logger(t), nil)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
ctx = ReuseStep(ctx)
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
firstHandle, firstEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
secondHandle, secondEnriched := beginStep(ctx, svc, opts, OperationStream, nil)
require.NotNil(t, firstHandle)
require.Same(t, firstHandle, secondHandle)
require.Same(t, firstHandle.stepCtx, secondHandle.stepCtx)
require.Same(t, firstHandle.sink, secondHandle.sink)
require.Equal(t, runID, firstHandle.stepCtx.RunID)
require.Equal(t, chatID, firstHandle.stepCtx.ChatID)
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
require.Equal(t, OperationStream, firstHandle.stepCtx.Operation)
require.NotEqual(t, uuid.Nil, firstHandle.stepCtx.StepID)
firstStepCtx, ok := StepFromContext(firstEnriched)
require.True(t, ok)
secondStepCtx, ok := StepFromContext(secondEnriched)
require.True(t, ok)
require.Same(t, firstStepCtx, secondStepCtx)
require.Same(t, firstHandle.stepCtx, firstStepCtx)
require.Same(t, attemptSinkFromContext(firstEnriched), attemptSinkFromContext(secondEnriched))
})
t.Run("creates new handles without ReuseStep", func(t *testing.T) {
t.Parallel()
chatID := uuid.New()
ownerID := uuid.New()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
expectDebugLoggingEnabled(t, db, ownerID)
expectCreateStepNumberWithRequestValidity(
t,
db,
runID,
chatID,
1,
OperationStream,
false,
)
expectDebugLoggingEnabled(t, db, ownerID)
expectCreateStepNumberWithRequestValidity(
t,
db,
runID,
chatID,
2,
OperationStream,
false,
)
svc := NewService(db, testutil.Logger(t), nil)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID}
firstHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
secondHandle, _ := beginStep(ctx, svc, opts, OperationStream, nil)
require.NotNil(t, firstHandle)
require.NotNil(t, secondHandle)
require.NotSame(t, firstHandle, secondHandle)
require.NotSame(t, firstHandle.sink, secondHandle.sink)
require.Equal(t, int32(1), firstHandle.stepCtx.StepNumber)
require.Equal(t, int32(2), secondHandle.stepCtx.StepNumber)
require.NotEqual(t, firstHandle.stepCtx.StepID, secondHandle.stepCtx.StepID)
})
}
-539
View File
@@ -1,539 +0,0 @@
package chatdebug
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/pubsub"
)
// DefaultStaleThreshold is the fallback stale timeout for debug rows
// when no caller-provided value is supplied.
const DefaultStaleThreshold = 5 * time.Minute
// Service persists chat debug rows and fans out lightweight change events.
type Service struct {
db database.Store
log slog.Logger
pubsub pubsub.Pubsub
alwaysEnable bool
// staleAfterNanos stores the stale threshold as nanoseconds in an
// atomic.Int64 so SetStaleAfter and FinalizeStale can be called
// from concurrent goroutines without a data race.
staleAfterNanos atomic.Int64
}
// ServiceOption configures optional Service behavior.
type ServiceOption func(*Service)
// WithStaleThreshold overrides the default stale-row finalization
// threshold. Callers that already have a configurable in-flight chat
// timeout (e.g. chatd's InFlightChatStaleAfter) should pass it here
// so the two sweeps stay in sync.
func WithStaleThreshold(d time.Duration) ServiceOption {
return func(s *Service) {
if d > 0 {
s.staleAfterNanos.Store(d.Nanoseconds())
}
}
}
// WithAlwaysEnable forces debug logging on for every chat regardless
// of the runtime admin and user opt-in settings. This is used for the
// deployment-level serpent flag.
func WithAlwaysEnable(always bool) ServiceOption {
return func(s *Service) {
s.alwaysEnable = always
}
}
// CreateRunParams contains friendly inputs for creating a debug run.
type CreateRunParams struct {
ChatID uuid.UUID
RootChatID uuid.UUID
ParentChatID uuid.UUID
ModelConfigID uuid.UUID
TriggerMessageID int64
HistoryTipMessageID int64
Kind RunKind
Status Status
Provider string
Model string
Summary any
}
// UpdateRunParams contains inputs for updating a debug run.
// Zero-valued fields are treated as "keep the existing value" by the
// COALESCE-based SQL query. Once a field is set it cannot be cleared
// back to NULL — this is intentional for the write-once-finalize
// lifecycle of debug rows.
type UpdateRunParams struct {
ID uuid.UUID
ChatID uuid.UUID
Status Status
Summary any
FinishedAt time.Time
}
// CreateStepParams contains friendly inputs for creating a debug step.
type CreateStepParams struct {
RunID uuid.UUID
ChatID uuid.UUID
StepNumber int32
Operation Operation
Status Status
HistoryTipMessageID int64
NormalizedRequest any
}
// UpdateStepParams contains optional inputs for updating a debug step.
// Most payload fields are typed as any and serialized through nullJSON
// because their shape varies by provider. The Attempts field uses a
// concrete slice for compile-time safety where the schema is stable.
// Zero-valued fields are treated as "keep the existing value" by the
// COALESCE-based SQL query — once set, fields cannot be cleared back
// to NULL. This is intentional for the write-once-finalize lifecycle
// of debug rows.
type UpdateStepParams struct {
ID uuid.UUID
ChatID uuid.UUID
Status Status
AssistantMessageID int64
NormalizedResponse any
Usage any
Attempts []Attempt
Error any
Metadata any
FinishedAt time.Time
}
// NewService constructs a chat debug persistence service.
func NewService(db database.Store, log slog.Logger, ps pubsub.Pubsub, opts ...ServiceOption) *Service {
if db == nil {
panic("chatdebug: nil database.Store")
}
s := &Service{
db: db,
log: log,
pubsub: ps,
}
s.staleAfterNanos.Store(DefaultStaleThreshold.Nanoseconds())
for _, opt := range opts {
opt(s)
}
return s
}
// SetStaleAfter overrides the in-flight stale threshold used when
// finalizing abandoned debug rows. Zero or negative durations keep the
// default threshold.
func (s *Service) SetStaleAfter(staleAfter time.Duration) {
if s == nil || staleAfter <= 0 {
return
}
s.staleAfterNanos.Store(staleAfter.Nanoseconds())
}
func chatdContext(ctx context.Context) context.Context {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
// chat debug persistence reads and writes.
return dbauthz.AsChatd(ctx)
}
// IsEnabled returns whether debug logging is enabled for the given chat.
func (s *Service) IsEnabled(
ctx context.Context,
chatID uuid.UUID,
ownerID uuid.UUID,
) bool {
if s == nil {
return false
}
if s.alwaysEnable {
return true
}
if s.db == nil {
return false
}
authCtx := chatdContext(ctx)
allowUsers, err := s.db.GetChatDebugLoggingEnabled(authCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false
}
s.log.Warn(ctx, "failed to load runtime admin chat debug logging setting",
slog.Error(err),
)
return false
}
if !allowUsers {
return false
}
if ownerID == uuid.Nil {
s.log.Warn(ctx, "missing chat owner for debug logging enablement check",
slog.F("chat_id", chatID),
)
return false
}
enabled, err := s.db.GetUserChatDebugLoggingEnabled(authCtx, ownerID)
if err == nil {
return enabled
}
if errors.Is(err, sql.ErrNoRows) {
return false
}
s.log.Warn(ctx, "failed to load user chat debug logging setting",
slog.Error(err),
slog.F("chat_id", chatID),
slog.F("owner_id", ownerID),
)
return false
}
// CreateRun inserts a new debug run and emits a run update event.
func (s *Service) CreateRun(
ctx context.Context,
params CreateRunParams,
) (database.ChatDebugRun, error) {
run, err := s.db.InsertChatDebugRun(chatdContext(ctx),
database.InsertChatDebugRunParams{
ChatID: params.ChatID,
RootChatID: nullUUID(params.RootChatID),
ParentChatID: nullUUID(params.ParentChatID),
ModelConfigID: nullUUID(params.ModelConfigID),
TriggerMessageID: nullInt64(params.TriggerMessageID),
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
Kind: string(params.Kind),
Status: string(params.Status),
Provider: nullString(params.Provider),
Model: nullString(params.Model),
Summary: s.nullJSON(ctx, params.Summary),
StartedAt: sql.NullTime{},
UpdatedAt: sql.NullTime{},
FinishedAt: sql.NullTime{},
})
if err != nil {
return database.ChatDebugRun{}, err
}
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
return run, nil
}
// UpdateRun updates an existing debug run and emits a run update event.
func (s *Service) UpdateRun(
ctx context.Context,
params UpdateRunParams,
) (database.ChatDebugRun, error) {
run, err := s.db.UpdateChatDebugRun(chatdContext(ctx),
database.UpdateChatDebugRunParams{
RootChatID: uuid.NullUUID{},
ParentChatID: uuid.NullUUID{},
ModelConfigID: uuid.NullUUID{},
TriggerMessageID: sql.NullInt64{},
HistoryTipMessageID: sql.NullInt64{},
Status: nullString(string(params.Status)),
Provider: sql.NullString{},
Model: sql.NullString{},
Summary: s.nullJSON(ctx, params.Summary),
FinishedAt: nullTime(params.FinishedAt),
ID: params.ID,
ChatID: params.ChatID,
})
if err != nil {
return database.ChatDebugRun{}, err
}
s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil)
return run, nil
}
// CreateStep inserts a new debug step and emits a step update event.
func (s *Service) CreateStep(
ctx context.Context,
params CreateStepParams,
) (database.ChatDebugStep, error) {
insert := database.InsertChatDebugStepParams{
RunID: params.RunID,
StepNumber: params.StepNumber,
Operation: string(params.Operation),
Status: string(params.Status),
HistoryTipMessageID: nullInt64(params.HistoryTipMessageID),
AssistantMessageID: sql.NullInt64{},
NormalizedRequest: s.nullJSON(ctx, params.NormalizedRequest),
NormalizedResponse: pqtype.NullRawMessage{},
Usage: pqtype.NullRawMessage{},
Attempts: pqtype.NullRawMessage{},
Error: pqtype.NullRawMessage{},
Metadata: pqtype.NullRawMessage{},
StartedAt: sql.NullTime{},
UpdatedAt: sql.NullTime{},
FinishedAt: sql.NullTime{},
ChatID: params.ChatID,
}
// Cap retry attempts to prevent infinite loops under
// pathological concurrency. Each iteration performs two DB
// round-trips (insert + list), so 10 retries is generous.
const maxCreateStepRetries = 10
for attempt := 0; attempt < maxCreateStepRetries; attempt++ {
if err := ctx.Err(); err != nil {
return database.ChatDebugStep{}, err
}
step, err := s.db.InsertChatDebugStep(chatdContext(ctx), insert)
if err == nil {
// Touch the parent run's updated_at so the stale-
// finalization sweep does not prematurely interrupt
// long-running runs that are still producing steps.
if _, touchErr := s.db.UpdateChatDebugRun(chatdContext(ctx), database.UpdateChatDebugRunParams{
RootChatID: uuid.NullUUID{},
ParentChatID: uuid.NullUUID{},
ModelConfigID: uuid.NullUUID{},
TriggerMessageID: sql.NullInt64{},
HistoryTipMessageID: sql.NullInt64{},
Status: sql.NullString{},
Provider: sql.NullString{},
Model: sql.NullString{},
Summary: pqtype.NullRawMessage{},
FinishedAt: sql.NullTime{},
ID: params.RunID,
ChatID: params.ChatID,
}); touchErr != nil {
s.log.Warn(ctx, "failed to touch parent run updated_at",
slog.F("run_id", params.RunID),
slog.Error(touchErr),
)
}
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
return step, nil
}
if !database.IsUniqueViolation(err, database.UniqueIndexChatDebugStepsRunStep) {
return database.ChatDebugStep{}, err
}
steps, listErr := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), params.RunID)
if listErr != nil {
return database.ChatDebugStep{}, listErr
}
nextStepNumber := insert.StepNumber + 1
for _, existing := range steps {
if existing.StepNumber >= nextStepNumber {
nextStepNumber = existing.StepNumber + 1
}
}
insert.StepNumber = nextStepNumber
}
return database.ChatDebugStep{}, xerrors.Errorf(
"failed to create debug step after %d attempts (run_id=%s)",
maxCreateStepRetries, params.RunID,
)
}
// UpdateStep updates an existing debug step and emits a step update event.
func (s *Service) UpdateStep(
ctx context.Context,
params UpdateStepParams,
) (database.ChatDebugStep, error) {
step, err := s.db.UpdateChatDebugStep(chatdContext(ctx),
database.UpdateChatDebugStepParams{
Status: nullString(string(params.Status)),
HistoryTipMessageID: sql.NullInt64{},
AssistantMessageID: nullInt64(params.AssistantMessageID),
NormalizedRequest: pqtype.NullRawMessage{},
NormalizedResponse: s.nullJSON(ctx, params.NormalizedResponse),
Usage: s.nullJSON(ctx, params.Usage),
Attempts: s.nullJSON(ctx, params.Attempts),
Error: s.nullJSON(ctx, params.Error),
Metadata: s.nullJSON(ctx, params.Metadata),
FinishedAt: nullTime(params.FinishedAt),
ID: params.ID,
ChatID: params.ChatID,
})
if err != nil {
return database.ChatDebugStep{}, err
}
s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID)
return step, nil
}
// DeleteByChatID deletes all debug data for a chat and emits a delete event.
func (s *Service) DeleteByChatID(
ctx context.Context,
chatID uuid.UUID,
) (int64, error) {
deleted, err := s.db.DeleteChatDebugDataByChatID(chatdContext(ctx), chatID)
if err != nil {
return 0, err
}
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
return deleted, nil
}
// DeleteAfterMessageID deletes debug data newer than the given message.
func (s *Service) DeleteAfterMessageID(
ctx context.Context,
chatID uuid.UUID,
messageID int64,
) (int64, error) {
deleted, err := s.db.DeleteChatDebugDataAfterMessageID(
chatdContext(ctx),
database.DeleteChatDebugDataAfterMessageIDParams{
ChatID: chatID,
MessageID: messageID,
},
)
if err != nil {
return 0, err
}
s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil)
return deleted, nil
}
// FinalizeStale finalizes stale in-flight debug rows and emits a broadcast.
func (s *Service) FinalizeStale(
ctx context.Context,
) (database.FinalizeStaleChatDebugRowsRow, error) {
ns := s.staleAfterNanos.Load()
staleAfter := time.Duration(ns)
if staleAfter <= 0 {
staleAfter = DefaultStaleThreshold
}
result, err := s.db.FinalizeStaleChatDebugRows(
chatdContext(ctx),
time.Now().Add(-staleAfter),
)
if err != nil {
return database.FinalizeStaleChatDebugRowsRow{}, err
}
if result.RunsFinalized > 0 || result.StepsFinalized > 0 {
s.publishEvent(ctx, uuid.Nil, EventKindFinalize, uuid.Nil, uuid.Nil)
}
return result, nil
}
func nullUUID(id uuid.UUID) uuid.NullUUID {
return uuid.NullUUID{UUID: id, Valid: id != uuid.Nil}
}
func nullInt64(v int64) sql.NullInt64 {
return sql.NullInt64{Int64: v, Valid: v != 0}
}
func nullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""}
}
func nullTime(value time.Time) sql.NullTime {
return sql.NullTime{Time: value, Valid: !value.IsZero()}
}
// nullJSON marshals value to a NullRawMessage. When value is nil or
// marshals to JSON "null", the result is {Valid: false}. Combined with
// the COALESCE-based UPDATE queries, this means a caller cannot clear a
// previously-set JSON column back to NULL — passing nil preserves the
// existing value. This is acceptable for debug logs because fields
// accumulate monotonically (request → response → usage → error) and
// never need to be cleared during normal operation.
// jsonClear is a sentinel value that tells nullJSON to emit a valid
// JSON null (JSONB 'null') instead of SQL NULL. COALESCE treats SQL
// NULL as "keep existing" but replaces with a non-NULL JSONB value,
// so passing jsonClear explicitly overwrites a previously set field.
type jsonClear struct{}
func (s *Service) nullJSON(ctx context.Context, value any) pqtype.NullRawMessage {
if value == nil {
return pqtype.NullRawMessage{}
}
// Sentinel: emit a valid JSONB null so COALESCE replaces
// any previously stored value.
if _, ok := value.(jsonClear); ok {
return pqtype.NullRawMessage{
RawMessage: json.RawMessage("null"),
Valid: true,
}
}
data, err := json.Marshal(value)
if err != nil {
s.log.Warn(ctx, "failed to marshal chat debug JSON",
slog.Error(err),
slog.F("value_type", fmt.Sprintf("%T", value)),
)
return pqtype.NullRawMessage{}
}
if bytes.Equal(data, []byte("null")) {
return pqtype.NullRawMessage{}
}
return pqtype.NullRawMessage{RawMessage: data, Valid: true}
}
func (s *Service) publishEvent(
ctx context.Context,
chatID uuid.UUID,
kind EventKind,
runID uuid.UUID,
stepID uuid.UUID,
) {
if s.pubsub == nil {
s.log.Debug(ctx,
"chat debug pubsub unavailable; skipping event",
slog.F("kind", kind),
slog.F("chat_id", chatID),
)
return
}
event := DebugEvent{
Kind: kind,
ChatID: chatID,
RunID: runID,
StepID: stepID,
}
data, err := json.Marshal(event)
if err != nil {
s.log.Warn(ctx, "failed to marshal chat debug event",
slog.Error(err),
slog.F("kind", kind),
slog.F("chat_id", chatID),
)
return
}
channel := PubsubChannel(chatID)
if err := s.pubsub.Publish(channel, data); err != nil {
s.log.Warn(ctx, "failed to publish chat debug event",
slog.Error(err),
slog.F("channel", channel),
slog.F("kind", kind),
slog.F("chat_id", chatID),
)
}
}

Some files were not shown because too many files have changed in this diff Show More