Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 79bb0a4312 | |||
| 6145135fa1 | |||
| 97e9a1ce4d | |||
| ab93e58133 | |||
| 52fea06766 | |||
| a5f8f90aa8 | |||
| 8658aa59a1 | |||
| d8665551f1 | |||
| eb08701107 | |||
| 8efc89ad5b | |||
| 6e16297942 | |||
| ab1f0306a6 | |||
| a9350b2ebe | |||
| 83fd4cf5c2 | |||
| 38d4da82b9 | |||
| 19e0e0e8e6 | |||
| 1d0653cdab | |||
| 95cff8c5fb | |||
| ad2415ede7 | |||
| 1e40cea199 | |||
| 9d6557d173 | |||
| 224db483d7 | |||
| 8237822441 | |||
| 65bf7c3b18 | |||
| 76cbc580f0 | |||
| 391b22aef7 | |||
| f8e8f979a2 | |||
| fb0ed1162b | |||
| 3f519744aa | |||
| 2505f6245f | |||
| 29ad2c6201 | |||
| 27e5ff0a8e | |||
| 128a7c23e6 | |||
| efb19eb748 | |||
| 2c499484b7 | |||
| 33d9d0d875 | |||
| f219834f5c | |||
| 7a94a683c4 | |||
| 2e6fdf2344 | |||
| 3d139c1a24 | |||
| f957981c8b | |||
| 584c61acb5 | |||
| f95a5202bf |
@@ -84,6 +84,7 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -139,6 +140,7 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -171,4 +173,6 @@ jobs:
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
|
||||
@@ -42,6 +42,7 @@ jobs:
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
@@ -116,6 +117,7 @@ jobs:
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
@@ -136,4 +138,6 @@ jobs:
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -91,6 +91,59 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -201,6 +254,7 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -654,8 +708,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -693,8 +747,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -734,8 +788,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -949,8 +1003,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1067,88 +1121,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -59,6 +64,8 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
var videoParts [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,6 +56,7 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
artifact := &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
|
||||
@@ -87,6 +87,12 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,3 +102,29 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+194
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+29
-5
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -148,6 +149,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -710,7 +712,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1434,6 +1436,21 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1452,10 +1469,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1467,9 +1484,16 @@ func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.In
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,6 +159,53 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
@@ -768,30 +768,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
}
|
||||
options.Pubsub = ps
|
||||
options.ChatPubsub = ps
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(ps)
|
||||
}
|
||||
defer options.Pubsub.Close()
|
||||
chatPubsub, err := pubsub.NewBatching(
|
||||
ctx,
|
||||
logger.Named("chatd").Named("pubsub_batch"),
|
||||
ps,
|
||||
sqlDB,
|
||||
dbURL,
|
||||
pubsub.BatchingConfig{
|
||||
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
|
||||
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create chat pubsub batcher: %w", err)
|
||||
}
|
||||
options.ChatPubsub = chatPubsub
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(chatPubsub)
|
||||
}
|
||||
defer options.ChatPubsub.Close()
|
||||
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
|
||||
pubsubWatchdogTimeout = psWatchdog.Timeout()
|
||||
defer psWatchdog.Close()
|
||||
|
||||
+6
-4
@@ -69,15 +69,17 @@ var (
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
|
||||
@@ -516,6 +516,23 @@ func TestIsRetryableError(t *testing.T) {
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
|
||||
Generated
+278
@@ -9514,6 +9514,212 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13239,6 +13445,12 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -15142,6 +15354,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -21271,6 +21503,23 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21726,6 +21975,35 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+256
@@ -8431,6 +8431,190 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11809,6 +11993,12 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -13643,6 +13833,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -19545,6 +19755,23 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19975,6 +20202,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
+15
-10
@@ -159,10 +159,7 @@ type Options struct {
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
// ChatPubsub allows chatd to use a dedicated publish path without changing
|
||||
// the shared pubsub used by the rest of coderd.
|
||||
ChatPubsub pubsub.Pubsub
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
|
||||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
@@ -780,11 +777,6 @@ func New(options *Options) *API {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
chatPubsub := options.ChatPubsub
|
||||
if chatPubsub == nil {
|
||||
chatPubsub = options.Pubsub
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
@@ -797,7 +789,7 @@ func New(options *Options) *API {
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: chatPubsub,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
@@ -1616,6 +1608,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1661,6 +1662,10 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -538,6 +538,12 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1234,6 +1240,8 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
|
||||
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -2169,10 +2180,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
@@ -2413,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -5728,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
|
||||
@@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -5413,10 +5431,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
Returns(int64(1))
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -728,12 +736,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -968,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -4104,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
|
||||
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1230,11 +1244,12 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
@@ -1667,6 +1682,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7780,6 +7810,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+2
@@ -3783,6 +3783,8 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -1,749 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
|
||||
// publishes wait before a scheduled flush when nearby publishes do not
|
||||
// naturally coalesce sooner.
|
||||
DefaultBatchingFlushInterval = 50 * time.Millisecond
|
||||
// DefaultBatchingQueueSize is the default number of buffered chatd publish
|
||||
// requests waiting to be flushed.
|
||||
DefaultBatchingQueueSize = 8192
|
||||
|
||||
defaultBatchingPressureWait = 10 * time.Millisecond
|
||||
defaultBatchingFinalFlushLimit = 15 * time.Second
|
||||
batchingWarnInterval = 10 * time.Second
|
||||
|
||||
batchFlushScheduled = "scheduled"
|
||||
batchFlushShutdown = "shutdown"
|
||||
|
||||
batchFlushStageNone = "none"
|
||||
batchFlushStageBegin = "begin"
|
||||
batchFlushStageExec = "exec"
|
||||
batchFlushStageCommit = "commit"
|
||||
|
||||
batchDelegateFallbackReasonQueueFull = "queue_full"
|
||||
batchDelegateFallbackReasonFlushError = "flush_error"
|
||||
|
||||
batchChannelClassStreamNotify = "stream_notify"
|
||||
batchChannelClassOwnerEvent = "owner_event"
|
||||
batchChannelClassConfigChange = "config_change"
|
||||
batchChannelClassOther = "other"
|
||||
)
|
||||
|
||||
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
|
||||
// attempted after shutdown has started.
|
||||
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
|
||||
|
||||
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
|
||||
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
|
||||
// backpressure) and drains everything currently queued into a single
|
||||
// transaction. There is no fixed batch-size knob — the batch size is simply
|
||||
// whatever accumulated since the last flush, which naturally adapts to load.
|
||||
type BatchingConfig struct {
|
||||
FlushInterval time.Duration
|
||||
QueueSize int
|
||||
PressureWait time.Duration
|
||||
FinalFlushTimeout time.Duration
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
type queuedPublish struct {
|
||||
event string
|
||||
channelClass string
|
||||
message []byte
|
||||
}
|
||||
|
||||
type batchSender interface {
|
||||
Flush(ctx context.Context, batch []queuedPublish) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchFlushError struct {
|
||||
stage string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
|
||||
// sender connection while delegating subscribe behavior to the shared listener
|
||||
// pubsub instance.
|
||||
type BatchingPubsub struct {
|
||||
logger slog.Logger
|
||||
delegate *PGPubsub
|
||||
// sender is only accessed from the run() goroutine (including
|
||||
// flushBatch and resetSender which it calls). Do not read or
|
||||
// write this field from Publish or any other goroutine.
|
||||
sender batchSender
|
||||
newSender func(context.Context) (batchSender, error)
|
||||
clock quartz.Clock
|
||||
|
||||
publishCh chan queuedPublish
|
||||
flushCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
spaceMu sync.Mutex
|
||||
spaceSignal chan struct{}
|
||||
|
||||
warnTicker *quartz.Ticker
|
||||
|
||||
flushInterval time.Duration
|
||||
pressureWait time.Duration
|
||||
finalFlushTimeout time.Duration
|
||||
|
||||
queuedCount atomic.Int64
|
||||
closed atomic.Bool
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
runErr error
|
||||
|
||||
runCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
metrics batchingMetrics
|
||||
}
|
||||
|
||||
type batchingMetrics struct {
|
||||
QueueDepth prometheus.Gauge
|
||||
BatchSize prometheus.Histogram
|
||||
FlushDuration *prometheus.HistogramVec
|
||||
DelegateFallbacksTotal *prometheus.CounterVec
|
||||
SenderResetsTotal prometheus.Counter
|
||||
SenderResetFailuresTotal prometheus.Counter
|
||||
}
|
||||
|
||||
func newBatchingMetrics() batchingMetrics {
|
||||
return batchingMetrics{
|
||||
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_queue_depth",
|
||||
Help: "The number of chatd notifications waiting in the batching queue.",
|
||||
}),
|
||||
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_size",
|
||||
Help: "The number of logical notifications sent in each chatd batch flush.",
|
||||
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
|
||||
}),
|
||||
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_flush_duration_seconds",
|
||||
Help: "The time spent flushing one chatd batch to PostgreSQL.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
|
||||
}, []string{"reason"}),
|
||||
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_delegate_fallbacks_total",
|
||||
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
|
||||
}, []string{"channel_class", "reason", "stage"}),
|
||||
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_resets_total",
|
||||
Help: "The number of successful batched pubsub sender resets after flush failures.",
|
||||
}),
|
||||
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_reset_failures_total",
|
||||
Help: "The number of batched pubsub sender reset attempts that failed.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
|
||||
m.QueueDepth.Describe(descs)
|
||||
m.BatchSize.Describe(descs)
|
||||
m.FlushDuration.Describe(descs)
|
||||
m.DelegateFallbacksTotal.Describe(descs)
|
||||
m.SenderResetsTotal.Describe(descs)
|
||||
m.SenderResetFailuresTotal.Describe(descs)
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
|
||||
m.QueueDepth.Collect(metrics)
|
||||
m.BatchSize.Collect(metrics)
|
||||
m.FlushDuration.Collect(metrics)
|
||||
m.DelegateFallbacksTotal.Collect(metrics)
|
||||
m.SenderResetsTotal.Collect(metrics)
|
||||
m.SenderResetFailuresTotal.Collect(metrics)
|
||||
}
|
||||
|
||||
// NewBatching creates a chatd-specific batched pubsub wrapper around the
|
||||
// shared PostgreSQL listener implementation.
|
||||
func NewBatching(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if prototype == nil {
|
||||
return nil, xerrors.New("prototype database is nil")
|
||||
}
|
||||
if connectURL == "" {
|
||||
return nil, xerrors.New("connect URL is empty")
|
||||
}
|
||||
|
||||
newSender := func(ctx context.Context) (batchSender, error) {
|
||||
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
|
||||
}
|
||||
|
||||
sender, err := newSender(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
|
||||
if err != nil {
|
||||
_ = sender.Close()
|
||||
return nil, err
|
||||
}
|
||||
ps.newSender = newSender
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func newBatchingPubsub(
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
sender batchSender,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if sender == nil {
|
||||
return nil, xerrors.New("batch sender is nil")
|
||||
}
|
||||
|
||||
flushInterval := cfg.FlushInterval
|
||||
if flushInterval == 0 {
|
||||
flushInterval = DefaultBatchingFlushInterval
|
||||
}
|
||||
if flushInterval < 0 {
|
||||
return nil, xerrors.New("flush interval must be positive")
|
||||
}
|
||||
|
||||
queueSize := cfg.QueueSize
|
||||
if queueSize == 0 {
|
||||
queueSize = DefaultBatchingQueueSize
|
||||
}
|
||||
if queueSize < 0 {
|
||||
return nil, xerrors.New("queue size must be positive")
|
||||
}
|
||||
|
||||
pressureWait := cfg.PressureWait
|
||||
if pressureWait == 0 {
|
||||
pressureWait = defaultBatchingPressureWait
|
||||
}
|
||||
if pressureWait < 0 {
|
||||
return nil, xerrors.New("pressure wait must be positive")
|
||||
}
|
||||
|
||||
finalFlushTimeout := cfg.FinalFlushTimeout
|
||||
if finalFlushTimeout == 0 {
|
||||
finalFlushTimeout = defaultBatchingFinalFlushLimit
|
||||
}
|
||||
if finalFlushTimeout < 0 {
|
||||
return nil, xerrors.New("final flush timeout must be positive")
|
||||
}
|
||||
|
||||
clock := cfg.Clock
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
ps := &BatchingPubsub{
|
||||
logger: logger,
|
||||
delegate: delegate,
|
||||
sender: sender,
|
||||
clock: clock,
|
||||
publishCh: make(chan queuedPublish, queueSize),
|
||||
flushCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
spaceSignal: make(chan struct{}),
|
||||
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
|
||||
flushInterval: flushInterval,
|
||||
pressureWait: pressureWait,
|
||||
finalFlushTimeout: finalFlushTimeout,
|
||||
runCtx: runCtx,
|
||||
cancel: cancel,
|
||||
metrics: newBatchingMetrics(),
|
||||
}
|
||||
ps.metrics.QueueDepth.Set(0)
|
||||
|
||||
go ps.run()
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// Describe implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
|
||||
p.metrics.Describe(descs)
|
||||
}
|
||||
|
||||
// Collect implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
|
||||
p.metrics.Collect(metrics)
|
||||
}
|
||||
|
||||
// Subscribe delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
|
||||
return p.delegate.Subscribe(event, listener)
|
||||
}
|
||||
|
||||
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
|
||||
return p.delegate.SubscribeWithErr(event, listener)
|
||||
}
|
||||
|
||||
// Publish enqueues a logical notification for asynchronous batched delivery.
|
||||
func (p *BatchingPubsub) Publish(event string, message []byte) error {
|
||||
channelClass := batchChannelClass(event)
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
|
||||
req := queuedPublish{
|
||||
event: event,
|
||||
channelClass: channelClass,
|
||||
message: bytes.Clone(message),
|
||||
}
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
|
||||
defer timer.Stop("pubsubBatcher", "pressureWait")
|
||||
|
||||
for {
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
p.signalPressureFlush()
|
||||
spaceSignal := p.currentSpaceSignal()
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-spaceSignal:
|
||||
continue
|
||||
case <-timer.C:
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
// The batching queue is still full after a pressure
|
||||
// flush and brief wait. Fall back to the shared
|
||||
// pubsub pool so the notification is still delivered
|
||||
// rather than dropped.
|
||||
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
|
||||
p.logPublishRejection(event)
|
||||
return p.delegate.Publish(event, message)
|
||||
case <-p.doneCh:
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops accepting new publishes, performs a bounded best-effort drain,
|
||||
// and then closes the dedicated sender connection.
|
||||
func (p *BatchingPubsub) Close() error {
|
||||
p.closeOnce.Do(func() {
|
||||
p.closed.Store(true)
|
||||
p.cancel()
|
||||
p.notifySpaceAvailable()
|
||||
close(p.closeCh)
|
||||
<-p.doneCh
|
||||
p.closeErr = p.runErr
|
||||
})
|
||||
return p.closeErr
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
|
||||
if p.closed.Load() {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case p.publishCh <- req:
|
||||
queuedDepth := p.queuedCount.Add(1)
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
|
||||
p.metrics.QueueDepth.Set(float64(depth))
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) signalPressureFlush() {
|
||||
select {
|
||||
case p.flushCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
return p.spaceSignal
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) notifySpaceAvailable() {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
close(p.spaceSignal)
|
||||
p.spaceSignal = make(chan struct{})
|
||||
}
|
||||
|
||||
func batchChannelClass(event string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(event, "chat:stream:"):
|
||||
return batchChannelClassStreamNotify
|
||||
case strings.HasPrefix(event, "chat:owner:"):
|
||||
return batchChannelClassOwnerEvent
|
||||
case event == "chat:config_change":
|
||||
return batchChannelClassConfigChange
|
||||
default:
|
||||
return batchChannelClassOther
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
counts := make(map[string]int)
|
||||
for _, item := range batch {
|
||||
counts[item.channelClass]++
|
||||
}
|
||||
for channelClass, count := range counts {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
|
||||
}
|
||||
}
|
||||
|
||||
func batchFlushStage(err error) string {
|
||||
var flushErr *batchFlushError
|
||||
if errors.As(err, &flushErr) {
|
||||
return flushErr.stage
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) run() {
|
||||
defer close(p.doneCh)
|
||||
defer p.warnTicker.Stop("pubsubBatcher", "warn")
|
||||
|
||||
batch := make([]queuedPublish, 0, 64)
|
||||
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
|
||||
defer timer.Stop("pubsubBatcher", "scheduledFlush")
|
||||
|
||||
flush := func(reason string) {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
batch, _ = p.flushBatch(p.runCtx, batch, reason)
|
||||
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
// An item arrived before the timer fired. Append it and
|
||||
// let the timer or pressure signal trigger the actual
|
||||
// flush so that nearby publishes coalesce naturally.
|
||||
batch = append(batch, item)
|
||||
p.notifySpaceAvailable()
|
||||
case <-timer.C:
|
||||
flush(batchFlushScheduled)
|
||||
case <-p.flushCh:
|
||||
flush("pressure")
|
||||
case <-p.closeCh:
|
||||
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
|
||||
drained := false
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
batch = append(batch, item)
|
||||
drained = true
|
||||
default:
|
||||
if drained {
|
||||
p.notifySpaceAvailable()
|
||||
}
|
||||
return batch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) flushBatch(
|
||||
ctx context.Context,
|
||||
batch []queuedPublish,
|
||||
reason string,
|
||||
) ([]queuedPublish, error) {
|
||||
if len(batch) == 0 {
|
||||
return batch[:0], nil
|
||||
}
|
||||
|
||||
count := len(batch)
|
||||
totalBytes := 0
|
||||
for _, item := range batch {
|
||||
totalBytes += len(item.message)
|
||||
}
|
||||
|
||||
p.metrics.BatchSize.Observe(float64(count))
|
||||
start := p.clock.Now()
|
||||
senderErr := p.sender.Flush(ctx, batch)
|
||||
elapsed := p.clock.Since(start)
|
||||
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
|
||||
var err error
|
||||
if senderErr != nil {
|
||||
stage := batchFlushStage(senderErr)
|
||||
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
|
||||
var resetErr error
|
||||
if reason != batchFlushShutdown {
|
||||
resetErr = p.resetSender()
|
||||
}
|
||||
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
|
||||
if fallbackErr != nil || resetErr != nil {
|
||||
err = errors.Join(senderErr, fallbackErr, resetErr)
|
||||
}
|
||||
} else if p.delegate != nil {
|
||||
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
|
||||
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
|
||||
}
|
||||
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
clear(batch)
|
||||
return batch[:0], err
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
|
||||
if len(batch) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
p.observeDelegateFallbackBatch(batch, reason, stage)
|
||||
if p.delegate == nil {
|
||||
return 0, len(batch), xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, item := range batch {
|
||||
if err := p.delegate.Publish(item.event, item.message); err != nil {
|
||||
failed++
|
||||
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
|
||||
continue
|
||||
}
|
||||
delivered++
|
||||
}
|
||||
return delivered, failed, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) resetSender() error {
|
||||
if p.newSender == nil {
|
||||
return nil
|
||||
}
|
||||
newSender, err := p.newSender(context.Background())
|
||||
if err != nil {
|
||||
p.metrics.SenderResetFailuresTotal.Inc()
|
||||
return err
|
||||
}
|
||||
oldSender := p.sender
|
||||
p.sender = newSender
|
||||
p.metrics.SenderResetsTotal.Inc()
|
||||
if oldSender == nil {
|
||||
return nil
|
||||
}
|
||||
if err := oldSender.Close(); err != nil {
|
||||
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
|
||||
fields := []slog.Field{
|
||||
slog.F("reason", reason),
|
||||
slog.F("stage", stage),
|
||||
slog.F("count", count),
|
||||
slog.F("total_bytes", totalBytes),
|
||||
slog.F("delegate_delivered", delivered),
|
||||
slog.F("delegate_failed", failed),
|
||||
slog.Error(senderErr),
|
||||
}
|
||||
if fallbackErr != nil {
|
||||
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
|
||||
}
|
||||
if resetErr != nil {
|
||||
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
|
||||
}
|
||||
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
|
||||
defer cancel()
|
||||
|
||||
var errs []error
|
||||
for {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
if len(batch) == 0 {
|
||||
break
|
||||
}
|
||||
var err error
|
||||
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dropped := p.dropPendingPublishes()
|
||||
if dropped > 0 {
|
||||
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) dropPendingPublishes() int {
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-p.publishCh:
|
||||
count++
|
||||
default:
|
||||
if count > 0 {
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
}
|
||||
return count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logPublishRejection(event string) {
|
||||
fields := []slog.Field{
|
||||
slog.F("event", event),
|
||||
slog.F("queue_size", cap(p.publishCh)),
|
||||
slog.F("queued", p.queuedCount.Load()),
|
||||
}
|
||||
select {
|
||||
case <-p.warnTicker.C:
|
||||
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
|
||||
default:
|
||||
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
type pgBatchSender struct {
|
||||
logger slog.Logger
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newPGBatchSender(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
) (*pgBatchSender, error) {
|
||||
connector, err := newConnector(ctx, logger, prototype, connectURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := sql.OpenDB(connector)
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxIdleTime(0)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
|
||||
}
|
||||
|
||||
return &pgBatchSender{logger: logger, db: db}, nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, item := range batch {
|
||||
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
|
||||
// not support the first parameter being a prepared statement.
|
||||
//nolint:gosec
|
||||
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
@@ -1,520 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubScheduledFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 2)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("two"), batch[1].message)
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 2, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
|
||||
|
||||
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
|
||||
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
|
||||
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
|
||||
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
|
||||
}
|
||||
|
||||
func TestBatchChannelClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
event string
|
||||
want string
|
||||
}{
|
||||
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
|
||||
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
|
||||
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
|
||||
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, batchChannelClass(tt.event))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 64,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Enqueue many messages before the timer fires — all should be
|
||||
// drained and flushed in a single batch.
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 5)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("five"), batch[4].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
|
||||
defer pressureTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.blockCh = make(chan struct{})
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 1,
|
||||
PressureWait: 10 * time.Millisecond,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Fill the queue (capacity 1).
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
// Fire the timer so the run loop starts flushing "one" — the
|
||||
// sender blocks on blockCh so the flush stays in-flight.
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
<-sender.started
|
||||
|
||||
// The run loop is blocked in flushBatch. Fill the queue again.
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
|
||||
// A third publish should fall back to the delegate (which has a
|
||||
// closed db, so the delegate Publish itself will error — but we
|
||||
// verify the fallback metric was incremented).
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- ps.Publish("chat:stream:a", []byte("three"))
|
||||
}()
|
||||
|
||||
pressureCall, err := pressureTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
pressureCall.MustRelease(ctx)
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
|
||||
err = testutil.TryReceive(ctx, t, errCh)
|
||||
// The delegate has a closed db so it returns an error from the
|
||||
// shared pool, not a batching-specific sentinel.
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
|
||||
|
||||
close(sender.blockCh)
|
||||
// Let the run loop finish the blocked flush and process "two".
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
require.NoError(t, ps.Close())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 3)
|
||||
require.Equal(t, []byte("one"), batches[0][0].message)
|
||||
require.Equal(t, []byte("two"), batches[0][1].message)
|
||||
require.Equal(t, []byte("three"), batches[0][2].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, 1, sender.CloseCalls())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubPreservesOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.NotEmpty(t, batches)
|
||||
|
||||
messages := make([]string, 0, 5)
|
||||
for _, batch := range batches {
|
||||
for _, item := range batch {
|
||||
messages = append(messages, string(item.message))
|
||||
}
|
||||
}
|
||||
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 1, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
|
||||
for _, stage := range stages {
|
||||
stage := stage
|
||||
t.Run(stage, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = stage
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:test",
|
||||
channelClass: batchChannelClass("chat:stream:test"),
|
||||
message: []byte("fallback-" + stage),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
firstSender := newFakeBatchSender()
|
||||
firstSender.err = context.DeadlineExceeded
|
||||
firstSender.errStage = batchFlushStageExec
|
||||
secondSender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
|
||||
ps.newSender = func(context.Context) (batchSender, error) {
|
||||
return secondSender, nil
|
||||
}
|
||||
|
||||
firstBatch := []queuedPublish{{
|
||||
event: "chat:stream:first",
|
||||
channelClass: batchChannelClass("chat:stream:first"),
|
||||
message: []byte("first"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(firstBatch)))
|
||||
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
|
||||
require.Equal(t, 1, firstSender.CloseCalls())
|
||||
|
||||
secondBatch := []queuedPublish{{
|
||||
event: "chat:stream:second",
|
||||
channelClass: batchChannelClass("chat:stream:second"),
|
||||
message: []byte("second"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(secondBatch)))
|
||||
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
|
||||
require.NoError(t, err)
|
||||
batches := secondSender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 1)
|
||||
require.Equal(t, []byte("second"), batches[0][0].message)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:error",
|
||||
channelClass: batchChannelClass("chat:stream:error"),
|
||||
message: []byte("error"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
|
||||
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
// Use a closed *sql.DB so that delegate.Publish returns a real
|
||||
// error instead of panicking on a nil pointer when the batching
|
||||
// queue falls back to the shared pool under pressure.
|
||||
closedDB := newClosedDB(t)
|
||||
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
|
||||
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
})
|
||||
return ps, delegate
|
||||
}
|
||||
|
||||
// newClosedDB returns an *sql.DB whose connections have been closed,
|
||||
// so any ExecContext call returns an error rather than panicking.
|
||||
func newClosedDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
return db
|
||||
}
|
||||
|
||||
type fakeBatchSender struct {
|
||||
mu sync.Mutex
|
||||
batches [][]queuedPublish
|
||||
flushes chan []queuedPublish
|
||||
started chan struct{}
|
||||
blockCh chan struct{}
|
||||
err error
|
||||
errStage string
|
||||
closeErr error
|
||||
closeCall int
|
||||
}
|
||||
|
||||
func newFakeBatchSender() *fakeBatchSender {
|
||||
return &fakeBatchSender{
|
||||
flushes: make(chan []queuedPublish, 16),
|
||||
started: make(chan struct{}, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
select {
|
||||
case s.started <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if s.blockCh != nil {
|
||||
select {
|
||||
case <-s.blockCh:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
clone := make([]queuedPublish, len(batch))
|
||||
for i, item := range batch {
|
||||
clone[i] = queuedPublish{
|
||||
event: item.event,
|
||||
message: bytes.Clone(item.message),
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.batches = append(s.batches, clone)
|
||||
s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case s.flushes <- clone:
|
||||
default:
|
||||
}
|
||||
if s.err == nil {
|
||||
return nil
|
||||
}
|
||||
if s.errStage != "" {
|
||||
return &batchFlushError{stage: s.errStage, err: s.err}
|
||||
}
|
||||
return s.err
|
||||
}
|
||||
|
||||
type metricWriter interface {
|
||||
Write(*dto.Metric) error
|
||||
}
|
||||
|
||||
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
|
||||
t.Helper()
|
||||
writer, ok := observer.(metricWriter)
|
||||
require.True(t, ok)
|
||||
|
||||
metric := &dto.Metric{}
|
||||
require.NoError(t, writer.Write(metric))
|
||||
histogram := metric.GetHistogram()
|
||||
require.NotNil(t, histogram)
|
||||
return histogram.GetSampleCount(), histogram.GetSampleSum()
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.closeCall++
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Batches() [][]queuedPublish {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
clone := make([][]queuedPublish, len(s.batches))
|
||||
for i, batch := range s.batches {
|
||||
clone[i] = make([]queuedPublish, len(batch))
|
||||
copy(clone[i], batch)
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) CloseCalls() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.closeCall
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
|
||||
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 1)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("hello")))
|
||||
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 4)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
|
||||
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
|
||||
require.NoError(t, senderConn.Close())
|
||||
|
||||
reconnected := false
|
||||
delivered := false
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
if !reconnected {
|
||||
select {
|
||||
case conn := <-trackedDriver.Connections:
|
||||
reconnected = conn != nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-messageCh:
|
||||
default:
|
||||
}
|
||||
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case msg := <-messageCh:
|
||||
delivered = string(msg) == "after-disconnect"
|
||||
case <-time.After(testutil.IntervalFast):
|
||||
delivered = false
|
||||
}
|
||||
return reconnected && delivered
|
||||
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
|
||||
}
|
||||
@@ -487,14 +487,12 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
|
||||
if db == nil {
|
||||
return nil, xerrors.New("database is nil")
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
// Creates a new listener using pq.
|
||||
var (
|
||||
dialer = logDialer{
|
||||
logger: logger,
|
||||
logger: p.logger,
|
||||
// pq.defaultDialer uses a zero net.Dialer as well.
|
||||
d: net.Dialer{},
|
||||
}
|
||||
@@ -503,38 +501,28 @@ func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectUR
|
||||
)
|
||||
|
||||
// Create a custom connector if the database driver supports it.
|
||||
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
|
||||
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
|
||||
if ok {
|
||||
connector, err = connectorCreator.Connector(connectURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create custom connector: %w", err)
|
||||
return xerrors.Errorf("create custom connector: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Use the default pq connector otherwise.
|
||||
// use the default pq connector otherwise
|
||||
connector, err = pq.NewConnector(connectURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create pq connector: %w", err)
|
||||
return xerrors.Errorf("create pq connector: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the dialer if the connector supports it.
|
||||
dc, ok := connector.(database.DialerConnector)
|
||||
if !ok {
|
||||
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
} else {
|
||||
dc.Dialer(dialer)
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
errCh = make(chan error, 1)
|
||||
sentErrCh = false
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -168,7 +169,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -215,6 +216,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -893,6 +895,7 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
|
||||
@@ -7376,7 +7376,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
|
||||
@@ -4505,6 +4505,19 @@ func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatD
|
||||
return err
|
||||
}
|
||||
|
||||
const clearChatMessageProviderResponseIDsByChatID = `-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = $1::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, clearChatMessageProviderResponseIDsByChatID, chatID)
|
||||
return err
|
||||
}
|
||||
|
||||
const countEnabledModelsWithoutPricing = `-- name: CountEnabledModelsWithoutPricing :one
|
||||
SELECT COUNT(*)::bigint AS count
|
||||
FROM chat_model_configs
|
||||
@@ -4603,6 +4616,66 @@ func (q *sqlQuerier) DeleteOldChats(ctx context.Context, arg DeleteOldChatsParam
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const getActiveChatsByAgentID = `-- name: GetActiveChatsByAgentID :many
|
||||
SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
FROM chats
|
||||
WHERE agent_id = $1::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getActiveChatsByAgentID, agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Chat
|
||||
for rows.Next() {
|
||||
var i Chat
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
&i.LastReadMessageID,
|
||||
&i.LastInjectedContext,
|
||||
&i.DynamicTools,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatByID = `-- name: GetChatByID :one
|
||||
SELECT
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order, last_read_message_id, last_injected_context, dynamic_tools
|
||||
@@ -6706,6 +6779,18 @@ func (q *sqlQuerier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg Soft
|
||||
return err
|
||||
}
|
||||
|
||||
const softDeleteContextFileMessages = `-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = $1::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, softDeleteContextFileMessages, chatID)
|
||||
return err
|
||||
}
|
||||
|
||||
const unarchiveChatByID = `-- name: UnarchiveChatByID :many
|
||||
WITH chats AS (
|
||||
UPDATE chats SET
|
||||
@@ -23042,7 +23127,7 @@ func (q *sqlQuerier) CreateUserSecret(ctx context.Context, arg CreateUserSecretP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
const deleteUserSecretByUserIDAndName = `-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2
|
||||
`
|
||||
@@ -23052,9 +23137,12 @@ type DeleteUserSecretByUserIDAndNameParams struct {
|
||||
Name string `db:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
return err
|
||||
func (q *sqlQuerier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, deleteUserSecretByUserIDAndName, arg.UserID, arg.Name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
const getUserSecretByUserIDAndName = `-- name: GetUserSecretByUserIDAndName :one
|
||||
|
||||
@@ -1293,3 +1293,26 @@ GROUP BY cm.chat_id;
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -56,6 +56,6 @@ SET
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
+61
-67
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -2176,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2198,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.SubscribeAuthorized(ctx, chat, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2206,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2273,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2282,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2293,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2308,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2321,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
|
||||
+19
-98
@@ -1114,17 +1114,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1136,25 +1125,16 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1174,18 +1154,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1198,18 +1166,11 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1282,25 +1243,14 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1313,25 +1263,15 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1350,7 +1290,6 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1393,31 +1332,13 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1427,7 +1348,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1440,12 +1361,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -146,12 +147,35 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
|
||||
if params.state != "" {
|
||||
cancelQuery.Add("state", params.state)
|
||||
}
|
||||
cancel.RawQuery = cancelQuery.Encode()
|
||||
|
||||
cancelURI := cancel.String()
|
||||
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Invalid Callback URL",
|
||||
Description: "The application's registered callback URL has an invalid scheme.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
// #nosec G203 -- The scheme is validated by
|
||||
// codersdk.ValidateRedirectURIScheme above.
|
||||
CancelURI: htmltemplate.URL(cancelURI),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -20,7 +21,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
ChatEventKindActionRequired ChatEventKind = "action_required"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Create a new user secret
|
||||
// @ID create-a-new-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
|
||||
// @Success 201 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [post]
|
||||
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
var req codersdk.CreateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Name is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Value is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Value: req.Value,
|
||||
ValueKeyID: sql.NullString{},
|
||||
EnvName: req.EnvName,
|
||||
FilePath: req.FilePath,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "A secret with that name, environment variable, or file path already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary List user secrets
|
||||
// @ID list-user-secrets
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Success 200 {array} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [get]
|
||||
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error listing secrets.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
|
||||
}
|
||||
|
||||
// @Summary Get a user secret by name
|
||||
// @ID get-a-user-secret-by-name
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [get]
|
||||
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Update a user secret
|
||||
// @ID update-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [patch]
|
||||
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var req codersdk.UpdateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one field must be provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
UpdateValue: req.Value != nil,
|
||||
Value: "",
|
||||
ValueKeyID: sql.NullString{},
|
||||
UpdateDescription: req.Description != nil,
|
||||
Description: "",
|
||||
UpdateEnvName: req.EnvName != nil,
|
||||
EnvName: "",
|
||||
UpdateFilePath: req.FilePath != nil,
|
||||
FilePath: "",
|
||||
}
|
||||
if req.Value != nil {
|
||||
params.Value = *req.Value
|
||||
}
|
||||
if req.Description != nil {
|
||||
params.Description = *req.Description
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
params.EnvName = *req.EnvName
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
params.FilePath = *req.FilePath
|
||||
}
|
||||
|
||||
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Update would conflict with an existing secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Delete a user secret
|
||||
// @ID delete-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 204
|
||||
// @Router /users/{user}/secrets/{name} [delete]
|
||||
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPostUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub PAT",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
FilePath: "~/.github-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github-token", secret.Name)
|
||||
assert.Equal(t, "Personal GitHub PAT", secret.Description)
|
||||
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
|
||||
assert.Equal(t, "~/.github-token", secret.FilePath)
|
||||
assert.NotZero(t, secret.ID)
|
||||
assert.NotZero(t, secret.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Value: "some-value",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Name is required")
|
||||
})
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "missing-value-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Value is required")
|
||||
})
|
||||
|
||||
t.Run("DuplicateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-1",
|
||||
Value: "value1",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-2",
|
||||
Value: "value2",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-2",
|
||||
Value: "value2",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "invalid-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "1INVALID",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ReservedEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "reserved-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "PATH",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("CoderPrefixEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "coder-prefix-secret",
|
||||
Value: "value",
|
||||
EnvName: "CODER_AGENT_TOKEN",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "bad-path-secret",
|
||||
Value: "value",
|
||||
FilePath: "relative/path",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Verify no secrets exist on a fresh user.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, secrets)
|
||||
|
||||
t.Run("WithSecrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-a",
|
||||
Value: "value-a",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-b",
|
||||
Value: "value-b",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
// Sorted by name.
|
||||
assert.Equal(t, "list-secret-a", secrets[0].Name)
|
||||
assert.Equal(t, "list-secret-b", secrets[1].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "get-found-secret",
|
||||
Value: "my-value",
|
||||
EnvName: "GET_FOUND_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "get-found-secret", got.Name)
|
||||
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("UpdateDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-desc-secret",
|
||||
Value: "my-value",
|
||||
Description: "original",
|
||||
EnvName: "PATCH_DESC_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newDesc := "updated"
|
||||
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
|
||||
Description: &newDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", updated.Description)
|
||||
// Other fields unchanged.
|
||||
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NoFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-nofields-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
newVal := "new-value"
|
||||
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
|
||||
Value: &newVal,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-1",
|
||||
Value: "value1",
|
||||
EnvName: "CONFLICT_TAKEN_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "CONFLICT_TAKEN_ENV"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
|
||||
EnvName: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/conflict-taken",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "/tmp/conflict-taken"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
|
||||
FilePath: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "delete-me-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
@@ -42,6 +42,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -2393,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
|
||||
}
|
||||
return sdk
|
||||
}
|
||||
|
||||
// maxChatContextParts caps the number of parts per request to
|
||||
// prevent unbounded message payloads.
|
||||
const maxChatContextParts = 100
|
||||
|
||||
// maxChatContextFileBytes caps each context-file part to the same
|
||||
// 64KiB budget used when the agent reads instruction files from disk.
|
||||
const maxChatContextFileBytes = 64 * 1024
|
||||
|
||||
// maxChatContextRequestBodyBytes caps the JSON request body size for
|
||||
// agent-added context to roughly the same per-part budget used when
|
||||
// reading instruction files from disk.
|
||||
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
|
||||
|
||||
// sanitizeWorkspaceAgentContextFileContent applies prompt
|
||||
// sanitization, then enforces the 64KiB per-file budget. The
|
||||
// truncated flag is preserved when the caller already capped the
|
||||
// file before sending it.
|
||||
func sanitizeWorkspaceAgentContextFileContent(
|
||||
content string,
|
||||
truncated bool,
|
||||
) (string, bool) {
|
||||
content = chatd.SanitizePromptText(content)
|
||||
if len(content) > maxChatContextFileBytes {
|
||||
content = content[:maxChatContextFileBytes]
|
||||
truncated = true
|
||||
}
|
||||
return content, truncated
|
||||
}
|
||||
|
||||
// readChatContextBody reads and validates the request body for chat
|
||||
// context endpoints. It handles MaxBytesReader wrapping, error
|
||||
// responses, and body rewind. If the body is empty or whitespace-only
|
||||
// and allowEmpty is true, it returns false without writing an error.
|
||||
//
|
||||
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
|
||||
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Request body too large.",
|
||||
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
|
||||
})
|
||||
return false
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return false
|
||||
}
|
||||
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
|
||||
r.Body = http.NoBody
|
||||
return false
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return httpapi.Read(ctx, rw, r, dst)
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.AddChatContextRequest
|
||||
if !readChatContextBody(ctx, rw, r, &req, false) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) > maxChatContextParts {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to only non-empty context-file and skill parts.
|
||||
filtered := chatd.FilterContextParts(req.Parts, false)
|
||||
if len(filtered) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Parts = filtered
|
||||
responsePartCount := 0
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
// We verify agent-to-chat ownership explicitly below.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stamp each persisted part with the agent identity. Context-file
|
||||
// parts also get server-authoritative workspace metadata.
|
||||
directory := workspaceAgent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = workspaceAgent.Directory
|
||||
}
|
||||
for i := range req.Parts {
|
||||
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
|
||||
UUID: workspaceAgent.ID,
|
||||
Valid: true,
|
||||
}
|
||||
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
|
||||
req.Parts[i].ContextFileContent,
|
||||
req.Parts[i].ContextFileTruncated,
|
||||
)
|
||||
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
|
||||
req.Parts[i].ContextFileDirectory = directory
|
||||
}
|
||||
req.Parts = chatd.FilterContextParts(req.Parts, false)
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
responsePartCount = len(req.Parts)
|
||||
|
||||
// Skill-only messages need a sentinel context-file part so the turn
|
||||
// pipeline trusts the associated skill metadata.
|
||||
req.Parts = prependAgentChatContextSentinelIfNeeded(
|
||||
req.Parts,
|
||||
workspaceAgent.ID,
|
||||
workspaceAgent.OperatingSystem,
|
||||
directory,
|
||||
)
|
||||
|
||||
content, err := chatprompt.MarshalParts(req.Parts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal context parts.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspace.OwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
|
||||
chat.ID,
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
locked.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
uuid.Nil,
|
||||
)); err != nil {
|
||||
return xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("rebuild injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to persist context message.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
Count: responsePartCount,
|
||||
})
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.ClearChatContextRequest
|
||||
populated := readChatContextBody(ctx, rw, r, &req, true)
|
||||
if !populated && r.Body != http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
// Zero active chats is not an error for clear.
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
|
||||
return
|
||||
}
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to clear context from chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
errNoActiveChats = xerrors.New("no active chats found")
|
||||
errChatNotFound = xerrors.New("chat not found")
|
||||
errChatNotActive = xerrors.New("chat is not active")
|
||||
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
|
||||
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
|
||||
)
|
||||
|
||||
type multipleActiveChatsError struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (e *multipleActiveChatsError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"multiple active chats (%d) found for this agent, specify a chat ID",
|
||||
e.count,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
|
||||
switch len(chats) {
|
||||
case 0:
|
||||
return database.Chat{}, errNoActiveChats
|
||||
case 1:
|
||||
return chats[0], nil
|
||||
}
|
||||
|
||||
var rootChat *database.Chat
|
||||
for i := range chats {
|
||||
chat := &chats[i]
|
||||
if chat.ParentChatID.Valid {
|
||||
continue
|
||||
}
|
||||
if rootChat != nil {
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
rootChat = chat
|
||||
}
|
||||
if rootChat != nil {
|
||||
return *rootChat, nil
|
||||
}
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
|
||||
// resolveAgentChat finds the target chat from either an explicit ID
|
||||
// or auto-detection via the agent's active chats.
|
||||
func resolveAgentChat(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
explicitChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
if explicitChatID == uuid.Nil {
|
||||
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
|
||||
}
|
||||
ownerChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
continue
|
||||
}
|
||||
ownerChats = append(ownerChats, chat)
|
||||
}
|
||||
return resolveDefaultAgentChat(ownerChats)
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, explicitChatID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, errChatNotFound
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
|
||||
}
|
||||
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
|
||||
return database.Chat{}, errChatDoesNotBelongToAgent
|
||||
}
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if !isActiveAgentChat(chat) {
|
||||
return database.Chat{}, errChatNotActive
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func isActiveAgentChat(chat database.Chat) bool {
|
||||
if chat.Archived {
|
||||
return false
|
||||
}
|
||||
|
||||
switch chat.Status {
|
||||
case database.ChatStatusWaiting,
|
||||
database.ChatStatusPending,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusPaused,
|
||||
database.ChatStatusRequiresAction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clearAgentChatContext(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspaceOwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
hadInjectedContext := locked.LastInjectedContext.Valid
|
||||
var skillOnlyMessageIDs []int64
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
|
||||
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
|
||||
if hasContextFile || hasSkill {
|
||||
hadInjectedContext = true
|
||||
}
|
||||
if hasSkill && !hasContextFile {
|
||||
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
if !hadInjectedContext {
|
||||
return nil
|
||||
}
|
||||
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("soft delete context-file messages: %w", err)
|
||||
}
|
||||
for _, messageID := range skillOnlyMessageIDs {
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
|
||||
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
|
||||
}
|
||||
}
|
||||
// Reset provider-side Responses chaining so the next turn replays
|
||||
// the post-clear history instead of inheriting cleared context.
|
||||
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("clear provider response chain: %w", err)
|
||||
}
|
||||
// Clear the injected-context cache inside the transaction so it is
|
||||
// atomic with the soft-deletes.
|
||||
param, err := chatd.BuildLastInjectedContext(nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
|
||||
// part when the request only carries skills. The turn pipeline uses
|
||||
// the sentinel's agent metadata to trust the skill parts.
|
||||
func prependAgentChatContextSentinelIfNeeded(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
agentID uuid.UUID,
|
||||
operatingSystem string,
|
||||
directory string,
|
||||
) []codersdk.ChatMessagePart {
|
||||
hasContextFile := false
|
||||
hasSkill := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFile = true
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasSkill = true
|
||||
}
|
||||
if hasContextFile && hasSkill {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
if !hasSkill || hasContextFile {
|
||||
return parts
|
||||
}
|
||||
return append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: chatd.AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
ContextFileOS: operatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
}}, parts...)
|
||||
}
|
||||
|
||||
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
|
||||
sort.SliceStable(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// updateAgentChatLastInjectedContextFromMessages rebuilds the
|
||||
// injected-context cache from all persisted context-file and skill parts.
|
||||
func updateAgentChatLastInjectedContextFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) error {
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("load context messages for injected context: %w", err)
|
||||
}
|
||||
|
||||
sortChatMessagesByCreatedAtAndID(messages)
|
||||
|
||||
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect injected context parts: %w", err)
|
||||
}
|
||||
parts = chatd.FilterContextPartsToLatestAgent(parts)
|
||||
|
||||
param, err := chatd.BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, typ := range types {
|
||||
if part.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAgentChatError translates resolveAgentChat errors to HTTP
|
||||
// responses.
|
||||
func writeAgentChatError(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
err error,
|
||||
) {
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "No active chats found for this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotFound) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToAgent) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this workspace owner.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotActive) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Cannot modify context: this chat is no longer active.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var multipleErr *multipleActiveChatsError
|
||||
if errors.As(err, &multipleErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: owner.ID,
|
||||
}).WithAgent().Do()
|
||||
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
|
||||
|
||||
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, status := range database.AllChatStatusValues() {
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: status,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("%s-archived-%t", status, archived),
|
||||
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
insertedChats = append(insertedChats, chat)
|
||||
}
|
||||
}
|
||||
|
||||
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeByID := make(map[uuid.UUID]bool, len(activeChats))
|
||||
for _, chat := range activeChats {
|
||||
activeByID[chat.ID] = true
|
||||
}
|
||||
|
||||
for _, chat := range insertedChats {
|
||||
require.Equalf(
|
||||
t,
|
||||
isActiveAgentChat(chat),
|
||||
activeByID[chat.ID],
|
||||
"status=%s archived=%t",
|
||||
chat.Status,
|
||||
chat.Archived,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{
|
||||
{
|
||||
ID: 2,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: newContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: oldContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.True(t, arg.LastInjectedContext.Valid)
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 1)
|
||||
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
|
||||
return database.Chat{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
err = updateAgentChatLastInjectedContextFromMessages(
|
||||
context.Background(),
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db,
|
||||
chatID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func insertAgentChatTestModelConfig(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: createdBy,
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return model
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
|
||||
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
})
|
||||
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
+69
-12
@@ -213,6 +213,39 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Connecting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -247,10 +280,10 @@ func TestWorkspace(t *testing.T) {
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
|
||||
assert.False(t, agent.Health.Healthy)
|
||||
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Unhealthy", func(t *testing.T) {
|
||||
@@ -302,6 +335,7 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
a1AuthToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -313,7 +347,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "a1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: a1AuthToken,
|
||||
},
|
||||
}, {
|
||||
Id: uuid.NewString(),
|
||||
Name: "a2",
|
||||
@@ -330,13 +366,21 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, a1AuthToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && !workspace.Health.Healthy
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait for the mixed state: a1 connected (healthy)
|
||||
// and workspace unhealthy (because a2 timed out).
|
||||
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return agent1.Health.Healthy && !workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
@@ -360,6 +404,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// disconnected, but this should not make the workspace unhealthy.
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -371,7 +416,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "parent",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
@@ -383,14 +430,23 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Get the workspace and parent agent.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
|
||||
// Wait for the parent agent to connect and be healthy.
|
||||
var parentAgent codersdk.WorkspaceAgent
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return parentAgent.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
|
||||
|
||||
// Create a sub-agent with a short connection timeout so it becomes
|
||||
// unhealthy quickly (simulating a devcontainer rebuild scenario).
|
||||
@@ -404,6 +460,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// Wait for the sub-agent to become unhealthy due to timeout.
|
||||
var subAgentUnhealthy bool
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
+298
-103
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/util/singleflight"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -73,9 +74,9 @@ const (
|
||||
|
||||
// maxConcurrentRecordingUploads caps the number of recording
|
||||
// stop-and-store operations that can run concurrently. Each
|
||||
// slot buffers up to MaxRecordingSize (100 MB) in memory, so
|
||||
// this value implicitly bounds memory to roughly
|
||||
// maxConcurrentRecordingUploads * 100 MB.
|
||||
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
|
||||
// (110 MB) in memory, so this value implicitly bounds memory
|
||||
// to roughly maxConcurrentRecordingUploads * 110 MB.
|
||||
maxConcurrentRecordingUploads = 25
|
||||
|
||||
// staleRecoveryIntervalDivisor determines how often the stale
|
||||
@@ -96,6 +97,13 @@ const (
|
||||
// cross-replica relay subscribers time to connect and
|
||||
// snapshot the buffer before it is garbage-collected.
|
||||
bufferRetainGracePeriod = 5 * time.Second
|
||||
// chatStreamHistoryFetchTimeout bounds server-owned shared
|
||||
// history reads. It is intentionally generous because initial
|
||||
// and catch-up scans may be larger than control-path lookups.
|
||||
chatStreamHistoryFetchTimeout = 30 * time.Second
|
||||
// chatStreamControlFetchTimeout bounds subscriber-owned
|
||||
// control-path DB reads when the caller has no deadline.
|
||||
chatStreamControlFetchTimeout = 5 * time.Second
|
||||
|
||||
// DefaultMaxChatsPerAcquire is the maximum number of chats to
|
||||
// acquire in a single processOnce call. Batching avoids
|
||||
@@ -137,6 +145,11 @@ type Server struct {
|
||||
// never contend with each other.
|
||||
chatStreams sync.Map // uuid.UUID -> *chatStreamState
|
||||
|
||||
// streamMessageFetches coalesces concurrent chat stream durable
|
||||
// history reads. It is not a cache: once a shared fetch
|
||||
// completes, future reads hit the database again.
|
||||
streamMessageFetches singleflight.Group[string, []database.ChatMessage]
|
||||
|
||||
// workspaceMCPToolsCache caches workspace MCP tool definitions
|
||||
// per chat to avoid re-fetching on every turn. The cache is
|
||||
// keyed by chat ID and invalidated when the agent changes.
|
||||
@@ -996,7 +1009,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
return database.Chat{}, txErr
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return chat, nil
|
||||
}
|
||||
@@ -1158,7 +1171,7 @@ func (p *Server) SendMessage(
|
||||
|
||||
p.publishMessage(opts.ChatID, result.Message)
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
return result, nil
|
||||
}
|
||||
@@ -1301,7 +1314,7 @@ func (p *Server) EditMessage(
|
||||
QueueUpdate: true,
|
||||
})
|
||||
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
||||
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -1355,10 +1368,10 @@ func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
|
||||
if interrupted {
|
||||
p.publishStatus(chat.ID, statusChat.Status, statusChat.WorkerID)
|
||||
p.publishChatPubsubEvent(statusChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(statusChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvents(archivedChats, coderdpubsub.ChatEventKindDeleted)
|
||||
p.publishChatPubsubEvents(archivedChats, codersdk.ChatWatchEventKindDeleted)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1373,7 +1386,7 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
||||
ctx,
|
||||
chat.ID,
|
||||
"unarchive",
|
||||
coderdpubsub.ChatEventKindCreated,
|
||||
codersdk.ChatWatchEventKindCreated,
|
||||
p.db.UnarchiveChatByID,
|
||||
)
|
||||
}
|
||||
@@ -1382,7 +1395,7 @@ func (p *Server) applyChatLifecycleTransition(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
action string,
|
||||
kind coderdpubsub.ChatEventKind,
|
||||
kind codersdk.ChatWatchEventKind,
|
||||
transition func(context.Context, uuid.UUID) ([]database.Chat, error),
|
||||
) error {
|
||||
updatedChats, err := transition(ctx, chatID)
|
||||
@@ -1545,7 +1558,7 @@ func (p *Server) PromoteQueued(
|
||||
})
|
||||
p.publishMessage(opts.ChatID, promoted)
|
||||
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
p.signalWake()
|
||||
|
||||
return result, nil
|
||||
@@ -2092,7 +2105,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2347,7 +2360,7 @@ func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database
|
||||
return database.Chat{}, err
|
||||
}
|
||||
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
@@ -2461,6 +2474,33 @@ type chainModeInfo struct {
|
||||
// trailingUserCount is the number of contiguous user messages
|
||||
// at the end of the conversation that form the current turn.
|
||||
trailingUserCount int
|
||||
// contributingTrailingUserCount counts the trailing user
|
||||
// messages that materially change the provider input.
|
||||
contributingTrailingUserCount int
|
||||
}
|
||||
|
||||
func userMessageContributesToChainMode(msg database.ChatMessage) bool {
|
||||
parts, err := chatprompt.ParseContent(msg)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeText,
|
||||
codersdk.ChatMessagePartTypeReasoning:
|
||||
if strings.TrimSpace(part.Text) != "" {
|
||||
return true
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeFile,
|
||||
codersdk.ChatMessagePartTypeFileReference:
|
||||
return true
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if part.ContextFileContent != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveChainMode scans DB messages from the end to count trailing user
|
||||
@@ -2470,11 +2510,13 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
var info chainModeInfo
|
||||
i := len(messages) - 1
|
||||
for ; i >= 0; i-- {
|
||||
if messages[i].Role == database.ChatMessageRoleUser {
|
||||
info.trailingUserCount++
|
||||
continue
|
||||
if messages[i].Role != database.ChatMessageRoleUser {
|
||||
break
|
||||
}
|
||||
info.trailingUserCount++
|
||||
if userMessageContributesToChainMode(messages[i]) {
|
||||
info.contributingTrailingUserCount++
|
||||
}
|
||||
break
|
||||
}
|
||||
for ; i >= 0; i-- {
|
||||
switch messages[i].Role {
|
||||
@@ -2497,15 +2539,15 @@ func resolveChainMode(messages []database.ChatMessage) chainModeInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// filterPromptForChainMode keeps only system messages and the last
|
||||
// trailingUserCount user messages from the prompt. Assistant and tool
|
||||
// messages are dropped because the provider already has them via the
|
||||
// previous_response_id chain.
|
||||
// filterPromptForChainMode keeps only system messages and the trailing
|
||||
// user messages that still contribute model-visible content to the
|
||||
// current turn. Assistant and tool messages are dropped because the
|
||||
// provider already has them via the previous_response_id chain.
|
||||
func filterPromptForChainMode(
|
||||
prompt []fantasy.Message,
|
||||
trailingUserCount int,
|
||||
info chainModeInfo,
|
||||
) []fantasy.Message {
|
||||
if trailingUserCount <= 0 {
|
||||
if info.contributingTrailingUserCount <= 0 {
|
||||
return prompt
|
||||
}
|
||||
|
||||
@@ -2516,7 +2558,12 @@ func filterPromptForChainMode(
|
||||
}
|
||||
}
|
||||
|
||||
usersToSkip := totalUsers - trailingUserCount
|
||||
// Prompt construction already drops user turns with no model-visible
|
||||
// content, such as skill-only sentinel messages. That means the user
|
||||
// count here stays aligned with contributingTrailingUserCount even
|
||||
// when non-contributing DB turns are interleaved in the trailing
|
||||
// block.
|
||||
usersToSkip := totalUsers - info.contributingTrailingUserCount
|
||||
if usersToSkip < 0 {
|
||||
usersToSkip = 0
|
||||
}
|
||||
@@ -2562,6 +2609,28 @@ func appendChatMessage(
|
||||
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
|
||||
}
|
||||
|
||||
// BuildSingleChatMessageInsertParams creates batch insert params for one
|
||||
// message using the shared chat message builder.
|
||||
func BuildSingleChatMessageInsertParams(
|
||||
chatID uuid.UUID,
|
||||
role database.ChatMessageRole,
|
||||
content pqtype.NullRawMessage,
|
||||
visibility database.ChatMessageVisibility,
|
||||
modelConfigID uuid.UUID,
|
||||
contentVersion int16,
|
||||
createdBy uuid.UUID,
|
||||
) database.InsertChatMessagesParams {
|
||||
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: chatID,
|
||||
}
|
||||
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
|
||||
if createdBy != uuid.Nil {
|
||||
msg = msg.withCreatedBy(createdBy)
|
||||
}
|
||||
appendChatMessage(¶ms, msg)
|
||||
return params
|
||||
}
|
||||
|
||||
func insertUserMessageAndSetPending(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -3105,6 +3174,73 @@ func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func cloneChatMessagesForStream(messages []database.ChatMessage) []database.ChatMessage {
|
||||
cloned := slices.Clone(messages)
|
||||
for i := range cloned {
|
||||
cloned[i].Content.RawMessage = slices.Clone(cloned[i].Content.RawMessage)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// streamSharedHistoryFetchContext detaches subscriber cancellation from a
|
||||
// shared history fetch and runs it under a server-owned timeout budget.
|
||||
// Shared work should not inherit the winner's request deadline.
|
||||
func streamSharedHistoryFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(ctx), chatStreamHistoryFetchTimeout)
|
||||
}
|
||||
|
||||
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
|
||||
// requesting subscriber while applying a fallback timeout when the caller has
|
||||
// no deadline.
|
||||
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
return ctx, func() {}
|
||||
}
|
||||
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
|
||||
}
|
||||
|
||||
// getStreamChatMessages loads durable chat messages for an already-authorized
|
||||
// subscriber. Subscribe() must validate the caller before this helper is used.
|
||||
// The shared fetch intentionally runs as chatd so request identity and timeout
|
||||
// policy come from chatd rather than whichever caller won singleflight.
|
||||
func (p *Server) getStreamChatMessages(
|
||||
ctx context.Context,
|
||||
params database.GetChatMessagesByChatIDParams,
|
||||
) ([]database.ChatMessage, error) {
|
||||
messages, err := singleflightDoChan(
|
||||
ctx,
|
||||
&p.streamMessageFetches,
|
||||
fmt.Sprintf("chat-messages:%s:after:%d", params.ChatID, params.AfterID),
|
||||
func() ([]database.ChatMessage, error) {
|
||||
fetchCtx, cancel := streamSharedHistoryFetchContext(ctx)
|
||||
defer cancel()
|
||||
//nolint:gocritic // SubscribeAuthorized already validated the
|
||||
// caller; the shared singleflight fetch runs as chatd so the
|
||||
// leader's request identity cannot affect other authorized waiters.
|
||||
return p.db.GetChatMessagesByChatID(dbauthz.AsChatd(fetchCtx), params)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cloneChatMessagesForStream(messages), nil
|
||||
}
|
||||
|
||||
func subscribeWithInitialError(chatID uuid.UUID, message string) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
events := make(chan codersdk.ChatStreamEvent)
|
||||
close(events)
|
||||
return []codersdk.ChatStreamEvent{{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{Message: message},
|
||||
}}, events, func() {}, true
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -3119,9 +3255,40 @@ func (p *Server) Subscribe(
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
|
||||
chat, err := p.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
p.logger.Warn(ctx, "failed to load chat for stream subscription",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
|
||||
}
|
||||
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
|
||||
}
|
||||
|
||||
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
|
||||
// updates. The passed chat row proves authorization, but SubscribeAuthorized
|
||||
// still reloads the chat after the stream subscriptions are armed so the
|
||||
// initial status and relay setup use fresh state.
|
||||
func (p *Server) SubscribeAuthorized(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
requestHeader http.Header,
|
||||
afterMessageID int64,
|
||||
) (
|
||||
[]codersdk.ChatStreamEvent,
|
||||
<-chan codersdk.ChatStreamEvent,
|
||||
func(),
|
||||
bool,
|
||||
) {
|
||||
if p == nil {
|
||||
return nil, nil, nil, false
|
||||
}
|
||||
chatID := chat.ID
|
||||
|
||||
// Subscribe to the local stream for message_parts and same-replica
|
||||
// persisted messages.
|
||||
@@ -3185,6 +3352,34 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-read the chat after the local/pubsub subscriptions are active so
|
||||
// the initial status event and any enterprise relay setup use fresh
|
||||
// state instead of the middleware-loaded row.
|
||||
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
|
||||
snapshotChat, err := func() (database.Chat, error) {
|
||||
defer refreshCancel()
|
||||
//nolint:gocritic // SubscribeAuthorized already validated the
|
||||
// caller; this refresh only loads the latest status/worker for
|
||||
// the already-authorized stream subscription.
|
||||
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
|
||||
}()
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
snapshotChat = chat
|
||||
}
|
||||
|
||||
// Build initial snapshot synchronously. The pubsub subscription
|
||||
// is already active so no notifications can be lost during this
|
||||
// window.
|
||||
@@ -3200,7 +3395,7 @@ func (p *Server) Subscribe(
|
||||
// caller already has messages up to that ID (e.g. from the REST
|
||||
// endpoint), so we only fetch newer ones to avoid sending
|
||||
// duplicate data.
|
||||
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
messages, err := p.getStreamChatMessages(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: afterMessageID,
|
||||
})
|
||||
@@ -3225,8 +3420,12 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
|
||||
// Load initial queue.
|
||||
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
||||
// Load initial queue. Queue snapshots are intentionally not
|
||||
// singleflighted because a chat-scoped key cannot distinguish the
|
||||
// pre- and post-notification queue state.
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
|
||||
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to load initial queued messages",
|
||||
slog.Error(err),
|
||||
@@ -3245,35 +3444,18 @@ func (p *Server) Subscribe(
|
||||
})
|
||||
}
|
||||
|
||||
// Get initial chat state to determine if we need a relay.
|
||||
chat, chatErr := p.db.GetChatByID(ctx, chatID)
|
||||
|
||||
// Include the current chat status in the snapshot so the
|
||||
// frontend can gate message_part processing correctly from
|
||||
// the very first batch, without waiting for a separate REST
|
||||
// query.
|
||||
if chatErr != nil {
|
||||
p.logger.Error(ctx, "failed to load initial chat state",
|
||||
slog.Error(chatErr),
|
||||
slog.F("chat_id", chatID),
|
||||
)
|
||||
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeError,
|
||||
ChatID: chatID,
|
||||
Error: &codersdk.ChatStreamError{Message: "failed to load initial snapshot"},
|
||||
})
|
||||
} else {
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(chat.Status),
|
||||
},
|
||||
}
|
||||
// Prepend so the frontend sees the status before any
|
||||
// message_part events.
|
||||
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
||||
// Include the current chat status in the snapshot so the frontend can gate
|
||||
// message_part processing correctly from the very first batch, without
|
||||
// waiting for a separate REST query.
|
||||
statusEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeStatus,
|
||||
ChatID: chatID,
|
||||
Status: &codersdk.ChatStreamStatus{
|
||||
Status: codersdk.ChatStatus(snapshotChat.Status),
|
||||
},
|
||||
}
|
||||
// Prepend so the frontend sees the status before any message_part events.
|
||||
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
||||
|
||||
// Track the highest durable message ID delivered to this subscriber,
|
||||
// whether it came from the initial DB snapshot, the same-replica local
|
||||
@@ -3283,18 +3465,17 @@ func (p *Server) Subscribe(
|
||||
lastMessageID = messages[len(messages)-1].ID
|
||||
}
|
||||
|
||||
// When an enterprise SubscribeFn is provided and the chat
|
||||
// lookup succeeded, call it to get relay events (message_parts
|
||||
// from remote replicas). OSS now owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding;
|
||||
// enterprise only manages relay dialing.
|
||||
// When an enterprise SubscribeFn is provided, call it to get relay events
|
||||
// (message_parts from remote replicas). OSS owns pubsub subscription,
|
||||
// message catch-up, queue updates, and status forwarding; enterprise only
|
||||
// manages relay dialing.
|
||||
var relayEvents <-chan codersdk.ChatStreamEvent
|
||||
var statusNotifications chan StatusNotification
|
||||
if p.subscribeFn != nil && chatErr == nil {
|
||||
if p.subscribeFn != nil {
|
||||
statusNotifications = make(chan StatusNotification, 10)
|
||||
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
|
||||
ChatID: chatID,
|
||||
Chat: chat,
|
||||
Chat: snapshotChat,
|
||||
WorkerID: p.workerID,
|
||||
StatusNotifications: statusNotifications,
|
||||
RequestHeader: requestHeader,
|
||||
@@ -3351,7 +3532,7 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
lastMessageID = event.Message.ID
|
||||
}
|
||||
} else if newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
} else if newMessages, msgErr := p.getStreamChatMessages(mergedCtx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: lastMessageID,
|
||||
}); msgErr != nil {
|
||||
@@ -3439,7 +3620,9 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}
|
||||
if notify.QueueUpdate {
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
||||
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
|
||||
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
||||
queueCancel()
|
||||
if queueErr != nil {
|
||||
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
|
||||
slog.F("chat_id", chatID),
|
||||
@@ -3515,14 +3698,6 @@ func (p *Server) Subscribe(
|
||||
}
|
||||
}()
|
||||
|
||||
cancel := func() {
|
||||
mergedCancel()
|
||||
for _, cancelFn := range allCancels {
|
||||
if cancelFn != nil {
|
||||
cancelFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
return initialSnapshot, mergedEvents, cancel, true
|
||||
}
|
||||
|
||||
@@ -3571,7 +3746,7 @@ func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.C
|
||||
}
|
||||
|
||||
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsub.ChatEventKind) {
|
||||
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
|
||||
for _, chat := range chats {
|
||||
p.publishChatPubsubEvent(chat, kind, nil)
|
||||
}
|
||||
@@ -3579,7 +3754,7 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind coderdpubsu
|
||||
|
||||
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
||||
// pubsub so that all replicas can push updates to watching clients.
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
@@ -3591,7 +3766,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
if diffStatus != nil {
|
||||
sdkChat.DiffStatus = diffStatus
|
||||
}
|
||||
event := coderdpubsub.ChatEvent{
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: kind,
|
||||
Chat: sdkChat,
|
||||
}
|
||||
@@ -3603,7 +3778,7 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("kind", kind),
|
||||
@@ -3636,8 +3811,8 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
toolCalls := pendingToStreamToolCalls(pending)
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindActionRequired,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindActionRequired,
|
||||
Chat: sdkChat,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
@@ -3649,7 +3824,7 @@ func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloo
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
@@ -3677,7 +3852,7 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID)
|
||||
}
|
||||
|
||||
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange, &sdkStatus)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4159,7 +4334,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
if title, ok := generatedTitle.Load(); ok {
|
||||
updatedChat.Title = title
|
||||
}
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
||||
|
||||
// When the chat is parked in requires_action,
|
||||
// publish the stream event and global pubsub event
|
||||
@@ -4430,13 +4605,21 @@ func (p *Server) runChat(
|
||||
// the workspace agent has changed (e.g. workspace rebuilt).
|
||||
needsInstructionPersist := false
|
||||
hasContextFiles := false
|
||||
persistedSkills := skillsFromParts(messages)
|
||||
latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages)
|
||||
currentWorkspaceAgentID := uuid.Nil
|
||||
hasCurrentWorkspaceAgent := false
|
||||
if chat.WorkspaceID.Valid {
|
||||
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
||||
currentWorkspaceAgentID = agent.ID
|
||||
hasCurrentWorkspaceAgent = true
|
||||
}
|
||||
persistedAgentID, found := contextFileAgentID(messages)
|
||||
hasContextFiles = found
|
||||
if !hasContextFiles {
|
||||
if !hasPersistedInstructionFiles(messages) {
|
||||
needsInstructionPersist = true
|
||||
} else if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil && agent.ID != persistedAgentID {
|
||||
// Agent changed — persist fresh instruction files.
|
||||
} else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID {
|
||||
// Agent changed. Persist fresh instruction files.
|
||||
// Old context-file messages remain in the conversation
|
||||
// to preserve the prompt cache prefix.
|
||||
needsInstructionPersist = true
|
||||
@@ -4459,7 +4642,8 @@ func (p *Server) runChat(
|
||||
if needsInstructionPersist {
|
||||
g2.Go(func() error {
|
||||
var persistErr error
|
||||
instruction, skills, persistErr = p.persistInstructionFiles(
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
instruction, discoveredSkills, persistErr = p.persistInstructionFiles(
|
||||
ctx,
|
||||
chat,
|
||||
modelConfig.ID,
|
||||
@@ -4471,6 +4655,12 @@ func (p *Server) runChat(
|
||||
return workspaceCtx.getWorkspaceConn(instructionCtx)
|
||||
},
|
||||
)
|
||||
skills = selectSkillMetasForInstructionRefresh(
|
||||
persistedSkills,
|
||||
discoveredSkills,
|
||||
uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent},
|
||||
uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent},
|
||||
)
|
||||
if persistErr != nil {
|
||||
p.logger.Warn(ctx, "failed to persist instruction files",
|
||||
slog.F("chat_id", chat.ID),
|
||||
@@ -4485,7 +4675,7 @@ func (p *Server) runChat(
|
||||
// re-injected via InsertSystem after compaction drops
|
||||
// those messages. No workspace dial needed.
|
||||
instruction = instructionFromContextFiles(messages)
|
||||
skills = skillsFromParts(messages)
|
||||
skills = persistedSkills
|
||||
}
|
||||
g2.Go(func() error {
|
||||
resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID)
|
||||
@@ -5103,14 +5293,14 @@ func (p *Server) runChat(
|
||||
// assistant and tool messages that the provider already has.
|
||||
chainModeActive := chatprovider.IsResponsesStoreEnabled(providerOptions) &&
|
||||
chainInfo.previousResponseID != "" &&
|
||||
chainInfo.trailingUserCount > 0 &&
|
||||
chainInfo.contributingTrailingUserCount > 0 &&
|
||||
chainInfo.modelConfigID == modelConfig.ID
|
||||
if chainModeActive {
|
||||
providerOptions = chatprovider.CloneWithPreviousResponseID(
|
||||
providerOptions,
|
||||
chainInfo.previousResponseID,
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo)
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
@@ -5164,7 +5354,7 @@ func (p *Server) runChat(
|
||||
if chainModeActive {
|
||||
reloadedPrompt = filterPromptForChainMode(
|
||||
reloadedPrompt,
|
||||
chainInfo.trailingUserCount,
|
||||
chainInfo,
|
||||
)
|
||||
}
|
||||
return reloadedPrompt, nil
|
||||
@@ -5537,8 +5727,9 @@ func refreshChatWorkspaceSnapshot(
|
||||
}
|
||||
|
||||
// contextFileAgentID extracts the workspace agent ID from the most
|
||||
// recent persisted context-file parts. Returns uuid.Nil, false if no
|
||||
// context-file parts exist.
|
||||
// recent persisted instruction-file parts. The skill-only sentinel is
|
||||
// ignored because it does not represent persisted instruction content.
|
||||
// Returns uuid.Nil, false if no instruction-file parts exist.
|
||||
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
@@ -5551,11 +5742,14 @@ func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
continue
|
||||
}
|
||||
for _, p := range parts {
|
||||
if p.Type == codersdk.ChatMessagePartTypeContextFile && p.ContextFileAgentID.Valid {
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!p.ContextFileAgentID.Valid ||
|
||||
p.ContextFilePath == AgentChatContextSentinelPath {
|
||||
continue
|
||||
}
|
||||
lastID = p.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return lastID, found
|
||||
@@ -5625,13 +5819,14 @@ func (p *Server) persistInstructionFiles(
|
||||
// agent cannot know its own UUID, OS metadata, or
|
||||
// directory — those are added here at the trust boundary.
|
||||
var discoveredSkills []chattool.SkillMeta
|
||||
var hasContent bool
|
||||
var hasContent, hasContextFilePart bool
|
||||
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
|
||||
|
||||
for i := range agentParts {
|
||||
agentParts[i].ContextFileAgentID = agentID
|
||||
switch agentParts[i].Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFilePart = true
|
||||
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
|
||||
agentParts[i].ContextFileOS = agent.OperatingSystem
|
||||
agentParts[i].ContextFileDirectory = directory
|
||||
@@ -5652,13 +5847,13 @@ func (p *Server) persistInstructionFiles(
|
||||
if !workspaceConnOK {
|
||||
return "", nil, nil
|
||||
}
|
||||
// Persist a sentinel (plus any skill-only parts) so
|
||||
// subsequent turns skip the workspace agent dial.
|
||||
if len(agentParts) == 0 {
|
||||
agentParts = []codersdk.ChatMessagePart{{
|
||||
// Persist a blank context-file marker (plus any skill-only
|
||||
// parts) so subsequent turns skip the workspace agent dial.
|
||||
if !hasContextFilePart {
|
||||
agentParts = append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFileAgentID: agentID,
|
||||
}}
|
||||
}}, agentParts...)
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(agentParts)
|
||||
if err != nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -100,6 +101,10 @@ type RunOptions struct {
|
||||
// first stream part before the attempt is canceled and
|
||||
// retried. Zero uses the production default.
|
||||
StartupTimeout time.Duration
|
||||
// Clock creates startup guard timers. In production use a
|
||||
// real clock; tests can inject quartz.NewMock(t) to make
|
||||
// startup timeout behavior deterministic.
|
||||
Clock quartz.Clock
|
||||
|
||||
ActiveTools []string
|
||||
ContextLimitFallback int64
|
||||
@@ -289,6 +294,9 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
if opts.StartupTimeout <= 0 {
|
||||
opts.StartupTimeout = defaultStartupTimeout
|
||||
}
|
||||
if opts.Clock == nil {
|
||||
opts.Clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
@@ -364,6 +372,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
attempt, streamErr := guardedStream(
|
||||
retryCtx,
|
||||
opts.Model.Provider(),
|
||||
opts.Clock,
|
||||
opts.StartupTimeout,
|
||||
func(attemptCtx context.Context) (fantasy.StreamResponse, error) {
|
||||
return opts.Model.Stream(attemptCtx, call)
|
||||
@@ -660,17 +669,18 @@ type guardedAttempt struct {
|
||||
// stream startup. Exactly one outcome wins: the timer cancels
|
||||
// the attempt, or the first-part path disarms the timer.
|
||||
type startupGuard struct {
|
||||
timer *time.Timer
|
||||
timer *quartz.Timer
|
||||
cancel context.CancelCauseFunc
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newStartupGuard(
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
cancel context.CancelCauseFunc,
|
||||
) *startupGuard {
|
||||
guard := &startupGuard{cancel: cancel}
|
||||
guard.timer = time.AfterFunc(timeout, guard.onTimeout)
|
||||
guard.timer = clock.AfterFunc(timeout, guard.onTimeout, "startupGuard")
|
||||
return guard
|
||||
}
|
||||
|
||||
@@ -707,11 +717,12 @@ func classifyStartupTimeout(
|
||||
func guardedStream(
|
||||
parent context.Context,
|
||||
provider string,
|
||||
clock quartz.Clock,
|
||||
timeout time.Duration,
|
||||
openStream func(context.Context) (fantasy.StreamResponse, error),
|
||||
) (guardedAttempt, error) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(parent)
|
||||
guard := newStartupGuard(timeout, cancelAttempt)
|
||||
guard := newStartupGuard(clock, timeout, cancelAttempt)
|
||||
var releaseOnce sync.Once
|
||||
release := func() {
|
||||
releaseOnce.Do(func() {
|
||||
|
||||
@@ -19,10 +19,24 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const activeToolName = "read_file"
|
||||
|
||||
func awaitRunResult(ctx context.Context, t *testing.T, done <-chan error) error {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for Run to complete")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -202,7 +216,7 @@ func TestStartupGuard_DisarmAndFireRace(t *testing.T) {
|
||||
|
||||
for range 128 {
|
||||
var cancels atomic.Int32
|
||||
guard := newStartupGuard(time.Hour, func(err error) {
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, func(err error) {
|
||||
if errors.Is(err, errStartupTimeout) {
|
||||
cancels.Add(1)
|
||||
}
|
||||
@@ -240,7 +254,7 @@ func TestStartupGuard_DisarmPreservesPermanentError(t *testing.T) {
|
||||
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
|
||||
defer cancelAttempt(nil)
|
||||
|
||||
guard := newStartupGuard(time.Hour, cancelAttempt)
|
||||
guard := newStartupGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
||||
guard.Disarm()
|
||||
guard.onTimeout()
|
||||
|
||||
@@ -259,6 +273,16 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -278,23 +302,32 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -305,7 +338,12 @@ func TestRun_RetriesStartupTimeoutWhileOpeningStream(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
@@ -313,6 +351,16 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -337,23 +385,32 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -364,7 +421,12 @@ func TestRun_RetriesStartupTimeoutBeforeFirstPart(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
@@ -372,8 +434,19 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
|
||||
attempts := 0
|
||||
retried := false
|
||||
firstPartYielded := make(chan struct{}, 1)
|
||||
continueStream := make(chan struct{})
|
||||
model := &loopTestModel{
|
||||
provider: "openai",
|
||||
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
@@ -382,18 +455,19 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(startupTimeout * 2)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case firstPartYielded <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-continueStream:
|
||||
case <-ctx.Done():
|
||||
_ = yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeError,
|
||||
Error: ctx.Err(),
|
||||
})
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
parts := []fantasy.StreamPart{
|
||||
@@ -410,23 +484,40 @@ func TestRun_FirstPartDisarmsStartupTimeout(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
_ chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retried = true
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
_ chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retried = true
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
trap.Close()
|
||||
|
||||
select {
|
||||
case <-firstPartYielded:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for first stream part")
|
||||
}
|
||||
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
close(continueStream)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 1, attempts)
|
||||
require.False(t, retried)
|
||||
}
|
||||
@@ -479,6 +570,16 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
|
||||
const startupTimeout = 5 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
testutil.WaitShort,
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
trap := mClock.Trap().AfterFunc("startupGuard")
|
||||
defer trap.Close()
|
||||
|
||||
attempts := 0
|
||||
attemptCause := make(chan error, 1)
|
||||
var retries []chatretry.ClassifiedError
|
||||
@@ -499,23 +600,32 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
MaxSteps: 1,
|
||||
StartupTimeout: startupTimeout,
|
||||
Clock: mClock,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
OnRetry: func(
|
||||
_ int,
|
||||
_ error,
|
||||
classified chatretry.ClassifiedError,
|
||||
_ time.Duration,
|
||||
) {
|
||||
retries = append(retries, classified)
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
mClock.Advance(startupTimeout).MustWait(ctx)
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
require.NoError(t, awaitRunResult(ctx, t, done))
|
||||
require.Equal(t, 2, attempts)
|
||||
require.Len(t, retries, 1)
|
||||
require.Equal(t, chaterror.KindStartupTimeout, retries[0].Kind)
|
||||
@@ -526,7 +636,12 @@ func TestRun_RetriesStartupTimeoutWhenStreamClosesSilently(t *testing.T) {
|
||||
"OpenAI did not start responding in time.",
|
||||
retries[0].Message,
|
||||
)
|
||||
require.ErrorIs(t, <-attemptCause, errStartupTimeout)
|
||||
select {
|
||||
case cause := <-attemptCause:
|
||||
require.ErrorIs(t, cause, errStartupTimeout)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for startup timeout cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// AgentChatContextSentinelPath marks the synthetic empty context-file
|
||||
// part used to preserve skill-only workspace-agent additions across
|
||||
// turns without treating them as persisted instruction files.
|
||||
const AgentChatContextSentinelPath = ".coder/agent-chat-context-sentinel"
|
||||
|
||||
// FilterContextParts keeps only context-file and skill parts from parts.
|
||||
// When keepEmptyContextFiles is false, context-file parts with empty
|
||||
// content are dropped. When keepEmptyContextFiles is true, empty
|
||||
// context-file parts are preserved.
|
||||
// revive:disable-next-line:flag-parameter // Required by shared helper callers.
|
||||
func FilterContextParts(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
keepEmptyContextFiles bool,
|
||||
) []codersdk.ChatMessagePart {
|
||||
var filtered []codersdk.ChatMessagePart
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
if !keepEmptyContextFiles && part.ContextFileContent == "" {
|
||||
continue
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
default:
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// CollectContextPartsFromMessages unmarshals chat message content and
|
||||
// collects the context-file and skill parts it contains. When
|
||||
// keepEmptyContextFiles is false, empty context-file parts are skipped.
|
||||
// When it is true, empty context-file parts are included in the result.
|
||||
func CollectContextPartsFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
messages []database.ChatMessage,
|
||||
keepEmptyContextFiles bool,
|
||||
) ([]codersdk.ChatMessagePart, error) {
|
||||
var collected []codersdk.ChatMessagePart
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
logger.Warn(ctx, "skipping malformed chat context message",
|
||||
slog.F("chat_message_id", msg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
collected = append(
|
||||
collected,
|
||||
FilterContextParts(parts, keepEmptyContextFiles)...,
|
||||
)
|
||||
}
|
||||
|
||||
return collected, nil
|
||||
}
|
||||
|
||||
func latestContextAgentIDFromParts(parts []codersdk.ChatMessagePart) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid {
|
||||
continue
|
||||
}
|
||||
lastID = part.ContextFileAgentID.UUID
|
||||
found = true
|
||||
}
|
||||
return lastID, found
|
||||
}
|
||||
|
||||
// FilterContextPartsToLatestAgent keeps parts stamped with the latest
|
||||
// workspace-agent ID seen in the slice, plus legacy unstamped parts.
|
||||
// When no stamped context-file parts exist, it returns the original
|
||||
// slice unchanged.
|
||||
func FilterContextPartsToLatestAgent(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
|
||||
latestAgentID, ok := latestContextAgentIDFromParts(parts)
|
||||
if !ok {
|
||||
return parts
|
||||
}
|
||||
|
||||
filtered := make([]codersdk.ChatMessagePart, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile,
|
||||
codersdk.ChatMessagePartTypeSkill:
|
||||
if part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != latestAgentID {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, part)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// BuildLastInjectedContext filters parts down to non-empty context-file
|
||||
// and skill parts, strips their internal fields, and marshals the
|
||||
// result for LastInjectedContext. A nil or fully filtered input returns
|
||||
// an invalid NullRawMessage.
|
||||
func BuildLastInjectedContext(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
) (pqtype.NullRawMessage, error) {
|
||||
if parts == nil {
|
||||
return pqtype.NullRawMessage{Valid: false}, nil
|
||||
}
|
||||
|
||||
filtered := FilterContextParts(parts, false)
|
||||
if len(filtered) == 0 {
|
||||
return pqtype.NullRawMessage{Valid: false}, nil
|
||||
}
|
||||
|
||||
stripped := make([]codersdk.ChatMessagePart, 0, len(filtered))
|
||||
for _, part := range filtered {
|
||||
cp := part
|
||||
cp.StripInternal()
|
||||
stripped = append(stripped, cp)
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(stripped)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf(
|
||||
"marshal injected context: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return pqtype.NullRawMessage{RawMessage: raw, Valid: true}, nil
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -57,6 +59,34 @@ func formatSystemInstructions(
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// latestContextAgentID returns the most recent workspace-agent ID seen
|
||||
// on any persisted context-file part, including the skill-only sentinel.
|
||||
// Returns uuid.Nil, false when no stamped context-file parts exist.
|
||||
func latestContextAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
||||
var lastID uuid.UUID
|
||||
found := false
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid {
|
||||
continue
|
||||
}
|
||||
lastID = part.ContextFileAgentID.UUID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return lastID, found
|
||||
}
|
||||
|
||||
// instructionFromContextFiles reconstructs the formatted instruction
|
||||
// string from persisted context-file parts. This is used on non-first
|
||||
// turns so the instruction can be re-injected after compaction
|
||||
@@ -64,6 +94,7 @@ func formatSystemInstructions(
|
||||
func instructionFromContextFiles(
|
||||
messages []database.ChatMessage,
|
||||
) string {
|
||||
filterAgentID, filterByAgent := latestContextAgentID(messages)
|
||||
var contextParts []codersdk.ChatMessagePart
|
||||
var os, dir string
|
||||
for _, msg := range messages {
|
||||
@@ -79,6 +110,10 @@ func instructionFromContextFiles(
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
if filterByAgent && part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != filterAgentID {
|
||||
continue
|
||||
}
|
||||
if part.ContextFileOS != "" {
|
||||
os = part.ContextFileOS
|
||||
}
|
||||
@@ -93,6 +128,80 @@ func instructionFromContextFiles(
|
||||
return formatSystemInstructions(os, dir, contextParts)
|
||||
}
|
||||
|
||||
// hasPersistedInstructionFiles reports whether messages include a
|
||||
// persisted context-file part that should suppress another baseline
|
||||
// instruction-file lookup. The workspace-agent skill-only sentinel is
|
||||
// ignored so default instructions still load on fresh chats.
|
||||
func hasPersistedInstructionFiles(
|
||||
messages []database.ChatMessage,
|
||||
) bool {
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
!bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeContextFile ||
|
||||
!part.ContextFileAgentID.Valid ||
|
||||
part.ContextFilePath == AgentChatContextSentinelPath {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeSkillMetas(
|
||||
persisted []chattool.SkillMeta,
|
||||
discovered []chattool.SkillMeta,
|
||||
) []chattool.SkillMeta {
|
||||
if len(persisted) == 0 {
|
||||
return discovered
|
||||
}
|
||||
if len(discovered) == 0 {
|
||||
return persisted
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(persisted)+len(discovered))
|
||||
merged := make([]chattool.SkillMeta, 0, len(persisted)+len(discovered))
|
||||
appendUnique := func(skill chattool.SkillMeta) {
|
||||
if _, ok := seen[skill.Name]; ok {
|
||||
return
|
||||
}
|
||||
seen[skill.Name] = struct{}{}
|
||||
merged = append(merged, skill)
|
||||
}
|
||||
for _, skill := range discovered {
|
||||
appendUnique(skill)
|
||||
}
|
||||
for _, skill := range persisted {
|
||||
appendUnique(skill)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// selectSkillMetasForInstructionRefresh chooses which skill metadata
|
||||
// should be injected on a turn that refreshes instruction files.
|
||||
func selectSkillMetasForInstructionRefresh(
|
||||
persisted []chattool.SkillMeta,
|
||||
discovered []chattool.SkillMeta,
|
||||
currentAgentID uuid.NullUUID,
|
||||
latestInjectedAgentID uuid.NullUUID,
|
||||
) []chattool.SkillMeta {
|
||||
if currentAgentID.Valid && latestInjectedAgentID.Valid && latestInjectedAgentID.UUID == currentAgentID.UUID {
|
||||
return mergeSkillMetas(persisted, discovered)
|
||||
}
|
||||
if !currentAgentID.Valid && len(discovered) == 0 {
|
||||
return persisted
|
||||
}
|
||||
return discovered
|
||||
}
|
||||
|
||||
// skillsFromParts reconstructs skill metadata from persisted
|
||||
// skill parts. This is analogous to instructionFromContextFiles
|
||||
// so the skill index can be re-injected after compaction without
|
||||
@@ -100,6 +209,7 @@ func instructionFromContextFiles(
|
||||
func skillsFromParts(
|
||||
messages []database.ChatMessage,
|
||||
) []chattool.SkillMeta {
|
||||
filterAgentID, filterByAgent := latestContextAgentID(messages)
|
||||
var skills []chattool.SkillMeta
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid ||
|
||||
@@ -114,6 +224,10 @@ func skillsFromParts(
|
||||
if part.Type != codersdk.ChatMessagePartTypeSkill {
|
||||
continue
|
||||
}
|
||||
if filterByAgent && part.ContextFileAgentID.Valid &&
|
||||
part.ContextFileAgentID.UUID != filterAgentID {
|
||||
continue
|
||||
}
|
||||
skills = append(skills, chattool.SkillMeta{
|
||||
Name: part.SkillName,
|
||||
Description: part.SkillDescription,
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -160,7 +159,7 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+148
-54
@@ -2,8 +2,11 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -13,71 +16,60 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type recordingResult struct {
|
||||
recordingFileID string
|
||||
thumbnailFileID string
|
||||
}
|
||||
|
||||
// stopAndStoreRecording stops the desktop recording, downloads the
|
||||
// MP4, and stores it in chat_files. Only called when the subagent
|
||||
// completed successfully. Returns the file ID on success, empty
|
||||
// string on any failure. All errors are logged but not propagated
|
||||
// — recording is best-effort.
|
||||
// multipart response containing the MP4 and optional thumbnail, and
|
||||
// stores them in chat_files. Only called when the subagent completed
|
||||
// successfully. Returns file IDs on success, empty fields on any
|
||||
// failure. All errors are logged but not propagated; recording is
|
||||
// best-effort.
|
||||
func (p *Server) stopAndStoreRecording(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
recordingID string,
|
||||
ownerID uuid.UUID,
|
||||
workspaceID uuid.NullUUID,
|
||||
) string {
|
||||
) recordingResult {
|
||||
var result recordingResult
|
||||
|
||||
select {
|
||||
case p.recordingSem <- struct{}{}:
|
||||
defer func() { <-p.recordingSem }()
|
||||
case <-ctx.Done():
|
||||
p.logger.Warn(ctx, "context canceled waiting for recording semaphore", slog.Error(ctx.Err()))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
body, err := conn.StopDesktopRecording(ctx,
|
||||
resp, err := conn.StopDesktopRecording(ctx,
|
||||
workspacesdk.StopDesktopRecordingRequest{RecordingID: recordingID})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to stop desktop recording",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
type readResult struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
ch := make(chan readResult, 1)
|
||||
go func() {
|
||||
data, err := io.ReadAll(io.LimitReader(body, workspacesdk.MaxRecordingSize+1))
|
||||
ch <- readResult{data, err}
|
||||
}()
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data []byte
|
||||
select {
|
||||
case res := <-ch:
|
||||
body.Close()
|
||||
data = res.data
|
||||
if res.err != nil {
|
||||
p.logger.Warn(ctx, "failed to read recording data", slog.Error(res.err))
|
||||
return ""
|
||||
}
|
||||
case <-ctx.Done():
|
||||
body.Close()
|
||||
p.logger.Warn(ctx, "context canceled while reading recording data", slog.Error(ctx.Err()))
|
||||
return ""
|
||||
_, params, err := mime.ParseMediaType(resp.ContentType)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse content type from recording response",
|
||||
slog.F("content_type", resp.ContentType),
|
||||
slog.Error(err))
|
||||
return result
|
||||
}
|
||||
if len(data) > workspacesdk.MaxRecordingSize {
|
||||
p.logger.Warn(ctx, "recording data exceeds maximum size, skipping store",
|
||||
slog.F("size", len(data)),
|
||||
slog.F("max_size", workspacesdk.MaxRecordingSize))
|
||||
return ""
|
||||
}
|
||||
if len(data) == 0 {
|
||||
p.logger.Warn(ctx, "recording data is empty, skipping store")
|
||||
return ""
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
p.logger.Warn(ctx, "missing boundary in recording response content type",
|
||||
slog.F("content_type", resp.ContentType))
|
||||
return result
|
||||
}
|
||||
|
||||
if !workspaceID.Valid {
|
||||
p.logger.Warn(ctx, "chat has no workspace, cannot store recording")
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
// The chatd actor is used here because the recording is stored on
|
||||
@@ -87,21 +79,123 @@ func (p *Server) stopAndStoreRecording(
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to resolve workspace for recording",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
return result
|
||||
}
|
||||
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "video/mp4",
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store recording in database",
|
||||
slog.Error(err))
|
||||
return ""
|
||||
mr := multipart.NewReader(resp.Body, boundary)
|
||||
// Context cancellation is checked between parts. Within a
|
||||
// part read, cancellation relies on Go's HTTP transport closing
|
||||
// the underlying connection when the context is done, which
|
||||
// interrupts the blocked io.ReadAll.
|
||||
// First pass: parse all multipart parts into memory.
|
||||
// The agent sends at most two parts: one video/mp4 and one
|
||||
// optional image/jpeg thumbnail. Cap the number of parts to
|
||||
// prevent a malicious or broken agent from forcing the server
|
||||
// into an unbounded parsing loop.
|
||||
const maxParts = 2
|
||||
var videoData, thumbnailData []byte
|
||||
for range maxParts {
|
||||
if ctx.Err() != nil {
|
||||
p.logger.Warn(ctx, "context canceled while reading recording parts", slog.Error(ctx.Err()))
|
||||
break
|
||||
}
|
||||
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "error reading next multipart part", slog.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
contentType := part.Header.Get("Content-Type")
|
||||
|
||||
// Select the read limit based on content type so that
|
||||
// thumbnails (image/jpeg) do not allocate up to
|
||||
// MaxRecordingSize (100 MB) before the size check rejects
|
||||
// them. Unknown types use a small default since they are
|
||||
// discarded below.
|
||||
maxSize := int64(1 << 20) // 1 MB default for unknown types
|
||||
switch contentType {
|
||||
case "video/mp4":
|
||||
maxSize = int64(workspacesdk.MaxRecordingSize)
|
||||
case "image/jpeg":
|
||||
maxSize = int64(workspacesdk.MaxThumbnailSize)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(part, maxSize+1))
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to read recording part data",
|
||||
slog.F("content_type", contentType),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if int64(len(data)) > maxSize {
|
||||
p.logger.Warn(ctx, "recording part exceeds maximum size, skipping",
|
||||
slog.F("content_type", contentType),
|
||||
slog.F("size", len(data)),
|
||||
slog.F("max_size", maxSize))
|
||||
continue
|
||||
}
|
||||
if len(data) == 0 {
|
||||
p.logger.Warn(ctx, "recording part is empty, skipping",
|
||||
slog.F("content_type", contentType))
|
||||
continue
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "video/mp4":
|
||||
if videoData != nil {
|
||||
p.logger.Warn(ctx, "duplicate video/mp4 part in recording response, skipping")
|
||||
continue
|
||||
}
|
||||
videoData = data
|
||||
case "image/jpeg":
|
||||
if thumbnailData != nil {
|
||||
p.logger.Warn(ctx, "duplicate image/jpeg part in recording response, skipping")
|
||||
continue
|
||||
}
|
||||
thumbnailData = data
|
||||
default:
|
||||
p.logger.Debug(ctx, "skipping unknown part content type",
|
||||
slog.F("content_type", contentType))
|
||||
}
|
||||
}
|
||||
return row.ID.String()
|
||||
|
||||
// Second pass: store the collected data in the database.
|
||||
if videoData != nil {
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("recording-%s.mp4", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "video/mp4",
|
||||
Data: videoData,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store recording in database",
|
||||
slog.Error(err))
|
||||
} else {
|
||||
result.recordingFileID = row.ID.String()
|
||||
}
|
||||
}
|
||||
if thumbnailData != nil && result.recordingFileID != "" {
|
||||
//nolint:gocritic // AsChatd is required to insert chat files from the recording pipeline.
|
||||
row, err := p.db.InsertChatFile(dbauthz.AsChatd(ctx), database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
Name: fmt.Sprintf("thumbnail-%s.jpg", p.clock.Now().UTC().Format("2006-01-02T15-04-05Z")),
|
||||
Mimetype: "image/jpeg",
|
||||
Data: thumbnailData,
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "failed to store thumbnail in database",
|
||||
slog.Error(err))
|
||||
} else {
|
||||
result.thumbnailFileID = row.ID.String()
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -34,6 +36,30 @@ func (zeroReader) Read(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// partSpec describes a single part for buildMultipartResponse.
|
||||
type partSpec struct {
|
||||
contentType string
|
||||
data []byte
|
||||
}
|
||||
|
||||
// buildMultipartResponse constructs a StopDesktopRecordingResponse
|
||||
// with the given content type/data pairs encoded as multipart/mixed.
|
||||
func buildMultipartResponse(parts ...partSpec) workspacesdk.StopDesktopRecordingResponse {
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
for _, p := range parts {
|
||||
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {p.contentType},
|
||||
})
|
||||
_, _ = partWriter.Write(p.data)
|
||||
}
|
||||
_ = mw.Close()
|
||||
return workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(buf.Bytes())),
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}
|
||||
}
|
||||
|
||||
// createComputerUseParentChild creates a parent chat and a
|
||||
// computer_use child chat bound to the given workspace/agent.
|
||||
// Both chats are inserted directly via DB to avoid triggering
|
||||
@@ -170,8 +196,7 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(bytes.NewReader(fakeMp4)), nil).
|
||||
Times(1)
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", fakeMp4}), nil).Times(1)
|
||||
|
||||
// Invoke wait_agent via the tool closure.
|
||||
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
|
||||
@@ -198,6 +223,87 @@ func TestWaitAgentComputerUseRecording(t *testing.T) {
|
||||
assert.Equal(t, fakeMp4, chatFile.Data)
|
||||
}
|
||||
|
||||
// TestWaitAgentComputerUseRecordingWithThumbnail verifies the
|
||||
// recording flow when the agent produces both video and thumbnail:
|
||||
// both file IDs appear in the wait_agent tool response.
|
||||
func TestWaitAgentComputerUseRecordingWithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, agent := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
parent, child := createComputerUseParentChild(
|
||||
ctx, t, server, user, model, workspace, agent,
|
||||
"parent-recording-thumb", "computer-use-child-thumb",
|
||||
)
|
||||
|
||||
server.drainInflight()
|
||||
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, agent.ID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
insertAssistantMessage(ctx, t, db, child.ID, model.ID, "I opened Firefox and took a screenshot.")
|
||||
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
|
||||
fakeMp4 := []byte("fake-mp4-data-with-thumbnail-test")
|
||||
fakeThumb := []byte("fake-jpeg-thumbnail-data")
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, req workspacesdk.StartDesktopRecordingRequest) error {
|
||||
require.NotEmpty(t, req.RecordingID)
|
||||
return nil
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", fakeMp4},
|
||||
partSpec{"image/jpeg", fakeThumb},
|
||||
), nil).Times(1)
|
||||
|
||||
resp, err := invokeWaitAgentTool(ctx, t, server, db, parent.ID, child.ID, 5)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected successful response, got: %s", resp.Content)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
|
||||
// Verify recording_file_id is present and valid.
|
||||
storedFileID, ok := result["recording_file_id"].(string)
|
||||
require.True(t, ok, "recording_file_id must be present in response")
|
||||
require.NotEmpty(t, storedFileID)
|
||||
fileUUID, err := uuid.Parse(storedFileID)
|
||||
require.NoError(t, err)
|
||||
chatFile, err := db.GetChatFileByID(ctx, fileUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", chatFile.Mimetype)
|
||||
assert.Equal(t, fakeMp4, chatFile.Data)
|
||||
|
||||
// Verify thumbnail_file_id is present and valid.
|
||||
thumbFileID, ok := result["thumbnail_file_id"].(string)
|
||||
require.True(t, ok, "thumbnail_file_id must be present in response")
|
||||
require.NotEmpty(t, thumbFileID)
|
||||
thumbUUID, err := uuid.Parse(thumbFileID)
|
||||
require.NoError(t, err)
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, fakeThumb, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestWaitAgentNonComputerUseNoRecording verifies that when the
|
||||
// child chat is NOT a computer_use chat, no recording is attempted.
|
||||
// StartDesktopRecording must never be called.
|
||||
@@ -342,7 +448,7 @@ func TestWaitAgentRecordingStopFails(t *testing.T) {
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(nil, xerrors.New("disk full")).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("disk full")).
|
||||
Times(1)
|
||||
|
||||
// Invoke wait_agent via the tool closure.
|
||||
@@ -446,10 +552,10 @@ func TestWaitAgentTimeoutLeavesRecordingRunning(t *testing.T) {
|
||||
assert.Contains(t, result.resp.Content, "timed out")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecordingOversized verifies that when the recording
|
||||
// data exceeds MaxRecordingSize, stopAndStoreRecording returns an
|
||||
// empty string and does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecordingOversized(t *testing.T) {
|
||||
// TestStopAndStoreRecording_Oversized verifies that when the
|
||||
// recording data exceeds MaxRecordingSize, stopAndStoreRecording
|
||||
// returns an empty string and does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecording_Oversized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -463,29 +569,146 @@ func TestStopAndStoreRecordingOversized(t *testing.T) {
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
// Create a reader that produces MaxRecordingSize+1 bytes without
|
||||
// allocating the full buffer in memory.
|
||||
oversizedReader := io.LimitReader(
|
||||
&zeroReader{},
|
||||
int64(workspacesdk.MaxRecordingSize+1),
|
||||
)
|
||||
// Build a streaming multipart response with a video/mp4 part
|
||||
// that exceeds MaxRecordingSize without allocating the full
|
||||
// buffer in memory.
|
||||
pr, pw := io.Pipe()
|
||||
mw := multipart.NewWriter(pw)
|
||||
go func() {
|
||||
partWriter, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
// Stream MaxRecordingSize+1 zero bytes.
|
||||
_, _ = io.Copy(partWriter, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxRecordingSize+1)))
|
||||
_ = mw.Close()
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(oversizedReader), nil).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: pr,
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
storedFileID := server.stopAndStoreRecording(
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
assert.Empty(t, storedFileID, "oversized recording should not be stored")
|
||||
assert.Empty(t, result.recordingFileID, "oversized recording should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecordingEmpty verifies that when the recording
|
||||
// TestStopAndStoreRecording_OversizedThumbnail verifies that when the
|
||||
// thumbnail part exceeds MaxThumbnailSize it is skipped while the
|
||||
// normal-sized video part is still stored.
|
||||
func TestStopAndStoreRecording_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := bytes.Repeat([]byte{0xAA}, 1024)
|
||||
|
||||
// Build a streaming multipart response with a normal video part
|
||||
// and an oversized thumbnail part.
|
||||
pr, pw := io.Pipe()
|
||||
mw := multipart.NewWriter(pw)
|
||||
go func() {
|
||||
vw, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
_, _ = vw.Write(videoData)
|
||||
tw, _ := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
// Stream MaxThumbnailSize+1 zero bytes for the thumbnail.
|
||||
_, _ = io.Copy(tw, io.LimitReader(&zeroReader{}, int64(workspacesdk.MaxThumbnailSize+1)))
|
||||
_ = mw.Close()
|
||||
_ = pw.Close()
|
||||
}()
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: pr,
|
||||
ContentType: "multipart/mixed; boundary=" + mw.Boundary(),
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Video should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// Thumbnail should be skipped (oversized).
|
||||
assert.Empty(t, result.thumbnailFileID, "oversized thumbnail should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_DuplicatePartsIgnored verifies that when
|
||||
// a multipart response contains two video/mp4 parts, only the first
|
||||
// is stored and the duplicate is skipped.
|
||||
func TestStopAndStoreRecording_DuplicatePartsIgnored(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
firstVideo := bytes.Repeat([]byte{0x01}, 512)
|
||||
secondVideo := bytes.Repeat([]byte{0x02}, 512)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", firstVideo},
|
||||
partSpec{"video/mp4", secondVideo},
|
||||
), nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Only the first video part should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err)
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, firstVideo, recFile.Data, "first video part should be stored, not the duplicate")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_Empty verifies that when the recording
|
||||
// data is empty, stopAndStoreRecording returns an empty string and
|
||||
// does NOT call InsertChatFile.
|
||||
func TestStopAndStoreRecordingEmpty(t *testing.T) {
|
||||
func TestStopAndStoreRecording_Empty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -499,16 +722,265 @@ func TestStopAndStoreRecordingEmpty(t *testing.T) {
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
// Return empty data.
|
||||
// Build a multipart response with an empty video/mp4 part.
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(io.NopCloser(bytes.NewReader(nil)), nil).
|
||||
Times(1)
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", nil}), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
storedFileID := server.stopAndStoreRecording(
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
assert.Empty(t, storedFileID, "empty recording should not be stored")
|
||||
assert.Empty(t, result.recordingFileID, "empty recording should not be stored")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_WithThumbnail verifies that a multipart
|
||||
// response containing both a video/mp4 part and an image/jpeg part
|
||||
// results in both files being stored with correct mimetypes.
|
||||
func TestStopAndStoreRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := bytes.Repeat([]byte{0xDE, 0xAD}, 512) // 1024 bytes
|
||||
thumbData := bytes.Repeat([]byte{0xFF, 0xD8}, 256) // 512 bytes
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", videoData},
|
||||
partSpec{"image/jpeg", thumbData},
|
||||
), nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Both file IDs should be valid UUIDs.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
|
||||
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
|
||||
// Verify the recording file in the database.
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// Verify the thumbnail file in the database.
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, thumbData, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_VideoOnly verifies that a multipart
|
||||
// response with only a video/mp4 part stores the recording but
|
||||
// leaves thumbnailFileID empty.
|
||||
func TestStopAndStoreRecording_VideoOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := make([]byte, 1024)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(partSpec{"video/mp4", videoData}), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Recording should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
// No thumbnail.
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when no thumbnail part is present")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_DownloadFailure verifies that when
|
||||
// StopDesktopRecording returns an error, stopAndStoreRecording
|
||||
// returns an empty recordingResult without panicking.
|
||||
func TestStopAndStoreRecording_DownloadFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{}, xerrors.New("network error")).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty on download failure")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty on download failure")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_UnknownPartIgnored verifies that parts
|
||||
// with unrecognized content types are silently skipped while known
|
||||
// parts (video/mp4 and image/jpeg) are still stored.
|
||||
func TestStopAndStoreRecording_UnknownPartIgnored(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
videoData := make([]byte, 1024)
|
||||
thumbData := make([]byte, 512)
|
||||
unknownData := make([]byte, 256)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(buildMultipartResponse(
|
||||
partSpec{"video/mp4", videoData},
|
||||
partSpec{"image/jpeg", thumbData},
|
||||
partSpec{"application/octet-stream", unknownData},
|
||||
), nil).Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
// Both known parts should be stored.
|
||||
recUUID, err := uuid.Parse(result.recordingFileID)
|
||||
require.NoError(t, err, "RecordingFileID should be a valid UUID")
|
||||
|
||||
thumbUUID, err := uuid.Parse(result.thumbnailFileID)
|
||||
require.NoError(t, err, "ThumbnailFileID should be a valid UUID")
|
||||
|
||||
// Verify only 2 files exist (unknown part was skipped).
|
||||
recFile, err := db.GetChatFileByID(ctx, recUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "video/mp4", recFile.Mimetype)
|
||||
assert.Equal(t, videoData, recFile.Data)
|
||||
|
||||
thumbFile, err := db.GetChatFileByID(ctx, thumbUUID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image/jpeg", thumbFile.Mimetype)
|
||||
assert.Equal(t, thumbData, thumbFile.Data)
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_MalformedContentType verifies that a
|
||||
// response with an unparseable Content-Type returns an empty result.
|
||||
func TestStopAndStoreRecording_MalformedContentType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
ContentType: "",
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty for malformed content type")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty for malformed content type")
|
||||
}
|
||||
|
||||
// TestStopAndStoreRecording_MissingBoundary verifies that a
|
||||
// multipart response without a boundary parameter returns an empty
|
||||
// result.
|
||||
func TestStopAndStoreRecording_MissingBoundary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := chatdTestContext(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
user, _ := seedInternalChatDeps(ctx, t, db)
|
||||
workspace, _, _ := seedWorkspaceBinding(t, db, user.ID)
|
||||
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
mockConn.EXPECT().
|
||||
StopDesktopRecording(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StopDesktopRecordingResponse{
|
||||
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||
ContentType: "multipart/mixed",
|
||||
}, nil).
|
||||
Times(1)
|
||||
|
||||
recordingID := uuid.New().String()
|
||||
result := server.stopAndStoreRecording(
|
||||
ctx, mockConn, recordingID, user.ID,
|
||||
uuid.NullUUID{UUID: workspace.ID, Valid: true},
|
||||
)
|
||||
|
||||
assert.Empty(t, result.recordingFileID, "RecordingFileID should be empty when boundary is missing")
|
||||
assert.Empty(t, result.thumbnailFileID, "ThumbnailFileID should be empty when boundary is missing")
|
||||
}
|
||||
|
||||
+275
-65
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -233,13 +234,13 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
}
|
||||
|
||||
// Only stop and store the recording on success.
|
||||
var storedFileID string
|
||||
var recResult recordingResult
|
||||
if recordingID != "" && agentConn != nil {
|
||||
// Use a fresh context for cleanup so a canceled
|
||||
// parent context doesn't prevent recording storage.
|
||||
stopCtx, stopCancel := context.WithTimeout(context.WithoutCancel(ctx), 90*time.Second)
|
||||
defer stopCancel()
|
||||
storedFileID = p.stopAndStoreRecording(stopCtx, agentConn,
|
||||
recResult = p.stopAndStoreRecording(stopCtx, agentConn,
|
||||
recordingID, parent.OwnerID, parent.WorkspaceID)
|
||||
}
|
||||
resp := map[string]any{
|
||||
@@ -248,8 +249,11 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
"report": report,
|
||||
"status": string(targetChat.Status),
|
||||
}
|
||||
if storedFileID != "" {
|
||||
resp["recording_file_id"] = storedFileID
|
||||
if recResult.recordingFileID != "" {
|
||||
resp["recording_file_id"] = recResult.recordingFileID
|
||||
}
|
||||
if recResult.thumbnailFileID != "" {
|
||||
resp["thumbnail_file_id"] = recResult.thumbnailFileID
|
||||
}
|
||||
return toolJSONResponse(resp), nil
|
||||
},
|
||||
@@ -358,48 +362,19 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(args.Prompt)
|
||||
if prompt == "" {
|
||||
return fantasy.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(args.Title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
|
||||
}
|
||||
|
||||
// Create the child chat with Mode set to
|
||||
// computer_use. This signals runChat to use the
|
||||
// predefined computer use model and include the
|
||||
// computer tool.
|
||||
childChat, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
childChat, err := p.createChildSubagentChatWithOptions(
|
||||
ctx,
|
||||
parent,
|
||||
args.Prompt,
|
||||
args.Title,
|
||||
childSubagentChatOptions{
|
||||
chatMode: database.NullChatMode{
|
||||
ChatMode: database.ChatModeComputerUse,
|
||||
Valid: true,
|
||||
},
|
||||
systemPrompt: computerUseSubagentSystemPrompt + "\n\n" + strings.TrimSpace(args.Prompt),
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
MCPServerIDs: parent.MCPServerIDs,
|
||||
})
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
@@ -424,11 +399,26 @@ func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
return chatID, nil
|
||||
}
|
||||
|
||||
type childSubagentChatOptions struct {
|
||||
chatMode database.NullChatMode
|
||||
systemPrompt string
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
) (database.Chat, error) {
|
||||
return p.createChildSubagentChatWithOptions(ctx, parent, prompt, title, childSubagentChatOptions{})
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChatWithOptions(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
prompt string,
|
||||
title string,
|
||||
opts childSubagentChatOptions,
|
||||
) (database.Chat, error) {
|
||||
if parent.ParentChatID.Valid {
|
||||
return database.Chat{}, xerrors.New("delegated chats cannot create child subagents")
|
||||
@@ -452,31 +442,251 @@ func (p *Server) createChildSubagentChat(
|
||||
return database.Chat{}, xerrors.New("parent chat model config id is required")
|
||||
}
|
||||
|
||||
child, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
MCPServerIDs: parent.MCPServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
|
||||
mcpServerIDs := parent.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
labelsJSON, err := json.Marshal(database.StringMap{})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("marshal labels: %w", err)
|
||||
}
|
||||
childSystemPrompt := SanitizePromptText(opts.systemPrompt)
|
||||
|
||||
var child database.Chat
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
if limitErr := p.checkUsageLimit(ctx, tx, parent.OwnerID); limitErr != nil {
|
||||
return limitErr
|
||||
}
|
||||
|
||||
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
BuildID: parent.BuildID,
|
||||
AgentID: parent.AgentID,
|
||||
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true},
|
||||
LastModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
Mode: opts.chatMode,
|
||||
Status: database.ChatStatusPending,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: pqtype.NullRawMessage{
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert child chat: %w", err)
|
||||
}
|
||||
|
||||
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
|
||||
workspaceAwareness := "There is no workspace associated with this chat yet. Create one using the create_workspace tool before using workspace tools like execute, read_file, write_file, etc."
|
||||
if insertedChat.WorkspaceID.Valid {
|
||||
workspaceAwareness = "This chat is attached to a workspace. You can use workspace tools like execute, read_file, write_file, etc."
|
||||
}
|
||||
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(workspaceAwareness),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal workspace awareness: %w", err)
|
||||
}
|
||||
userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal initial user content: %w", err)
|
||||
}
|
||||
|
||||
systemParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: insertedChat.ID,
|
||||
}
|
||||
if deploymentPrompt != "" {
|
||||
deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(deploymentPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal deployment system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
deploymentContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
}
|
||||
if childSystemPrompt != "" {
|
||||
childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(childSystemPrompt),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal child system prompt: %w", err)
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
childSystemPromptContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
}
|
||||
appendChatMessage(&systemParams, newChatMessage(
|
||||
database.ChatMessageRoleSystem,
|
||||
workspaceAwarenessContent,
|
||||
database.ChatMessageVisibilityModel,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
))
|
||||
if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil {
|
||||
return xerrors.Errorf("insert initial child system messages: %w", err)
|
||||
}
|
||||
|
||||
child = insertedChat
|
||||
|
||||
// Copy persisted context before the initial child prompt so the
|
||||
// child cannot be acquired until its inherited context is in
|
||||
// place. signalWake runs only after commit.
|
||||
copiedContextParts, err := copyParentContextMessages(ctx, p.logger, tx, parent, child)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("copy parent context messages: %w", err)
|
||||
}
|
||||
if err := updateChildLastInjectedContext(ctx, p.logger, tx, child.ID, copiedContextParts); err != nil {
|
||||
return xerrors.Errorf("update child injected context: %w", err)
|
||||
}
|
||||
|
||||
userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: insertedChat.ID,
|
||||
}
|
||||
appendChatMessage(&userParams, newChatMessage(
|
||||
database.ChatMessageRoleUser,
|
||||
userContent,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
parent.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(parent.OwnerID))
|
||||
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {
|
||||
return xerrors.Errorf("insert initial child user message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if txErr != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr)
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil)
|
||||
p.signalWake()
|
||||
return child, nil
|
||||
}
|
||||
|
||||
// copyParentContextMessages reads persisted context-file and skill
|
||||
// messages from the parent chat and inserts copies into the child
|
||||
// chat. This ensures sub-agents inherit the same instruction and
|
||||
// skill context as their parent without independently re-fetching
|
||||
// from the agent.
|
||||
func copyParentContextMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
store database.Store,
|
||||
parent database.Chat,
|
||||
child database.Chat,
|
||||
) ([]codersdk.ChatMessagePart, error) {
|
||||
parentMessages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: parent.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get parent messages: %w", err)
|
||||
}
|
||||
|
||||
var (
|
||||
copiedParts []codersdk.ChatMessagePart
|
||||
copiedRole database.ChatMessageRole
|
||||
copiedVisibility database.ChatMessageVisibility
|
||||
copiedVersion int16
|
||||
)
|
||||
for _, msg := range parentMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
||||
logger.Warn(ctx, "failed to unmarshal parent context message",
|
||||
slog.F("parent_chat_id", parent.ID),
|
||||
slog.F("message_id", msg.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
messageContextParts := FilterContextParts(parts, true)
|
||||
if len(messageContextParts) == 0 {
|
||||
continue
|
||||
}
|
||||
if copiedParts == nil {
|
||||
copiedRole = msg.Role
|
||||
copiedVisibility = msg.Visibility
|
||||
copiedVersion = msg.ContentVersion
|
||||
}
|
||||
copiedParts = append(copiedParts, messageContextParts...)
|
||||
}
|
||||
if len(copiedParts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
copiedParts = FilterContextPartsToLatestAgent(copiedParts)
|
||||
filteredContent, err := chatprompt.MarshalParts(copiedParts)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal filtered context parts: %w", err)
|
||||
}
|
||||
|
||||
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: child.ID,
|
||||
}
|
||||
appendChatMessage(&msgParams, newChatMessage(
|
||||
copiedRole,
|
||||
filteredContent,
|
||||
copiedVisibility,
|
||||
child.LastModelConfigID,
|
||||
copiedVersion,
|
||||
))
|
||||
if _, err := store.InsertChatMessages(ctx, msgParams); err != nil {
|
||||
return nil, xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
|
||||
return copiedParts, nil
|
||||
}
|
||||
|
||||
func updateChildLastInjectedContext(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
store database.Store,
|
||||
chatID uuid.UUID,
|
||||
parts []codersdk.ChatMessagePart,
|
||||
) error {
|
||||
parts = FilterContextPartsToLatestAgent(parts)
|
||||
param, err := BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal inherited injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("marshal inherited injected context: %w", err)
|
||||
}
|
||||
if _, err := store.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
logger.Warn(ctx, "failed to update inherited injected context",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("update inherited injected context: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Server) sendSubagentMessage(
|
||||
ctx context.Context,
|
||||
parentChatID uuid.UUID,
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestCollectContextPartsFromMessagesSkipsSentinelContextFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
SkillDescription: "A test skill",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/AGENTS.md",
|
||||
ContextFileContent: "# Project instructions",
|
||||
},
|
||||
codersdk.ChatMessageText("ignored"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
|
||||
{
|
||||
ID: 1,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: content,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 2)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeSkill, parts[0].Type)
|
||||
require.Equal(t, "my-skill", parts[0].SkillName)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeContextFile, parts[1].Type)
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", parts[1].ContextFilePath)
|
||||
require.Equal(t, "# Project instructions", parts[1].ContextFileContent)
|
||||
}
|
||||
|
||||
func TestCollectContextPartsFromMessagesKeepsEmptyContextFilesWhenRequested(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
content, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: uuid.New(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parts, err := CollectContextPartsFromMessages(context.Background(), slog.Make(), []database.ChatMessage{ //nolint:exhaustruct // Only content fields matter for this unit test.
|
||||
{
|
||||
ID: 1,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: content,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parts, 2)
|
||||
require.Equal(t, AgentChatContextSentinelPath, parts[0].ContextFilePath)
|
||||
require.Equal(t, "my-skill", parts[1].SkillName)
|
||||
}
|
||||
|
||||
func TestFilterContextPartsToLatestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/legacy/AGENTS.md",
|
||||
ContextFileContent: "legacy instructions",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-legacy",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: newAgentID,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "repo-helper-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
got := FilterContextPartsToLatestAgent(parts)
|
||||
require.Len(t, got, 4)
|
||||
require.Equal(t, "/legacy/AGENTS.md", got[0].ContextFilePath)
|
||||
require.Equal(t, "repo-helper-legacy", got[1].SkillName)
|
||||
require.Equal(t, AgentChatContextSentinelPath, got[2].ContextFilePath)
|
||||
require.Equal(t, "repo-helper-new", got[3].SkillName)
|
||||
}
|
||||
|
||||
func createParentChatWithInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
server *Server,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-with-context",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
inheritedParts := []codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/AGENTS.md",
|
||||
ContextFileContent: "# Project instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/home/coder/project",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "my-skill",
|
||||
SkillDescription: "A test skill",
|
||||
SkillDir: "/home/coder/project/.agents/skills/my-skill",
|
||||
ContextFileSkillMetaFile: "SKILL.md",
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project/.agents/skills/my-skill/SKILL.md",
|
||||
},
|
||||
}
|
||||
content, err := json.Marshal(inheritedParts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: parent.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{string(content)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
return parentChat
|
||||
}
|
||||
|
||||
func assertChildInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
childID uuid.UUID,
|
||||
prompt string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.LastInjectedContext.Valid)
|
||||
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 2)
|
||||
|
||||
var sawContextFile bool
|
||||
var sawSkill bool
|
||||
for _, part := range cached {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
sawContextFile = true
|
||||
require.Equal(t, "/home/coder/project/AGENTS.md", part.ContextFilePath)
|
||||
require.Empty(t, part.ContextFileContent)
|
||||
require.Empty(t, part.ContextFileOS)
|
||||
require.Empty(t, part.ContextFileDirectory)
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
sawSkill = true
|
||||
require.Equal(t, "my-skill", part.SkillName)
|
||||
require.Equal(t, "A test skill", part.SkillDescription)
|
||||
require.Empty(t, part.SkillDir)
|
||||
require.Empty(t, part.ContextFileSkillMetaFile)
|
||||
default:
|
||||
t.Fatalf("unexpected cached part type %q", part.Type)
|
||||
}
|
||||
}
|
||||
require.True(t, sawContextFile)
|
||||
require.True(t, sawSkill)
|
||||
|
||||
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: childID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
contextMessageIndexes []int
|
||||
userPromptIndex = -1
|
||||
sawDBAgentsContextFile bool
|
||||
sawDBSkillCompanionContext bool
|
||||
sawDBSkill bool
|
||||
)
|
||||
for i, msg := range childMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
|
||||
|
||||
if len(parts) == 1 && parts[0].Type == codersdk.ChatMessagePartTypeText && parts[0].Text == prompt {
|
||||
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
|
||||
userPromptIndex = i
|
||||
continue
|
||||
}
|
||||
|
||||
hasInheritedContext := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasInheritedContext = true
|
||||
switch part.ContextFilePath {
|
||||
case "/home/coder/project/AGENTS.md":
|
||||
sawDBAgentsContextFile = true
|
||||
require.Equal(t, "# Project instructions", part.ContextFileContent)
|
||||
require.Equal(t, "linux", part.ContextFileOS)
|
||||
require.Equal(t, "/home/coder/project", part.ContextFileDirectory)
|
||||
case "/home/coder/project/.agents/skills/my-skill/SKILL.md":
|
||||
sawDBSkillCompanionContext = true
|
||||
require.Empty(t, part.ContextFileContent)
|
||||
require.Empty(t, part.ContextFileOS)
|
||||
require.Empty(t, part.ContextFileDirectory)
|
||||
default:
|
||||
t.Fatalf("unexpected child inherited context file path %q", part.ContextFilePath)
|
||||
}
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasInheritedContext = true
|
||||
sawDBSkill = true
|
||||
require.Equal(t, "my-skill", part.SkillName)
|
||||
require.Equal(t, "A test skill", part.SkillDescription)
|
||||
require.Equal(t, "/home/coder/project/.agents/skills/my-skill", part.SkillDir)
|
||||
require.Equal(t, "SKILL.md", part.ContextFileSkillMetaFile)
|
||||
default:
|
||||
t.Fatalf("unexpected child inherited part type %q", part.Type)
|
||||
}
|
||||
}
|
||||
if hasInheritedContext {
|
||||
require.Equal(t, database.ChatMessageRoleUser, msg.Role)
|
||||
contextMessageIndexes = append(contextMessageIndexes, i)
|
||||
}
|
||||
}
|
||||
|
||||
require.NotEmpty(t, contextMessageIndexes)
|
||||
require.NotEqual(t, -1, userPromptIndex)
|
||||
for _, idx := range contextMessageIndexes {
|
||||
require.Less(t, idx, userPromptIndex)
|
||||
}
|
||||
require.True(t, sawDBAgentsContextFile)
|
||||
require.True(t, sawDBSkillCompanionContext)
|
||||
require.True(t, sawDBSkill)
|
||||
}
|
||||
|
||||
func createParentChatWithRotatedInheritedContext(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
server *Server,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-with-rotated-context",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project-old/AGENTS.md",
|
||||
ContextFileContent: "# Old instructions",
|
||||
ContextFileOS: "darwin",
|
||||
ContextFileDirectory: "/home/coder/project-old",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "old-skill",
|
||||
SkillDescription: "Old skill",
|
||||
SkillDir: "/home/coder/project-old/.agents/skills/old-skill",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/home/coder/project-new/AGENTS.md",
|
||||
ContextFileContent: "# New instructions",
|
||||
ContextFileOS: "linux",
|
||||
ContextFileDirectory: "/home/coder/project-new",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
{
|
||||
Type: codersdk.ChatMessagePartTypeSkill,
|
||||
SkillName: "new-skill",
|
||||
SkillDescription: "New skill",
|
||||
SkillDir: "/home/coder/project-new/.agents/skills/new-skill",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: parent.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID, model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleUser},
|
||||
Content: []string{string(oldContent), string(newContent)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0, 0},
|
||||
OutputTokens: []int64{0, 0},
|
||||
TotalTokens: []int64{0, 0},
|
||||
ReasoningTokens: []int64{0, 0},
|
||||
CacheCreationTokens: []int64{0, 0},
|
||||
CacheReadTokens: []int64{0, 0},
|
||||
ContextLimit: []int64{0, 0},
|
||||
Compressed: []bool{false, false},
|
||||
TotalCostMicros: []int64{0, 0},
|
||||
RuntimeMs: []int64{0, 0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
return parentChat
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatCopiesOnlyLatestAgentContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithRotatedInheritedContext(ctx, t, db, server)
|
||||
|
||||
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.LastInjectedContext.Valid)
|
||||
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(childChat.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 2)
|
||||
require.Equal(t, "/home/coder/project-new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, "new-skill", cached[1].SkillName)
|
||||
|
||||
childMessages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: child.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var inherited [][]codersdk.ChatMessagePart
|
||||
for _, msg := range childMessages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
var parts []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(msg.Content.RawMessage, &parts))
|
||||
if len(parts) == 0 || parts[0].Type == codersdk.ChatMessagePartTypeText {
|
||||
continue
|
||||
}
|
||||
inherited = append(inherited, parts)
|
||||
}
|
||||
require.Len(t, inherited, 1)
|
||||
require.Len(t, inherited[0], 2)
|
||||
require.Equal(t, "/home/coder/project-new/AGENTS.md", inherited[0][0].ContextFilePath)
|
||||
require.Equal(t, "# New instructions", inherited[0][0].ContextFileContent)
|
||||
require.Equal(t, "new-skill", inherited[0][1].SkillName)
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatUpdatesInheritedLastInjectedContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
|
||||
|
||||
child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assertChildInheritedContext(ctx, t, db, child.ID, "inspect bindings")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgentInheritsContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
parentChat := createParentChatWithInheritedContext(ctx, t, db, server)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool)
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-context",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"inspect bindings"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
childIDStr, ok := result["chat_id"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
childID, err := uuid.Parse(childIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.Mode.Valid)
|
||||
require.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
|
||||
|
||||
assertChildInheritedContext(ctx, t, db, childID, "inspect bindings")
|
||||
}
|
||||
@@ -892,3 +892,66 @@ func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*Reinitialization
|
||||
return &reinitEvent, nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddChatContextRequest is the request body for adding chat context.
|
||||
type AddChatContextRequest struct {
|
||||
// ChatID optionally identifies the chat to add context to.
|
||||
// If empty, auto-detection is used (CODER_CHAT_ID env, the
|
||||
// only active chat, or the only top-level active chat for this
|
||||
// agent).
|
||||
ChatID uuid.UUID `json:"chat_id,omitempty"`
|
||||
// Parts are the context-file and skill parts to add.
|
||||
Parts []codersdk.ChatMessagePart `json:"parts"`
|
||||
}
|
||||
|
||||
// AddChatContextResponse is the response for adding chat context.
|
||||
type AddChatContextResponse struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// ClearChatContextRequest is the request body for clearing chat context.
|
||||
type ClearChatContextRequest struct {
|
||||
// ChatID optionally identifies the chat to clear context from.
|
||||
// If empty, auto-detection is used (CODER_CHAT_ID env, the
|
||||
// only active chat, or the only top-level active chat for this
|
||||
// agent).
|
||||
ChatID uuid.UUID `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
// ClearChatContextResponse is the response for clearing chat context.
|
||||
type ClearChatContextResponse struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
}
|
||||
|
||||
// AddChatContext adds context-file and skill parts to an active chat.
|
||||
func (c *Client) AddChatContext(ctx context.Context, req AddChatContextRequest) (AddChatContextResponse, error) {
|
||||
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/experimental/chat-context", req)
|
||||
if err != nil {
|
||||
return AddChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return AddChatContextResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var resp AddChatContextResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// ClearChatContext soft-deletes context-file and skill messages from an active chat.
|
||||
func (c *Client) ClearChatContext(ctx context.Context, req ClearChatContextRequest) (ClearChatContextResponse, error) {
|
||||
res, err := c.SDK.Request(ctx, http.MethodDelete, "/api/v2/workspaceagents/me/experimental/chat-context", req)
|
||||
if err != nil {
|
||||
return ClearChatContextResponse{}, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ClearChatContextResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
var resp ClearChatContextResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
@@ -127,6 +127,8 @@ type AIBridgeThread struct {
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
CredentialKind string `json:"credential_kind"`
|
||||
CredentialHint string `json:"credential_hint"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"`
|
||||
TokenUsage AIBridgeSessionThreadsTokenUsage `json:"token_usage"`
|
||||
|
||||
+9
-93
@@ -1130,11 +1130,6 @@ type ChatStreamEvent struct {
|
||||
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
|
||||
}
|
||||
|
||||
type chatStreamEnvelope struct {
|
||||
Type ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
|
||||
type ChatCostSummaryOptions struct {
|
||||
StartDate time.Time
|
||||
@@ -1987,8 +1982,8 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var batch []ChatStreamEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &batch); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2005,61 +2000,10 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var batch []ChatStreamEvent
|
||||
decodeErr := json.Unmarshal(envelope.Data, &batch)
|
||||
if decodeErr == nil {
|
||||
for _, streamedEvent := range batch {
|
||||
if !send(streamedEvent) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
{
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf(
|
||||
"decode chat stream event batch: %v",
|
||||
decodeErr,
|
||||
),
|
||||
},
|
||||
})
|
||||
for _, event := range batch {
|
||||
if !send(event) {
|
||||
return
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
message := "chat stream returned an error"
|
||||
if len(envelope.Data) > 0 {
|
||||
var response Response
|
||||
if err := json.Unmarshal(envelope.Data, &response); err == nil {
|
||||
message = formatChatStreamResponseError(response)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(string(envelope.Data))
|
||||
if trimmed != "" {
|
||||
message = trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
return
|
||||
default:
|
||||
_ = send(ChatStreamEvent{
|
||||
Type: ChatStreamEventTypeError,
|
||||
Error: &ChatStreamError{
|
||||
Message: fmt.Sprintf("unknown chat stream event type %q", envelope.Type),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2098,8 +2042,8 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
var event ChatWatchEvent
|
||||
if err := wsjson.Read(streamCtx, conn, &event); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -2110,23 +2054,10 @@ func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEv
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var event ChatWatchEvent
|
||||
if err := json.Unmarshal(envelope.Data, &event); err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
return
|
||||
default:
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -2478,21 +2409,6 @@ func (c *ExperimentalClient) GetChatsByWorkspace(ctx context.Context, workspaceI
|
||||
return result, json.NewDecoder(res.Body).Decode(&result)
|
||||
}
|
||||
|
||||
func formatChatStreamResponseError(response Response) string {
|
||||
message := strings.TrimSpace(response.Message)
|
||||
detail := strings.TrimSpace(response.Detail)
|
||||
switch {
|
||||
case message == "" && detail == "":
|
||||
return "chat stream returned an error"
|
||||
case message == "":
|
||||
return detail
|
||||
case detail == "":
|
||||
return message
|
||||
default:
|
||||
return fmt.Sprintf("%s: %s", message, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// PRInsightsResponse is the response from the PR insights endpoint.
|
||||
type PRInsightsResponse struct {
|
||||
Summary PRInsightsSummary `json:"summary"`
|
||||
|
||||
+1
-26
@@ -3624,29 +3624,6 @@ Write out the current server config as YAML to stdout.`,
|
||||
YAML: "acquireBatchSize",
|
||||
Hidden: true, // Hidden because most operators should not need to modify this.
|
||||
},
|
||||
{
|
||||
Name: "Chat: Pubsub Flush Interval",
|
||||
Description: "The maximum time accepted chatd pubsub publishes wait before the batching loop schedules a flush.",
|
||||
Flag: "chat-pubsub-flush-interval",
|
||||
Env: "CODER_CHAT_PUBSUB_FLUSH_INTERVAL",
|
||||
Value: &c.AI.Chat.PubsubFlushInterval,
|
||||
Default: "50ms",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "pubsubFlushInterval",
|
||||
Annotations: serpent.Annotations{}.Mark(annotationFormatDuration, "true"),
|
||||
Hidden: true,
|
||||
},
|
||||
{
|
||||
Name: "Chat: Pubsub Queue Size",
|
||||
Description: "How many chatd pubsub publishes can wait in memory for the dedicated sender path when PostgreSQL falls behind.",
|
||||
Flag: "chat-pubsub-queue-size",
|
||||
Env: "CODER_CHAT_PUBSUB_QUEUE_SIZE",
|
||||
Value: &c.AI.Chat.PubsubQueueSize,
|
||||
Default: "8192",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "pubsubQueueSize",
|
||||
Hidden: true,
|
||||
},
|
||||
// AI Bridge Options
|
||||
{
|
||||
Name: "AI Bridge Enabled",
|
||||
@@ -4113,9 +4090,7 @@ type AIBridgeProxyConfig struct {
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
PubsubFlushInterval serpent.Duration `json:"pubsub_flush_interval" typescript:",notnull"`
|
||||
PubsubQueueSize serpent.Int64 `json:"pubsub_queue_size" typescript:",notnull"`
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
|
||||
@@ -75,6 +75,49 @@ func (req *OAuth2ClientRegistrationRequest) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
|
||||
// safe to use as a redirect target. It returns an error when the scheme
|
||||
// is empty, an unsupported URN, or one of the schemes that are dangerous
|
||||
// in browser/HTML contexts (javascript, data, file, ftp).
|
||||
//
|
||||
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
|
||||
// are allowed.
|
||||
// ValidateRedirectURIScheme reports whether the callback URL's scheme is
|
||||
// safe to use as a redirect target. It returns an error when the scheme
|
||||
// is empty, an unsupported URN, or one of the schemes that are dangerous
|
||||
// in browser/HTML contexts (javascript, data, file, ftp).
|
||||
//
|
||||
// Legitimate custom schemes for native apps (e.g. vscode://, jetbrains://)
|
||||
// are allowed.
|
||||
func ValidateRedirectURIScheme(u *url.URL) error {
|
||||
return validateScheme(u)
|
||||
}
|
||||
|
||||
func validateScheme(u *url.URL) error {
|
||||
if u.Scheme == "" {
|
||||
return xerrors.New("redirect URI must have a scheme")
|
||||
}
|
||||
|
||||
// Handle special URNs (RFC 6749 section 3.1.2.1).
|
||||
if u.Scheme == "urn" {
|
||||
if u.String() == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
return nil
|
||||
}
|
||||
return xerrors.New("redirect URI uses unsupported URN scheme")
|
||||
}
|
||||
|
||||
// Block dangerous schemes for security (not allowed by RFCs
|
||||
// for OAuth2).
|
||||
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
|
||||
for _, dangerous := range dangerousSchemes {
|
||||
if strings.EqualFold(u.Scheme, dangerous) {
|
||||
return xerrors.Errorf("redirect URI uses dangerous scheme %s which is not allowed", dangerous)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRedirectURIs validates redirect URIs according to RFC 7591, 8252
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod) error {
|
||||
if len(uris) == 0 {
|
||||
@@ -91,27 +134,14 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndp
|
||||
return xerrors.Errorf("redirect URI at index %d is not a valid URL: %w", i, err)
|
||||
}
|
||||
|
||||
// Validate schemes according to RFC requirements
|
||||
if uri.Scheme == "" {
|
||||
return xerrors.Errorf("redirect URI at index %d must have a scheme", i)
|
||||
if err := validateScheme(uri); err != nil {
|
||||
return xerrors.Errorf("redirect URI at index %d: %w", i, err)
|
||||
}
|
||||
|
||||
// Handle special URNs (RFC 6749 section 3.1.2.1)
|
||||
// The urn:ietf:wg:oauth:2.0:oob scheme passed validation
|
||||
// above but needs no further checks.
|
||||
if uri.Scheme == "urn" {
|
||||
// Allow the out-of-band redirect URI for native apps
|
||||
if uriStr == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
continue // This is valid for native apps
|
||||
}
|
||||
// Other URNs are not standard for OAuth2
|
||||
return xerrors.Errorf("redirect URI at index %d uses unsupported URN scheme", i)
|
||||
}
|
||||
|
||||
// Block dangerous schemes for security (not allowed by RFCs for OAuth2)
|
||||
dangerousSchemes := []string{"javascript", "data", "file", "ftp"}
|
||||
for _, dangerous := range dangerousSchemes {
|
||||
if strings.EqualFold(uri.Scheme, dangerous) {
|
||||
return xerrors.Errorf("redirect URI at index %d uses dangerous scheme %s which is not allowed", i, dangerous)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine if this is a public client based on token endpoint auth method
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -39,3 +43,67 @@ type UpdateUserSecretRequest struct {
|
||||
EnvName *string `json:"env_name,omitempty"`
|
||||
FilePath *string `json:"file_path,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Client) CreateUserSecret(ctx context.Context, user string, req CreateUserSecretRequest) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/users/%s/secrets", user), req)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusCreated {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) UserSecrets(ctx context.Context, user string) ([]UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/secrets", user), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var secrets []UserSecret
|
||||
return secrets, json.NewDecoder(res.Body).Decode(&secrets)
|
||||
}
|
||||
|
||||
func (c *Client) UserSecretByName(ctx context.Context, user string, name string) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), nil)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) UpdateUserSecret(ctx context.Context, user string, name string, req UpdateUserSecretRequest) (UserSecret, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), req)
|
||||
if err != nil {
|
||||
return UserSecret{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserSecret{}, ReadBodyAsError(res)
|
||||
}
|
||||
var secret UserSecret
|
||||
return secret, json.NewDecoder(res.Body).Decode(&secret)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteUserSecret(ctx context.Context, user string, name string) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/users/%s/secrets/%s", user, name), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ type AgentConn interface {
|
||||
ConnectDesktopVNC(ctx context.Context) (net.Conn, error)
|
||||
ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error)
|
||||
StartDesktopRecording(ctx context.Context, req StartDesktopRecordingRequest) error
|
||||
StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error)
|
||||
StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error)
|
||||
}
|
||||
|
||||
// AgentConn represents a connection to a workspace agent.
|
||||
@@ -610,11 +610,25 @@ type StopDesktopRecordingRequest struct {
|
||||
RecordingID string `json:"recording_id"`
|
||||
}
|
||||
|
||||
// StopDesktopRecordingResponse wraps the response from stopping a
|
||||
// desktop recording. Body contains the recording data as a
|
||||
// multipart/mixed stream. ContentType holds the Content-Type
|
||||
// header (including boundary) so callers can parse the body.
|
||||
type StopDesktopRecordingResponse struct {
|
||||
Body io.ReadCloser
|
||||
ContentType string
|
||||
}
|
||||
|
||||
// MaxRecordingSize is the largest desktop recording (in bytes)
|
||||
// that will be accepted. Used by both the agent-side stop handler
|
||||
// and the server-side storage pipeline.
|
||||
const MaxRecordingSize = 100 << 20 // 100 MB
|
||||
|
||||
// MaxThumbnailSize is the largest thumbnail (in bytes) that will
|
||||
// be accepted. Applied both agent-side (before streaming) and
|
||||
// server-side (when parsing multipart parts).
|
||||
const MaxThumbnailSize = 10 << 20 // 10 MB
|
||||
|
||||
// ExecuteDesktopAction executes a mouse/keyboard/scroll action on the
|
||||
// agent's desktop.
|
||||
func (c *agentConn) ExecuteDesktopAction(ctx context.Context, action DesktopAction) (DesktopActionResponse, error) {
|
||||
@@ -681,22 +695,27 @@ func (c *agentConn) StartDesktopRecording(ctx context.Context, req StartDesktopR
|
||||
}
|
||||
|
||||
// StopDesktopRecording stops a desktop recording session on the
|
||||
// agent and returns the MP4 data as an io.ReadCloser. The caller
|
||||
// is responsible for closing the returned reader. Idempotent —
|
||||
// safe to call on an already-stopped recording.
|
||||
func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (io.ReadCloser, error) {
|
||||
// agent and returns the recording as a StopDesktopRecordingResponse.
|
||||
// The response body is a multipart/mixed stream containing the
|
||||
// video (and optionally a JPEG thumbnail). The caller is
|
||||
// responsible for closing the returned Body. Idempotent — safe
|
||||
// to call on an already-stopped recording.
|
||||
func (c *agentConn) StopDesktopRecording(ctx context.Context, req StopDesktopRecordingRequest) (StopDesktopRecordingResponse, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/desktop/recording/stop", req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("stop recording request: %w", err)
|
||||
return StopDesktopRecordingResponse{}, xerrors.Errorf("stop recording request: %w", err)
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
defer res.Body.Close()
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
return StopDesktopRecordingResponse{}, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
// Caller is responsible for closing res.Body.
|
||||
return res.Body, nil
|
||||
return StopDesktopRecordingResponse{
|
||||
Body: res.Body,
|
||||
ContentType: res.Header.Get("Content-Type"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteDevcontainer deletes the provided devcontainer.
|
||||
|
||||
@@ -580,10 +580,10 @@ func (mr *MockAgentConnMockRecorder) StartProcess(ctx, req any) *gomock.Call {
|
||||
}
|
||||
|
||||
// StopDesktopRecording mocks base method.
|
||||
func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (io.ReadCloser, error) {
|
||||
func (m *MockAgentConn) StopDesktopRecording(ctx context.Context, req workspacesdk.StopDesktopRecordingRequest) (workspacesdk.StopDesktopRecordingResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StopDesktopRecording", ctx, req)
|
||||
ret0, _ := ret[0].(io.ReadCloser)
|
||||
ret0, _ := ret[0].(workspacesdk.StopDesktopRecordingResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ CODER_EXPERIMENTS=oauth2
|
||||
2. Click **Create Application**
|
||||
3. Fill in the application details:
|
||||
- **Name**: Your application name
|
||||
- **Callback URL**: `https://yourapp.example.com/callback`
|
||||
- **Callback URL**: `https://yourapp.example.com/callback` (web) or `myapp://callback` (native/desktop)
|
||||
- **Icon**: Optional icon URL
|
||||
|
||||
### Method 2: Management API
|
||||
@@ -251,16 +251,31 @@ Add `oauth2` to your experiment flags: `coder server --experiments oauth2`
|
||||
|
||||
Ensure the redirect URI in your request exactly matches the one registered for your application.
|
||||
|
||||
### "Invalid Callback URL" on the consent page
|
||||
|
||||
If you see this error when authorizing, the registered callback URL uses a
|
||||
blocked scheme (`javascript:`, `data:`, `file:`, or `ftp:`). Update the
|
||||
application's callback URL to a valid scheme (see
|
||||
[Callback URL schemes](#callback-url-schemes)).
|
||||
|
||||
### "PKCE verification failed"
|
||||
|
||||
Verify that the `code_verifier` used in the token request matches the one used to generate the `code_challenge`.
|
||||
|
||||
## Callback URL schemes
|
||||
|
||||
Custom URI schemes (`myapp://`, `vscode://`, `jetbrains://`, etc.) are fully supported for native and desktop applications. The OS routes the redirect back to the registered application without requiring a running HTTP server.
|
||||
|
||||
The following schemes are blocked for security reasons: `javascript:`, `data:`, `file:`, `ftp:`.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Use HTTPS**: Always use HTTPS in production to protect tokens in transit
|
||||
- **Implement PKCE**: PKCE is mandatory for all authorization code clients
|
||||
(public and confidential)
|
||||
- **Validate redirect URLs**: Only register trusted redirect URIs for your applications
|
||||
- **Validate redirect URLs**: Only register trusted redirect URIs. Dangerous
|
||||
schemes (`javascript:`, `data:`, `file:`, `ftp:`) are blocked by the server,
|
||||
but custom URI schemes for native apps (`myapp://`) are permitted
|
||||
- **Rotate secrets**: Periodically rotate client secrets using the management API
|
||||
|
||||
## Limitations
|
||||
|
||||
@@ -150,12 +150,6 @@ deployment. They will always be available from the agent.
|
||||
| `coder_derp_server_sent_pong_total` | counter | Total pongs sent. | |
|
||||
| `coder_derp_server_unknown_frames_total` | counter | Total unknown frames received. | |
|
||||
| `coder_derp_server_watchers` | gauge | Current watchers. | |
|
||||
| `coder_pubsub_batch_delegate_fallbacks_total` | counter | The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage. | `channel_class` `reason` `stage` |
|
||||
| `coder_pubsub_batch_flush_duration_seconds` | histogram | The time spent flushing one chatd batch to PostgreSQL. | `reason` |
|
||||
| `coder_pubsub_batch_queue_depth` | gauge | The number of chatd notifications waiting in the batching queue. | |
|
||||
| `coder_pubsub_batch_sender_reset_failures_total` | counter | The number of batched pubsub sender reset attempts that failed. | |
|
||||
| `coder_pubsub_batch_sender_resets_total` | counter | The number of successful batched pubsub sender resets after flush failures. | |
|
||||
| `coder_pubsub_batch_size` | histogram | The number of logical notifications sent in each chatd batch flush. | |
|
||||
| `coder_pubsub_connected` | gauge | Whether we are connected (1) or not connected (0) to postgres | |
|
||||
| `coder_pubsub_current_events` | gauge | The current number of pubsub event channels listened for | |
|
||||
| `coder_pubsub_current_subscribers` | gauge | The current number of active pubsub subscribers | |
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Data Retention
|
||||
|
||||
Coder supports configurable retention policies that automatically purge old
|
||||
Audit Logs, Connection Logs, Workspace Agent Logs, API keys, and AI Bridge
|
||||
Audit Logs, Connection Logs, Workspace Agent Logs, API keys, and AI Gateway
|
||||
records. These policies help manage database growth by removing records older
|
||||
than a specified duration.
|
||||
|
||||
@@ -33,11 +33,11 @@ a YAML configuration file.
|
||||
| Connection Logs | `--connection-logs-retention` | `CODER_CONNECTION_LOGS_RETENTION` | `0` (disabled) | How long to retain Connection Logs |
|
||||
| API Keys | `--api-keys-retention` | `CODER_API_KEYS_RETENTION` | `7d` | How long to retain expired API keys |
|
||||
| Workspace Agent Logs | `--workspace-agent-logs-retention` | `CODER_WORKSPACE_AGENT_LOGS_RETENTION` | `7d` | How long to retain workspace agent logs |
|
||||
| AI Bridge | `--aibridge-retention` | `CODER_AIBRIDGE_RETENTION` | `60d` | How long to retain AI Bridge records |
|
||||
| AI Gateway | `--aibridge-retention` | `CODER_AIBRIDGE_RETENTION` | `60d` | How long to retain AI Gateway records |
|
||||
|
||||
> [!NOTE]
|
||||
> AI Bridge retention is configured separately from other retention settings.
|
||||
> See [AI Bridge Setup](../../ai-coder/ai-bridge/setup.md#data-retention) for
|
||||
> AI Gateway retention is configured separately from other retention settings.
|
||||
> See [AI Gateway Setup](../../ai-coder/ai-gateway/setup.md#data-retention) for
|
||||
> detailed configuration options.
|
||||
|
||||
### Duration Format
|
||||
@@ -128,15 +128,15 @@ For non-latest builds, logs are deleted if the agent hasn't connected within the
|
||||
retention period. Setting `--workspace-agent-logs-retention=7d` deletes logs for
|
||||
agents that haven't connected in 7 days (excluding those from the latest build).
|
||||
|
||||
### AI Bridge Data Behavior
|
||||
### AI Gateway Data Behavior
|
||||
|
||||
AI Bridge retention applies to interception records and all related data,
|
||||
AI Gateway retention applies to interception records and all related data,
|
||||
including token usage, prompts, and tool invocations. The default of 60 days
|
||||
provides a reasonable balance between storage costs and the ability to analyze
|
||||
usage patterns.
|
||||
|
||||
For details on what data is retained, see the
|
||||
[AI Bridge Data Retention](../../ai-coder/ai-bridge/setup.md#data-retention)
|
||||
[AI Gateway Data Retention](../../ai-coder/ai-gateway/setup.md#data-retention)
|
||||
documentation.
|
||||
|
||||
## Best Practices
|
||||
@@ -199,7 +199,7 @@ retention:
|
||||
workspace_agent_logs: 0s # Keep workspace agent logs forever
|
||||
|
||||
aibridge:
|
||||
retention: 0s # Keep AI Bridge records forever
|
||||
retention: 0s # Keep AI Gateway records forever
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
@@ -214,9 +214,9 @@ containing the table name (e.g., `audit_logs`, `connection_logs`, `api_keys`).
|
||||
purge procedures.
|
||||
- [Connection Logs](../monitoring/connection-logs.md): Learn about Connection
|
||||
Logs and monitoring.
|
||||
- [AI Bridge](../../ai-coder/ai-bridge/index.md): Learn about AI Bridge for
|
||||
- [AI Gateway](../../ai-coder/ai-gateway/index.md): Learn about AI Gateway for
|
||||
centralized LLM and MCP proxy management.
|
||||
- [AI Bridge Setup](../../ai-coder/ai-bridge/setup.md#data-retention): Configure
|
||||
AI Bridge data retention.
|
||||
- [AI Bridge Monitoring](../../ai-coder/ai-bridge/monitoring.md): Monitor AI
|
||||
Bridge usage and metrics.
|
||||
- [AI Gateway Setup](../../ai-coder/ai-gateway/setup.md#data-retention): Configure
|
||||
AI Gateway data retention.
|
||||
- [AI Gateway Monitoring](../../ai-coder/ai-gateway/monitoring.md): Monitor AI
|
||||
Gateway usage and metrics.
|
||||
|
||||
@@ -1,27 +1,32 @@
|
||||
# Agent Boundaries
|
||||
# Agent Firewall
|
||||
|
||||
Agent Boundaries are process-level firewalls that restrict and audit what
|
||||
Agent Firewall is a process-level firewall that restricts and audits what
|
||||
autonomous programs, such as AI agents, can access and use.
|
||||
|
||||
Example
|
||||
of Agent Boundaries blocking a process.
|
||||
Example
|
||||
of Agent Firewall blocking a process.
|
||||
|
||||
> [!NOTE]
|
||||
> Agent Firewall was previously known as "Agent Boundaries". Some
|
||||
> configuration options and internal references still use the old name
|
||||
> and will be updated in a future release.
|
||||
|
||||
## Supported Agents
|
||||
|
||||
Agent Boundaries support the securing of any terminal-based agent, including
|
||||
Agent Firewall supports the securing of any terminal-based agent, including
|
||||
your own custom agents.
|
||||
|
||||
## Features
|
||||
|
||||
Agent Boundaries offer network policy enforcement, which blocks domains and HTTP
|
||||
Agent Firewall offers network policy enforcement, which blocks domains and HTTP
|
||||
verbs to prevent exfiltration, and writes logs to the workspace.
|
||||
|
||||
Agent Boundaries also stream audit logs to Coder's control plane for centralized
|
||||
Agent Firewall also streams audit logs to Coder's control plane for centralized
|
||||
monitoring of HTTP requests.
|
||||
|
||||
## Getting Started with Agent Boundaries
|
||||
## Getting Started with Agent Firewall
|
||||
|
||||
The easiest way to use Agent Boundaries is through existing Coder modules, such
|
||||
The easiest way to use Agent Firewall is through existing Coder modules, such
|
||||
as the
|
||||
[Claude Code module](https://registry.coder.com/modules/coder/claude-code). It
|
||||
can also be ran directly in the terminal by installing the
|
||||
@@ -32,10 +37,10 @@ can also be ran directly in the terminal by installing the
|
||||
> [!NOTE]
|
||||
> For information about version requirements and compatibility, see the [Version Requirements](./version.md) documentation.
|
||||
|
||||
Agent Boundaries is configured using a `config.yaml` file. This allows you to
|
||||
Agent Firewall is configured using a `config.yaml` file. This allows you to
|
||||
maintain allow lists and share detailed policies with teammates.
|
||||
|
||||
In your Terraform module, enable Agent Boundaries with minimal configuration:
|
||||
In your Terraform module, enable Agent Firewall with minimal configuration:
|
||||
|
||||
```tf
|
||||
module "claude-code" {
|
||||
@@ -63,7 +68,7 @@ log_level: warn
|
||||
|
||||
For a basic recommendation of what to allow for agents, see the
|
||||
[Anthropic documentation on default allowed domains](https://code.claude.com/docs/en/claude-code-on-the-web#default-allowed-domains).
|
||||
For a comprehensive example of a production Agent Boundaries configuration, see
|
||||
For a comprehensive example of a production Agent Firewall configuration, see
|
||||
the
|
||||
[Coder dogfood policy example](https://github.com/coder/coder/blob/main/dogfood/coder/boundary-config.yaml).
|
||||
|
||||
@@ -85,9 +90,9 @@ resource "coder_script" "boundary_config_setup" {
|
||||
}
|
||||
```
|
||||
|
||||
Agent Boundaries automatically reads `config.yaml` from
|
||||
Agent Firewall automatically reads `config.yaml` from
|
||||
`~/.config/coder_boundary/` when it starts, so everyone who launches Agent
|
||||
Boundaries manually inside the workspace picks up the same configuration without
|
||||
Firewall manually inside the workspace picks up the same configuration without
|
||||
extra flags. This is especially convenient for managing extensive allow lists in
|
||||
version control.
|
||||
|
||||
@@ -108,8 +113,8 @@ version control.
|
||||
`landjail`. See [Jail Types](#jail-types) for a detailed comparison.
|
||||
- `log_dir` defines where boundary writes log files.
|
||||
- `log_level` defines the verbosity at which requests are logged. Agent
|
||||
Boundaries uses the following verbosity levels:
|
||||
- `WARN`: logs only requests that have been blocked by Agent Boundaries
|
||||
Firewall uses the following verbosity levels:
|
||||
- `WARN`: logs only requests that have been blocked by Agent Firewall
|
||||
- `INFO`: logs all requests at a high level
|
||||
- `DEBUG`: logs all requests in detail
|
||||
- `no_user_namespace` disables creation of a user namespace inside the jail.
|
||||
@@ -124,7 +129,7 @@ version control.
|
||||
For detailed information about the rules engine and how to construct allowlist
|
||||
rules, see the [rules engine documentation](./rules-engine.md).
|
||||
|
||||
You can also run Agent Boundaries directly in your workspace and configure it
|
||||
You can also run Agent Firewall directly in your workspace and configure it
|
||||
per template. You can do so by installing the
|
||||
[binary](https://github.com/coder/boundary) into the workspace image or at
|
||||
start-up. You can do so with the following command:
|
||||
@@ -135,7 +140,7 @@ curl -fsSL https://raw.githubusercontent.com/coder/boundary/main/install.sh | ba
|
||||
|
||||
## Jail Types
|
||||
|
||||
Agent Boundaries supports two different jail types for process isolation, each
|
||||
Agent Firewall supports two different jail types for process isolation, each
|
||||
with different characteristics and requirements:
|
||||
|
||||
1. **nsjail** - Uses Linux namespaces for isolation. This is the default jail
|
||||
@@ -168,31 +173,31 @@ environments where namespace capabilities are limited or unavailable.
|
||||
|
||||
## Audit Logs
|
||||
|
||||
Agent Boundaries stream audit logs to the Coder control plane, providing
|
||||
Agent Firewall streams audit logs to the Coder control plane, providing
|
||||
centralized visibility into HTTP requests made within workspaces—whether from AI
|
||||
agents or ad-hoc commands run with `boundary`.
|
||||
|
||||
Audit logs are independent of application logs:
|
||||
|
||||
- **Audit logs** record Agent Boundaries' policy decisions: whether each HTTP
|
||||
- **Audit logs** record Agent Firewall's policy decisions: whether each HTTP
|
||||
request was allowed or denied based on the allowlist rules. These are always
|
||||
sent to the control plane regardless of Agent Boundaries' configured log
|
||||
sent to the control plane regardless of Agent Firewall's configured log
|
||||
level.
|
||||
- **Application logs** are Agent Boundaries' operational logs written locally to
|
||||
- **Application logs** are Agent Firewall's operational logs written locally to
|
||||
the workspace. These include startup messages, internal errors, and debugging
|
||||
information controlled by the `log_level` setting.
|
||||
|
||||
For example, if a request to `api.example.com` is allowed by Agent Boundaries
|
||||
For example, if a request to `api.example.com` is allowed by Agent Firewall
|
||||
but the remote server returns a 500 error, the audit log records
|
||||
`decision=allow` because Agent Boundaries permitted the request. The HTTP
|
||||
`decision=allow` because Agent Firewall permitted the request. The HTTP
|
||||
response status is not tracked in audit logs.
|
||||
|
||||
> [!NOTE]
|
||||
> Requires Coder v2.30+ and Agent Boundaries v0.5.2+.
|
||||
> Requires Coder v2.30+ and Agent Firewall v0.5.2+.
|
||||
|
||||
### Audit Log Contents
|
||||
|
||||
Each Agent Boundaries audit log entry includes:
|
||||
Each Agent Firewall audit log entry includes:
|
||||
|
||||
| Field | Description |
|
||||
|-----------------------|-----------------------------------------------------------------------------------------|
|
||||
@@ -209,7 +214,7 @@ Each Agent Boundaries audit log entry includes:
|
||||
|
||||
### Viewing Audit Logs
|
||||
|
||||
Agent Boundaries audit logs are emitted as structured log entries from the Coder
|
||||
Agent Firewall audit logs are emitted as structured log entries from the Coder
|
||||
server. You can collect and analyze these logs using any log aggregation system
|
||||
such as Grafana Loki.
|
||||
|
||||
+2
-2
@@ -1,11 +1,11 @@
|
||||
# landjail Jail Type
|
||||
|
||||
landjail is Agent Boundaries' alternative jail type that uses Landlock V4 for
|
||||
landjail is Agent Firewall's alternative jail type that uses Landlock V4 for
|
||||
network isolation.
|
||||
|
||||
## Overview
|
||||
|
||||
Agent Boundaries uses Landlock V4 to enforce network restrictions:
|
||||
Agent Firewall uses Landlock V4 to enforce network restrictions:
|
||||
|
||||
- All `bind` syscalls are forbidden
|
||||
- All `connect` syscalls are forbidden except to the port that is used by http
|
||||
+8
-8
@@ -1,19 +1,19 @@
|
||||
# nsjail on Docker
|
||||
|
||||
This page describes the runtime and permission requirements for running Agent
|
||||
Boundaries with the **nsjail** jail type on **Docker**.
|
||||
Firewall with the **nsjail** jail type on **Docker**.
|
||||
|
||||
For an overview of nsjail, see [nsjail](./index.md).
|
||||
|
||||
## Runtime & Permission Requirements for Running Boundary in Docker
|
||||
|
||||
This section describes the Linux capabilities and runtime configurations
|
||||
required to run Agent Boundaries with nsjail inside a Docker container.
|
||||
required to run Agent Firewall with nsjail inside a Docker container.
|
||||
Requirements vary depending on the OCI runtime and the seccomp profile in use.
|
||||
|
||||
### 1. Default `runc` runtime with `CAP_NET_ADMIN`
|
||||
|
||||
When using Docker's default `runc` runtime, Agent Boundaries requires the
|
||||
When using Docker's default `runc` runtime, Agent Firewall requires the
|
||||
container to have `CAP_NET_ADMIN`. This is the minimal capability needed for
|
||||
configuring virtual networking inside the container.
|
||||
|
||||
@@ -30,10 +30,10 @@ For development or testing environments, you may grant the container
|
||||
`CAP_SYS_ADMIN`, which implicitly bypasses many of the restrictions in Docker's
|
||||
default seccomp profile.
|
||||
|
||||
- Agent Boundaries does not require `CAP_SYS_ADMIN` itself.
|
||||
- Agent Firewall does not require `CAP_SYS_ADMIN` itself.
|
||||
- However, Docker's default seccomp policy commonly blocks namespace-related
|
||||
syscalls unless `CAP_SYS_ADMIN` is present.
|
||||
- Granting `CAP_SYS_ADMIN` enables Agent Boundaries to run without modifying the
|
||||
- Granting `CAP_SYS_ADMIN` enables Agent Firewall to run without modifying the
|
||||
seccomp profile.
|
||||
|
||||
⚠️ Warning: `CAP_SYS_ADMIN` is extremely powerful and should not be used in
|
||||
@@ -41,7 +41,7 @@ production unless absolutely necessary.
|
||||
|
||||
### 3. `sysbox-runc` runtime with `CAP_NET_ADMIN`
|
||||
|
||||
When using the `sysbox-runc` runtime (from Nestybox), Agent Boundaries can run
|
||||
When using the `sysbox-runc` runtime (from Nestybox), Agent Firewall can run
|
||||
with only:
|
||||
|
||||
- `CAP_NET_ADMIN`
|
||||
@@ -53,8 +53,8 @@ seccomp profile modifications.
|
||||
## Docker Seccomp Profile Considerations
|
||||
|
||||
Docker's default seccomp profile frequently blocks the `clone` syscall, which is
|
||||
required by Agent Boundaries when creating unprivileged network namespaces. If
|
||||
the `clone` syscall is denied, Agent Boundaries will fail to start.
|
||||
required by Agent Firewall when creating unprivileged network namespaces. If
|
||||
the `clone` syscall is denied, Agent Firewall will fail to start.
|
||||
|
||||
To address this, you may need to modify or override the seccomp profile used by
|
||||
your container to explicitly allow the required `clone` variants.
|
||||
+5
-5
@@ -1,9 +1,9 @@
|
||||
# nsjail on ECS
|
||||
|
||||
This page describes the runtime and permission requirements for running
|
||||
Boundary with the **nsjail** jail type on **Amazon ECS**.
|
||||
This page describes the runtime and permission requirements for running Agent
|
||||
Firewall with the **nsjail** jail type on **Amazon ECS**.
|
||||
|
||||
## Runtime & Permission Requirements for Running Boundary in ECS
|
||||
## Runtime & Permission Requirements for Running Agent Firewall in ECS
|
||||
|
||||
The setup for ECS is similar to [nsjail on Kubernetes](./k8s.md); that environment
|
||||
is better explored and tested, so the Kubernetes page is a useful reference. On
|
||||
@@ -15,9 +15,9 @@ following examples use **ECS with Self Managed Node Groups** (EC2 launch type).
|
||||
### Example 1: ECS + Self Managed Node Groups + Amazon Linux
|
||||
|
||||
On **Amazon Linux** nodes with ECS, the default Docker seccomp profile enforced
|
||||
by ECS blocks the syscalls needed for Boundary. Because it is difficult to
|
||||
by ECS blocks the syscalls needed for Agent Firewall. Because it is difficult to
|
||||
disable or modify the seccomp profile on ECS, you must grant `SYS_ADMIN` (along
|
||||
with `NET_ADMIN`) so that Boundary can create namespaces and run nsjail.
|
||||
with `NET_ADMIN`) so that Agent Firewall can create namespaces and run nsjail.
|
||||
|
||||
**Task definition (Terraform) — `linuxParameters`:**
|
||||
|
||||
+3
-3
@@ -1,6 +1,6 @@
|
||||
# nsjail Jail Type
|
||||
|
||||
nsjail is Agent Boundaries' default jail type that uses Linux namespaces to
|
||||
nsjail is Agent Firewall's default jail type that uses Linux namespaces to
|
||||
provide process isolation. It creates unprivileged network namespaces to control
|
||||
and monitor network access for processes running under Boundary.
|
||||
|
||||
@@ -14,8 +14,8 @@ and permission requirements:
|
||||
## Overview
|
||||
|
||||
nsjail leverages Linux namespace technology to isolate processes at the network
|
||||
level. When Agent Boundaries runs with nsjail, it creates a separate network
|
||||
namespace for the isolated process, allowing Agent Boundaries to intercept and
|
||||
level. When Agent Firewall runs with nsjail, it creates a separate network
|
||||
namespace for the isolated process, allowing Agent Firewall to intercept and
|
||||
filter all network traffic according to the configured policy.
|
||||
|
||||
This jail type requires Linux capabilities to create and manage network
|
||||
+1
-1
@@ -1,7 +1,7 @@
|
||||
# nsjail on Kubernetes
|
||||
|
||||
This page describes the runtime and permission requirements for running Agent
|
||||
Boundaries with the **nsjail** jail type on **Kubernetes**.
|
||||
Firewall with the **nsjail** jail type on **Kubernetes**.
|
||||
|
||||
## Runtime & Permission Requirements for Running Boundary in Kubernetes
|
||||
|
||||
@@ -7,7 +7,7 @@ v4.7.0 or newer**.
|
||||
|
||||
### Coder v2.30.0+
|
||||
|
||||
Since Coder v2.30.0, Agent Boundaries is embedded inside the Coder binary, and
|
||||
Since Coder v2.30.0, Agent Firewall is embedded inside the Coder binary, and
|
||||
you don't need to install it separately. The `coder boundary` subcommand is
|
||||
available directly from the Coder CLI.
|
||||
|
||||
@@ -26,7 +26,7 @@ the `coder boundary` subcommand isn't available in your Coder installation. In
|
||||
this case, you need to:
|
||||
|
||||
1. Set `use_boundary_directly = true` in your Terraform module configuration
|
||||
2. Explicitly set `boundary_version` to specify which Agent Boundaries version
|
||||
2. Explicitly set `boundary_version` to specify which Agent Firewall version
|
||||
to install
|
||||
|
||||
Example configuration:
|
||||
@@ -44,7 +44,7 @@ module "claude-code" {
|
||||
### Using Claude Code Module Before v4.7.0
|
||||
|
||||
If you're using Claude Code module before v4.7.0, the module expects to use
|
||||
Agent Boundaries directly. You need to explicitly set `boundary_version` in your
|
||||
Agent Firewall directly. You need to explicitly set `boundary_version` in your
|
||||
Terraform configuration:
|
||||
|
||||
```tf
|
||||
@@ -257,12 +257,12 @@ until you add a new personal key.
|
||||
## Using an LLM proxy
|
||||
|
||||
Organizations that route LLM traffic through a centralized proxy — such as
|
||||
Coder's AI Bridge or third parties like LiteLLM — can point any provider's **Base URL** at their proxy endpoint.
|
||||
Coder's AI Gateway or third parties like LiteLLM — can point any provider's **Base URL** at their proxy endpoint.
|
||||
|
||||
For example, to route all OpenAI traffic through Coder's AI Bridge:
|
||||
For example, to route all OpenAI traffic through Coder's AI Gateway:
|
||||
|
||||
1. Add or edit the **OpenAI** provider.
|
||||
1. Set the **Base URL** to your AI Bridge endpoint
|
||||
1. Set the **Base URL** to your AI Gateway endpoint
|
||||
(e.g., `https://example.coder.com/api/v2/aibridge/openai/v1`).
|
||||
1. Enter the API key your proxy expects.
|
||||
|
||||
|
||||
@@ -113,14 +113,14 @@ This setting is available under **Agents** > **Settings** > **Behavior**.
|
||||
The maximum configurable value is 30 days. When disabled, workspaces follow
|
||||
their template's autostop rules (or none, if the template does not define any).
|
||||
|
||||
### Usage limits and analytics
|
||||
### Spend management
|
||||
|
||||
Administrators can set spend limits to cap LLM usage per user within a rolling
|
||||
time period, with per-user and per-group overrides. The cost tracking dashboard
|
||||
provides visibility into per-user spending, token consumption, and per-model
|
||||
breakdowns.
|
||||
|
||||
See [Usage & Analytics](./usage-insights.md) for details.
|
||||
See [Spend Management](./usage-insights.md) for details.
|
||||
|
||||
### Data retention
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Usage and Analytics
|
||||
# Spend Management
|
||||
|
||||
Coder provides two admin-only views for monitoring and controlling agent
|
||||
Coder provides admin-only controls for monitoring and controlling agent
|
||||
spend: usage limits and cost tracking.
|
||||
|
||||
## Usage limits
|
||||
|
||||
Navigate to **Agents** > **Settings** > **Limits**.
|
||||
Navigate to **Agents** > **Settings** > **Spend**.
|
||||
|
||||
Usage limits cap how much each user can spend on LLM usage within a rolling
|
||||
time period. When enabled, the system checks the user's current spend before
|
||||
@@ -53,7 +53,7 @@ their effective limit, current spend, and when the current period resets.
|
||||
|
||||
## Cost tracking
|
||||
|
||||
Navigate to **Agents** > **Settings** > **Usage**.
|
||||
Navigate to **Agents** > **Settings** > **Spend**.
|
||||
|
||||
This view shows deployment-wide LLM chat costs with per-user drill-down.
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# AI Bridge Proxy
|
||||
|
||||
AI Bridge Proxy extends [AI Bridge](../index.md) to support clients that don't allow base URL overrides.
|
||||
While AI Bridge requires clients to support custom base URLs, many popular AI coding tools lack this capability.
|
||||
|
||||
AI Bridge Proxy solves this by acting as an HTTP proxy that intercepts traffic to supported AI providers and forwards it to AI Bridge. Since most clients respect proxy configurations even when they don't support base URL overrides, this provides a universal compatibility layer for AI Bridge.
|
||||
|
||||
For a list of clients supported through AI Bridge Proxy, see [Client Configuration](../clients/index.md).
|
||||
|
||||
## How it works
|
||||
|
||||
AI Bridge Proxy operates in two modes depending on the destination:
|
||||
|
||||
* MITM (Man-in-the-Middle) mode for allowlisted AI provider domains:
|
||||
* Intercepts and decrypts HTTPS traffic using a configured CA certificate
|
||||
* Forwards requests to AI Bridge for authentication, auditing, and routing
|
||||
* Supports: Anthropic, OpenAI, GitHub Copilot
|
||||
|
||||
* Tunnel mode for all other traffic:
|
||||
* Passes requests through without decryption
|
||||
|
||||
Clients authenticate by passing their Coder token in the proxy credentials.
|
||||
|
||||
<!-- TODO(ssncferreira): Add diagram showing how AI Bridge Proxy works in tunnel and MITM modes -->
|
||||
|
||||
## When to use AI Bridge Proxy
|
||||
|
||||
Use AI Bridge Proxy when your AI tools don't support base URL overrides but do respect standard proxy configurations.
|
||||
|
||||
For clients that support base URL configuration, you can use [AI Bridge](../index.md) directly.
|
||||
Nevertheless, clients with base URL overrides also work with the proxy, in case you want to use multiple AI clients and some of them do not support base URL configuration.
|
||||
|
||||
## Next steps
|
||||
|
||||
* [Set up AI Bridge Proxy](./setup.md) on your Coder deployment
|
||||
@@ -0,0 +1,35 @@
|
||||
# AI Gateway Proxy
|
||||
|
||||
AI Gateway Proxy extends [AI Gateway](../index.md) to support clients that don't allow base URL overrides.
|
||||
While AI Gateway requires clients to support custom base URLs, many popular AI coding tools lack this capability.
|
||||
|
||||
AI Gateway Proxy solves this by acting as an HTTP proxy that intercepts traffic to supported AI providers and forwards it to AI Gateway. Since most clients respect proxy configurations even when they don't support base URL overrides, this provides a universal compatibility layer for AI Gateway.
|
||||
|
||||
For a list of clients supported through AI Gateway Proxy, see [Client Configuration](../clients/index.md).
|
||||
|
||||
## How it works
|
||||
|
||||
AI Gateway Proxy operates in two modes depending on the destination:
|
||||
|
||||
* MITM (Man-in-the-Middle) mode for allowlisted AI provider domains:
|
||||
* Intercepts and decrypts HTTPS traffic using a configured CA certificate
|
||||
* Forwards requests to AI Gateway for authentication, auditing, and routing
|
||||
* Supports: Anthropic, OpenAI, GitHub Copilot
|
||||
|
||||
* Tunnel mode for all other traffic:
|
||||
* Passes requests through without decryption
|
||||
|
||||
Clients authenticate by passing their Coder token in the proxy credentials.
|
||||
|
||||
<!-- TODO(ssncferreira): Add diagram showing how AI Gateway Proxy works in tunnel and MITM modes -->
|
||||
|
||||
## When to use AI Gateway Proxy
|
||||
|
||||
Use AI Gateway Proxy when your AI tools don't support base URL overrides but do respect standard proxy configurations.
|
||||
|
||||
For clients that support base URL configuration, you can use [AI Gateway](../index.md) directly.
|
||||
Nevertheless, clients with base URL overrides also work with the proxy, in case you want to use multiple AI clients and some of them do not support base URL configuration.
|
||||
|
||||
## Next steps
|
||||
|
||||
* [Set up AI Gateway Proxy](./setup.md) on your Coder deployment
|
||||
+27
-27
@@ -1,18 +1,18 @@
|
||||
# Setup
|
||||
|
||||
AI Bridge Proxy runs inside the Coder control plane (`coderd`), requiring no separate compute to deploy or scale.
|
||||
Once enabled, `coderd` runs the `aibridgeproxyd` in-memory and intercepts traffic to supported AI providers, forwarding it to AI Bridge.
|
||||
AI Gateway Proxy runs inside the Coder control plane (`coderd`), requiring no separate compute to deploy or scale.
|
||||
Once enabled, `coderd` runs the `aibridgeproxyd` in-memory and intercepts traffic to supported AI providers, forwarding it to AI Gateway.
|
||||
|
||||
**Required:**
|
||||
|
||||
1. AI Bridge must be enabled and configured (requires a **Premium** license with the [AI Governance Add-On](../../ai-governance.md)). See [AI Bridge Setup](../setup.md) for further information.
|
||||
1. AI Bridge Proxy must be [enabled](#proxy-configuration) using the server flag.
|
||||
1. AI Gateway must be enabled and configured (requires a **Premium** license with the [AI Governance Add-On](../../ai-governance.md)). See [AI Gateway Setup](../setup.md) for further information.
|
||||
1. AI Gateway Proxy must be [enabled](#proxy-configuration) using the server flag.
|
||||
1. A [CA certificate](#ca-certificate) must be configured for MITM interception.
|
||||
1. [Clients](#client-configuration) must be configured to use the proxy and trust the CA certificate.
|
||||
|
||||
## Proxy Configuration
|
||||
|
||||
AI Bridge Proxy is disabled by default. To enable it, set the following configuration options:
|
||||
AI Gateway Proxy is disabled by default. To enable it, set the following configuration options:
|
||||
|
||||
```shell
|
||||
CODER_AIBRIDGE_ENABLED=true \
|
||||
@@ -28,7 +28,7 @@ coder server \
|
||||
--aibridge-proxy-key-file=/path/to/ca.key
|
||||
```
|
||||
|
||||
Both the certificate and private key are required for AI Bridge Proxy to start.
|
||||
Both the certificate and private key are required for AI Gateway Proxy to start.
|
||||
See [CA Certificate](#ca-certificate) for how to generate and obtain these files.
|
||||
|
||||
By default, the proxy listener accepts plain HTTP connections.
|
||||
@@ -46,7 +46,7 @@ Both files must be provided together.
|
||||
The TLS certificate must include a Subject Alternative Name (SAN) matching the hostname or IP address that clients use to connect to the proxy.
|
||||
See [Proxy TLS Configuration](#proxy-tls-configuration) for how to generate and configure these files.
|
||||
|
||||
The AI Bridge Proxy only intercepts and forwards traffic to AI Bridge for the supported AI provider domains:
|
||||
The AI Gateway Proxy only intercepts and forwards traffic to AI Gateway for the supported AI provider domains:
|
||||
|
||||
* [Anthropic](https://www.anthropic.com/): `api.anthropic.com`
|
||||
* [OpenAI](https://openai.com/): `api.openai.com`
|
||||
@@ -59,7 +59,7 @@ For additional configuration options, see the [Coder server configuration](../..
|
||||
## Security Considerations
|
||||
|
||||
> [!WARNING]
|
||||
> The AI Bridge Proxy should only be accessible within a trusted network and **must not** be directly exposed to the public internet.
|
||||
> The AI Gateway Proxy should only be accessible within a trusted network and **must not** be directly exposed to the public internet.
|
||||
> Without proper network restrictions, unauthorized users could route traffic through the proxy or intercept credentials.
|
||||
|
||||
### Encrypting client connections
|
||||
@@ -68,7 +68,7 @@ By default, AI tools send the Coder session token in the proxy credentials over
|
||||
This only applies to the initial connection between the client and the proxy.
|
||||
Once connected:
|
||||
|
||||
* MITM mode: A TLS connection is established between the AI tool and the proxy (using the configured CA certificate), then traffic is forwarded securely to AI Bridge.
|
||||
* MITM mode: A TLS connection is established between the AI tool and the proxy (using the configured CA certificate), then traffic is forwarded securely to AI Gateway.
|
||||
* Tunnel mode: A TLS connection is established directly between the AI tool and the destination, passing through the proxy without decryption.
|
||||
|
||||
As a best practice, apply one or more of the following to protect credentials during the initial connection:
|
||||
@@ -85,15 +85,15 @@ To prevent unauthorized use, restrict network access to the proxy so that only a
|
||||
|
||||
## CA Certificate
|
||||
|
||||
AI Bridge Proxy uses a CA (Certificate Authority) certificate to perform MITM interception of HTTPS traffic.
|
||||
AI Gateway Proxy uses a CA (Certificate Authority) certificate to perform MITM interception of HTTPS traffic.
|
||||
When AI tools connect to AI provider domains through the proxy, the proxy presents a certificate signed by this CA.
|
||||
AI tools must trust this CA certificate, otherwise, the connection will fail.
|
||||
|
||||
### Self-signed certificate
|
||||
|
||||
Use a self-signed certificate when your organization doesn't have an internal CA, or when you want a dedicated CA specifically for AI Bridge Proxy.
|
||||
Use a self-signed certificate when your organization doesn't have an internal CA, or when you want a dedicated CA specifically for AI Gateway Proxy.
|
||||
|
||||
Generate a CA certificate specifically for AI Bridge Proxy:
|
||||
Generate a CA certificate specifically for AI Gateway Proxy:
|
||||
|
||||
1) Generate a private key:
|
||||
|
||||
@@ -108,10 +108,10 @@ chmod 400 ca.key
|
||||
openssl req -new -x509 -days 3650 \
|
||||
-key ca.key \
|
||||
-out ca.crt \
|
||||
-subj "/CN=AI Bridge Proxy CA"
|
||||
-subj "/CN=AI Gateway Proxy CA"
|
||||
```
|
||||
|
||||
Configure AI Bridge Proxy with both files:
|
||||
Configure AI Gateway Proxy with both files:
|
||||
|
||||
```shell
|
||||
CODER_AIBRIDGE_PROXY_CERT_FILE=/path/to/ca.crt
|
||||
@@ -120,7 +120,7 @@ CODER_AIBRIDGE_PROXY_KEY_FILE=/path/to/ca.key
|
||||
|
||||
### Corporate CA certificate
|
||||
|
||||
If your organization has an internal CA that clients already trust, you can have it issue an intermediate CA certificate for AI Bridge Proxy.
|
||||
If your organization has an internal CA that clients already trust, you can have it issue an intermediate CA certificate for AI Gateway Proxy.
|
||||
This simplifies deployment since AI tools that already trust your organization's root CA will automatically trust certificates signed by the intermediate.
|
||||
|
||||
Your organization's CA issues a certificate and private key pair for the proxy. Configure the proxy with both files:
|
||||
@@ -158,14 +158,14 @@ How you configure AI tools to trust the certificate depends on the tool and oper
|
||||
|
||||
## Proxy TLS Configuration
|
||||
|
||||
By default, the AI Bridge Proxy listener accepts plain HTTP connections.
|
||||
By default, the AI Gateway Proxy listener accepts plain HTTP connections.
|
||||
When TLS is enabled, the proxy serves over HTTPS, encrypting the connection between AI tools and the proxy.
|
||||
|
||||
The TLS certificate is separate from the [MITM CA certificate](#ca-certificate).
|
||||
The CA certificate is used to sign dynamically generated certificates during MITM interception.
|
||||
The TLS certificate identifies the proxy itself, like any standard web server certificate.
|
||||
|
||||
The AI Bridge Proxy enforces a minimum TLS version of 1.2.
|
||||
The AI Gateway Proxy enforces a minimum TLS version of 1.2.
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -183,7 +183,7 @@ Both files must be provided together. If only one is set, the proxy will fail to
|
||||
|
||||
### Self-signed certificate
|
||||
|
||||
Use a self-signed certificate when your organization doesn't have an internal CA, or when you want a dedicated certificate specifically for the AI Bridge Proxy.
|
||||
Use a self-signed certificate when your organization doesn't have an internal CA, or when you want a dedicated certificate specifically for the AI Gateway Proxy.
|
||||
|
||||
The TLS certificate must include a Subject Alternative Name (SAN) matching the hostname or IP address that clients use to connect to the proxy.
|
||||
Without a matching SAN, clients will reject the connection.
|
||||
@@ -225,20 +225,20 @@ See [Client Configuration](#client-configuration) for details.
|
||||
|
||||
## Upstream proxy
|
||||
|
||||
If your organization requires all outbound traffic to pass through a corporate proxy, you can configure AI Bridge Proxy to chain requests to an upstream proxy.
|
||||
If your organization requires all outbound traffic to pass through a corporate proxy, you can configure AI Gateway Proxy to chain requests to an upstream proxy.
|
||||
|
||||
> [!NOTE]
|
||||
> AI Bridge Proxy must be the first proxy in the chain.
|
||||
> AI tools must be configured to connect directly to AI Bridge Proxy, which then forwards tunneled traffic to the upstream proxy.
|
||||
> AI Gateway Proxy must be the first proxy in the chain.
|
||||
> AI tools must be configured to connect directly to AI Gateway Proxy, which then forwards tunneled traffic to the upstream proxy.
|
||||
|
||||
### How it works
|
||||
|
||||
Tunneled requests (non-allowlisted domains) are forwarded to the upstream proxy configured via [`CODER_AIBRIDGE_PROXY_UPSTREAM`](../../../reference/cli/server.md#--aibridge-proxy-upstream).
|
||||
|
||||
MITM'd requests (AI provider domains) are forwarded to AI Bridge, which then communicates with AI providers.
|
||||
To ensure AI Bridge also routes requests through the upstream proxy, make sure to configure the proxy settings for the Coder server process.
|
||||
MITM'd requests (AI provider domains) are forwarded to AI Gateway, which then communicates with AI providers.
|
||||
To ensure AI Gateway also routes requests through the upstream proxy, make sure to configure the proxy settings for the Coder server process.
|
||||
|
||||
<!-- TODO(ssncferreira): Add diagram showing how AI Bridge Proxy integrates with upstream proxies -->
|
||||
<!-- TODO(ssncferreira): Add diagram showing how AI Gateway Proxy integrates with upstream proxies -->
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -263,7 +263,7 @@ If the system already trusts the upstream proxy's CA certificate, [`CODER_AIBRID
|
||||
|
||||
## Client Configuration
|
||||
|
||||
To use AI Bridge Proxy, AI tools must be configured to:
|
||||
To use AI Gateway Proxy, AI tools must be configured to:
|
||||
|
||||
1. Route traffic through the proxy
|
||||
1. Trust the proxy's CA certificate
|
||||
@@ -287,7 +287,7 @@ Note: if [TLS is not enabled](#proxy-tls-configuration) on the proxy, replace `h
|
||||
> `HTTP_PROXY` is not required since AI providers only use `HTTPS`.
|
||||
> Leaving it unset avoids routing unnecessary traffic through the proxy.
|
||||
|
||||
In order for AI tools that communicate with AI Bridge Proxy to authenticate with Coder via AI Bridge, the Coder session token needs to be passed in the proxy credentials as the password field.
|
||||
In order for AI tools that communicate with AI Gateway Proxy to authenticate with Coder via AI Gateway, the Coder session token needs to be passed in the proxy credentials as the password field.
|
||||
|
||||
### Trusting the CA certificate
|
||||
|
||||
@@ -356,6 +356,6 @@ For other operating systems, refer to the system's documentation for instruction
|
||||
For AI tools running inside Coder workspaces, template administrators can pre-configure the proxy settings and CA certificate in the workspace template.
|
||||
This provides a seamless experience where users don't need to configure anything manually.
|
||||
|
||||
<!-- TODO(ssncferreira): Add registry link for AI Bridge Proxy module for Coder workspaces: https://github.com/coder/internal/issues/1187 -->
|
||||
<!-- TODO(ssncferreira): Add registry link for AI Gateway Proxy module for Coder workspaces: https://github.com/coder/internal/issues/1187 -->
|
||||
|
||||
For tool-specific configuration details, check the [client compatibility table](../clients/index.md#compatibility) for clients that require proxy-based integration.
|
||||
@@ -1,6 +1,6 @@
|
||||
# Auditing AI Sessions
|
||||
|
||||
AI Bridge groups intercepted requests into **sessions** and **threads** to show
|
||||
AI Gateway groups intercepted requests into **sessions** and **threads** to show
|
||||
the causal relationships between human prompts and agent actions. This
|
||||
structure gives auditors clear provenance over who initiated what, and why.
|
||||
|
||||
@@ -15,7 +15,7 @@ structure gives auditors clear provenance over who initiated what, and why.
|
||||
|
||||
## Human vs. Agent attribution
|
||||
|
||||
AI Bridge distinguishes between human-initiated and agent-initiated requests
|
||||
AI Gateway distinguishes between human-initiated and agent-initiated requests
|
||||
using the `role` property:
|
||||
|
||||
- A message with `role="user"` indicates a human-initiated action (i.e. prompt).
|
||||
@@ -24,16 +24,16 @@ using the `role` property:
|
||||
|
||||
The `user` role is currently overloaded by clients like Claude Code and Codex;
|
||||
they inject system instructions
|
||||
within `role="user"` blocks when using agents. AI Bridge applies a heuristic
|
||||
within `role="user"` blocks when using agents. AI Gateway applies a heuristic
|
||||
of storing only the **last** prompt from a block of `role="user"` messages.
|
||||
|
||||
> [!NOTE]
|
||||
> AI Bridge cannot declare with certainty whether a request was human- or
|
||||
> AI Gateway cannot declare with certainty whether a request was human- or
|
||||
> agent-initiated.
|
||||
|
||||
## LLM reasoning capture
|
||||
|
||||
AI Bridge captures model reasoning and thinking content when available. Both
|
||||
AI Gateway captures model reasoning and thinking content when available. Both
|
||||
Anthropic (extended thinking) and OpenAI (reasoning summaries) support this
|
||||
feature. Reasoning data gives auditors insight into **why** a tool was called,
|
||||
not just what was called.
|
||||
@@ -77,7 +77,7 @@ When investigating an incident (policy violation, destructive action, etc.):
|
||||
|
||||
## What we store
|
||||
|
||||
AI Bridge captures the following data from each request/response:
|
||||
AI Gateway captures the following data from each request/response:
|
||||
|
||||
- Last user prompt
|
||||
- Token usage
|
||||
@@ -105,5 +105,5 @@ session data is kept.
|
||||
## Next steps
|
||||
|
||||
- [Monitoring](./monitoring.md) — Dashboards, data export, and tracing
|
||||
- [Setup](./setup.md) — Configure AI Bridge and data retention
|
||||
- [Setup](./setup.md) — Configure AI Gateway and data retention
|
||||
- [Reference](./reference.md) — API and technical reference
|
||||
+9
-9
@@ -1,27 +1,27 @@
|
||||
# Claude Code
|
||||
|
||||
Claude Code can be configured using environment variables. All modes require a **[Coder session token](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself)** for authentication with AI Bridge.
|
||||
Claude Code can be configured using environment variables. All modes require a **[Coder session token](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself)** for authentication with AI Gateway.
|
||||
|
||||
## Centralized API Key
|
||||
|
||||
```bash
|
||||
# AI Bridge base URL.
|
||||
# AI Gateway base URL.
|
||||
export ANTHROPIC_BASE_URL="<your-deployment-url>/api/v2/aibridge/anthropic"
|
||||
|
||||
# Your Coder session token, used for authentication with AI Bridge.
|
||||
# Your Coder session token, used for authentication with AI Gateway.
|
||||
export ANTHROPIC_AUTH_TOKEN="<your-coder-session-token>"
|
||||
```
|
||||
|
||||
## BYOK (Personal API Key)
|
||||
|
||||
```bash
|
||||
# AI Bridge base URL.
|
||||
# AI Gateway base URL.
|
||||
export ANTHROPIC_BASE_URL="<your-deployment-url>/api/v2/aibridge/anthropic"
|
||||
|
||||
# Your personal Anthropic API key, forwarded to Anthropic.
|
||||
export ANTHROPIC_API_KEY="<your-anthropic-api-key>"
|
||||
|
||||
# Your Coder session token, used for authentication with AI Bridge.
|
||||
# Your Coder session token, used for authentication with AI Gateway.
|
||||
export ANTHROPIC_CUSTOM_HEADERS="X-Coder-AI-Governance-Token: <your-coder-session-token>"
|
||||
|
||||
# Ensure no auth token is set so Claude Code uses the API key instead.
|
||||
@@ -31,10 +31,10 @@ unset ANTHROPIC_AUTH_TOKEN
|
||||
## BYOK (Claude Subscription)
|
||||
|
||||
```bash
|
||||
# AI Bridge base URL.
|
||||
# AI Gateway base URL.
|
||||
export ANTHROPIC_BASE_URL="<your-deployment-url>/api/v2/aibridge/anthropic"
|
||||
|
||||
# Your Coder session token, used for authentication with AI Bridge.
|
||||
# Your Coder session token, used for authentication with AI Gateway.
|
||||
export ANTHROPIC_CUSTOM_HEADERS="X-Coder-AI-Governance-Token: <your-coder-session-token>"
|
||||
|
||||
# Ensure no auth token is set so Claude Code uses subscription login instead.
|
||||
@@ -46,7 +46,7 @@ account.
|
||||
|
||||
## Pre-configuring in Templates
|
||||
|
||||
Template admins can pre-configure Claude Code for a seamless experience. Admins can automatically inject the user's Coder session token and the AI Bridge base URL into the workspace environment.
|
||||
Template admins can pre-configure Claude Code for a seamless experience. Admins can automatically inject the user's Coder session token and the AI Gateway base URL into the workspace environment.
|
||||
|
||||
```hcl
|
||||
module "claude-code" {
|
||||
@@ -77,7 +77,7 @@ module "claude-code" {
|
||||
workdir = "/path/to/project" # Set to your project directory
|
||||
ai_prompt = data.coder_task.me.prompt
|
||||
|
||||
# Route through AI Bridge (Premium feature)
|
||||
# Route through AI Gateway (Premium feature)
|
||||
enable_aibridge = true
|
||||
}
|
||||
```
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
# Cline
|
||||
|
||||
Cline supports both OpenAI and Anthropic models and can be configured to use AI Bridge by setting providers.
|
||||
Cline supports both OpenAI and Anthropic models and can be configured to use AI Gateway by setting providers.
|
||||
|
||||
## Configuration
|
||||
|
||||
To configure Cline to use AI Bridge, follow these steps:
|
||||
To configure Cline to use AI Gateway, follow these steps:
|
||||

|
||||
|
||||
<div class="tabs">
|
||||
+5
-5
@@ -1,10 +1,10 @@
|
||||
# Codex CLI
|
||||
|
||||
Codex CLI can be configured to use AI Bridge by setting up a custom model provider.
|
||||
Codex CLI can be configured to use AI Gateway by setting up a custom model provider.
|
||||
|
||||
## Centralized API Key
|
||||
|
||||
To configure Codex CLI to use AI Bridge, set the following configuration options in your Codex configuration file (e.g., `~/.codex/config.toml`):
|
||||
To configure Codex CLI to use AI Gateway, set the following configuration options in your Codex configuration file (e.g., `~/.codex/config.toml`):
|
||||
|
||||
```toml
|
||||
model_provider = "aibridge"
|
||||
@@ -16,7 +16,7 @@ env_key = "OPENAI_API_KEY"
|
||||
wire_api = "responses"
|
||||
```
|
||||
|
||||
To authenticate with AI Bridge, get your **[Coder session token](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself)** and set it in your environment:
|
||||
To authenticate with AI Gateway, get your **[Coder session token](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself)** and set it in your environment:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="<your-coder-session-token>"
|
||||
@@ -45,7 +45,7 @@ Set both environment variables:
|
||||
# Your personal OpenAI API key, forwarded to OpenAI.
|
||||
export OPENAI_API_KEY="<your-openai-api-key>"
|
||||
|
||||
# Your Coder session token, used for authentication with AI Bridge.
|
||||
# Your Coder session token, used for authentication with AI Gateway.
|
||||
export CODER_SESSION_TOKEN="<your-coder-session-token>"
|
||||
```
|
||||
|
||||
@@ -70,7 +70,7 @@ env_http_headers = { "X-Coder-AI-Governance-Token" = "CODER_SESSION_TOKEN" }
|
||||
Set your Coder session token and ensure `OPENAI_API_KEY` is not set:
|
||||
|
||||
```bash
|
||||
# Your Coder session token, used for authentication with AI Bridge.
|
||||
# Your Coder session token, used for authentication with AI Gateway.
|
||||
export CODER_SESSION_TOKEN="<your-coder-session-token>"
|
||||
|
||||
# Ensure no OpenAI API key is set so Codex uses ChatGPT login instead.
|
||||
+22
-22
@@ -1,15 +1,15 @@
|
||||
# GitHub Copilot
|
||||
|
||||
[GitHub Copilot](https://github.com/features/copilot) is an AI coding assistant that doesn't support custom base URLs but does respect proxy configurations.
|
||||
This makes it compatible with [AI Bridge Proxy](../ai-bridge-proxy/index.md), which integrates with [AI Bridge](../index.md) for full access to auditing and governance features.
|
||||
To use Copilot with AI Bridge, make sure AI Bridge Proxy is properly configured, see [AI Bridge Proxy Setup](../ai-bridge-proxy/setup.md) for instructions.
|
||||
This makes it compatible with [AI Gateway Proxy](../ai-gateway-proxy/index.md), which integrates with [AI Gateway](../index.md) for full access to auditing and governance features.
|
||||
To use Copilot with AI Gateway, make sure AI Gateway Proxy is properly configured, see [AI Gateway Proxy Setup](../ai-gateway-proxy/setup.md) for instructions.
|
||||
|
||||
Copilot uses **per-user tokens** tied to GitHub accounts rather than a shared API key.
|
||||
Users must still authenticate with GitHub to use Copilot.
|
||||
|
||||
For general information about GitHub Copilot, see the [GitHub Copilot documentation](https://docs.github.com/en/copilot).
|
||||
|
||||
For general client configuration requirements, see [AI Bridge Proxy Client Configuration](../ai-bridge-proxy/setup.md#client-configuration).
|
||||
For general client configuration requirements, see [AI Gateway Proxy Client Configuration](../ai-gateway-proxy/setup.md#client-configuration).
|
||||
The sections below cover Copilot-specific setup for each client.
|
||||
|
||||
## Copilot CLI
|
||||
@@ -24,9 +24,9 @@ Set the `HTTPS_PROXY` environment variable:
|
||||
export HTTPS_PROXY="https://coder:${CODER_SESSION_TOKEN}@<proxy-host>:8888"
|
||||
```
|
||||
|
||||
Replace `<proxy-host>` with your AI Bridge Proxy hostname.
|
||||
Replace `<proxy-host>` with your AI Gateway Proxy hostname.
|
||||
|
||||
Note: if [TLS is not enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
Note: if [TLS is not enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
|
||||
### CA certificate trust
|
||||
|
||||
@@ -36,9 +36,9 @@ Copilot CLI is built on Node.js and uses the `NODE_EXTRA_CA_CERTS` environment v
|
||||
export NODE_EXTRA_CA_CERTS="/path/to/coder-aibridge-proxy-ca.pem"
|
||||
```
|
||||
|
||||
See [Client Configuration CA certificate trust](../ai-bridge-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
See [Client Configuration CA certificate trust](../ai-gateway-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
|
||||
When [TLS is enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, combine the MITM CA certificate and the TLS certificate into a single file:
|
||||
When [TLS is enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, combine the MITM CA certificate and the TLS certificate into a single file:
|
||||
|
||||
```shell
|
||||
cat coder-aibridge-proxy-ca.pem listener.crt > combined-ca.pem
|
||||
@@ -47,7 +47,7 @@ export NODE_EXTRA_CA_CERTS="/path/to/combined-ca.pem"
|
||||
|
||||
Copilot CLI may start MCP server processes that use runtimes other than Node.js (e.g. Go).
|
||||
These processes inherit environment variables like `HTTPS_PROXY` but may not respect `NODE_EXTRA_CA_CERTS`.
|
||||
Adding the TLS certificate to the [system trust store](../ai-bridge-proxy/setup.md#system-trust-store) ensures all processes trust it.
|
||||
Adding the TLS certificate to the [system trust store](../ai-gateway-proxy/setup.md#system-trust-store) ensures all processes trust it.
|
||||
|
||||
## VS Code Copilot Extension
|
||||
|
||||
@@ -56,7 +56,7 @@ For installation instructions, see [Installing the GitHub Copilot extension in V
|
||||
### Proxy configuration
|
||||
|
||||
You can configure the proxy using environment variables or VS Code settings.
|
||||
For environment variables, see [AI Bridge Proxy client configuration](../ai-bridge-proxy/setup.md#configuring-the-proxy).
|
||||
For environment variables, see [AI Gateway Proxy client configuration](../ai-gateway-proxy/setup.md#configuring-the-proxy).
|
||||
|
||||
Alternatively, you can configure the proxy directly in VS Code settings:
|
||||
|
||||
@@ -72,10 +72,10 @@ Or add directly to your `settings.json`:
|
||||
}
|
||||
```
|
||||
|
||||
Note: if [TLS is not enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
Note: if [TLS is not enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
|
||||
The `http.proxy` setting is used for both HTTP and HTTPS requests.
|
||||
Replace `<proxy-host>` with your AI Bridge Proxy hostname and `<CODER_SESSION_TOKEN>` with your coder session token.
|
||||
Replace `<proxy-host>` with your AI Gateway Proxy hostname and `<CODER_SESSION_TOKEN>` with your coder session token.
|
||||
|
||||
Restart VS Code for changes to take effect.
|
||||
|
||||
@@ -83,19 +83,19 @@ For more details, see [Configuring proxy settings for Copilot](https://docs.gith
|
||||
|
||||
### CA certificate trust
|
||||
|
||||
Add the AI Bridge Proxy CA certificate to your operating system's trust store.
|
||||
Add the AI Gateway Proxy CA certificate to your operating system's trust store.
|
||||
By default, VS Code loads system certificates, controlled by the `http.systemCertificates` setting.
|
||||
|
||||
See [Client Configuration CA certificate trust](../ai-bridge-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
See [Client Configuration CA certificate trust](../ai-gateway-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
|
||||
When [TLS is enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the system trust store as well.
|
||||
When [TLS is enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the system trust store as well.
|
||||
|
||||
### Using Coder Remote extension
|
||||
|
||||
When connecting to a Coder workspace with the [Coder extension](https://marketplace.visualstudio.com/items?itemName=coder.coder-remote), the Copilot extension runs inside the Coder workspace and not on your local machine.
|
||||
This means proxy and certificate configuration must be done in the Coder workspace environment.
|
||||
|
||||
When [TLS is enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the workspace's system trust store as well.
|
||||
When [TLS is enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the workspace's system trust store as well.
|
||||
|
||||
#### Proxy configuration
|
||||
|
||||
@@ -107,14 +107,14 @@ Configure the proxy in VS Code's remote settings:
|
||||
1. Search for `HTTP: Proxy`
|
||||
1. Set the proxy URL using the format `https://coder:<CODER_SESSION_TOKEN>@<proxy-host>:8888`
|
||||
|
||||
Note: if [TLS is not enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
Note: if [TLS is not enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, replace `https://` with `http://` in the proxy URL.
|
||||
|
||||
Replace `<proxy-host>` with your AI Bridge Proxy hostname and `<CODER_SESSION_TOKEN>` with your coder session token.
|
||||
Replace `<proxy-host>` with your AI Gateway Proxy hostname and `<CODER_SESSION_TOKEN>` with your coder session token.
|
||||
|
||||
#### CA certificate trust
|
||||
|
||||
Since the Copilot extension runs inside the Coder workspace, add the [AI Bridge Proxy CA certificate](../ai-bridge-proxy/setup.md#trusting-the-ca-certificate) to the Coder workspace's system trust store.
|
||||
See [System trust store](../ai-bridge-proxy/setup.md#system-trust-store) for instructions on how to do this on Linux.
|
||||
Since the Copilot extension runs inside the Coder workspace, add the [AI Gateway Proxy CA certificate](../ai-gateway-proxy/setup.md#trusting-the-ca-certificate) to the Coder workspace's system trust store.
|
||||
See [System trust store](../ai-gateway-proxy/setup.md#system-trust-store) for instructions on how to do this on Linux.
|
||||
|
||||
Restart VS Code for changes to take effect.
|
||||
|
||||
@@ -140,10 +140,10 @@ For more details, see [Configuring proxy settings for Copilot](https://docs.gith
|
||||
|
||||
### CA certificate trust
|
||||
|
||||
Add the AI Bridge Proxy CA certificate to your operating system's trust store.
|
||||
Add the AI Gateway Proxy CA certificate to your operating system's trust store.
|
||||
If the certificate is in the system trust store, no additional IDE configuration is needed.
|
||||
|
||||
When [TLS is enabled](../ai-bridge-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the system trust store as well, or add it under `Accepted certificates` in the IDE settings below.
|
||||
When [TLS is enabled](../ai-gateway-proxy/setup.md#proxy-tls-configuration) on the proxy, add the TLS certificate to the system trust store as well, or add it under `Accepted certificates` in the IDE settings below.
|
||||
|
||||
Alternatively, you can configure the IDE to accept the certificate:
|
||||
|
||||
@@ -155,4 +155,4 @@ Alternatively, you can configure the IDE to accept the certificate:
|
||||
|
||||
For more details, see [Trusted root certificates](https://www.jetbrains.com/help/idea/ssl-certificates.html) in the JetBrains documentation.
|
||||
|
||||
See [Client Configuration CA certificate trust](../ai-bridge-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
See [Client Configuration CA certificate trust](../ai-gateway-proxy/setup.md#trusting-the-ca-certificate) for details on how to obtain the certificate file.
|
||||
+2
-2
@@ -1,11 +1,11 @@
|
||||
# Factory
|
||||
|
||||
Factort's Droid agent can be configured to use AI Bridge by setting up custom models for OpenAI and Anthropic.
|
||||
Factort's Droid agent can be configured to use AI Gateway by setting up custom models for OpenAI and Anthropic.
|
||||
|
||||
## Configuration
|
||||
|
||||
1. Open `~/.factory/settings.json` (create it if it does not exist).
|
||||
2. Add a `customModels` entry for each provider you want to use with AI Bridge.
|
||||
2. Add a `customModels` entry for each provider you want to use with AI Gateway.
|
||||
3. Replace `coder.example.com` with your Coder deployment URL.
|
||||
4. Use a **[Coder session token](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself)** for `apiKey`.
|
||||
|
||||
+20
-20
@@ -1,11 +1,11 @@
|
||||
# Client Configuration
|
||||
|
||||
Once AI Bridge is setup on your deployment, the AI coding tools used by your users will need to be configured to route requests via AI Bridge.
|
||||
Once AI Gateway is setup on your deployment, the AI coding tools used by your users will need to be configured to route requests via AI Gateway.
|
||||
|
||||
There are two ways to connect AI tools to AI Bridge:
|
||||
There are two ways to connect AI tools to AI Gateway:
|
||||
|
||||
- Base URL configuration (Recommended): Most AI tools allow customizing the base URL for API requests. This is the preferred approach when supported.
|
||||
- AI Bridge Proxy: For tools that don't support base URL configuration, [AI Bridge Proxy](../ai-bridge-proxy/index.md) can intercept traffic and forward it to AI Bridge.
|
||||
- AI Gateway Proxy: For tools that don't support base URL configuration, [AI Gateway Proxy](../ai-gateway-proxy/index.md) can intercept traffic and forward it to AI Gateway.
|
||||
|
||||
## Base URLs
|
||||
|
||||
@@ -20,14 +20,14 @@ Replace `coder.example.com` with your actual Coder deployment URL.
|
||||
|
||||
## Authentication
|
||||
|
||||
Instead of distributing provider-specific API keys (OpenAI/Anthropic keys) to users, they authenticate to AI Bridge using their **Coder session token** or **API key**:
|
||||
Instead of distributing provider-specific API keys (OpenAI/Anthropic keys) to users, they authenticate to AI Gateway using their **Coder session token** or **API key**:
|
||||
|
||||
- **OpenAI clients**: Users set `OPENAI_API_KEY` to their Coder session token or API key
|
||||
- **Anthropic clients**: Users set `ANTHROPIC_API_KEY` to their Coder session token or API key
|
||||
|
||||
> [!NOTE]
|
||||
> Only Coder-issued tokens can authenticate users against AI Bridge.
|
||||
> AI Bridge will use provider-specific API keys to [authenticate against upstream AI services](https://coder.com/docs/ai-coder/ai-bridge/setup#configure-providers).
|
||||
> Only Coder-issued tokens can authenticate users against AI Gateway.
|
||||
> AI Gateway will use provider-specific API keys to [authenticate against upstream AI services](../setup.md#configure-providers).
|
||||
|
||||
Again, the exact environment variable or setting naming may differ from tool to tool. See a list of [supported clients](#all-supported-clients) below and consult your tool's documentation for details.
|
||||
|
||||
@@ -45,22 +45,22 @@ Alternatively, [generate a long-lived API token](../../../admin/users/sessions-t
|
||||
|
||||
## Bring Your Own Key (BYOK)
|
||||
|
||||
In addition to centralized key management, AI Bridge supports **Bring Your
|
||||
In addition to centralized key management, AI Gateway supports **Bring Your
|
||||
Own Key** (BYOK) mode. Users can provide their own LLM API keys or use
|
||||
provider subscriptions (such as Claude Pro/Max or ChatGPT Plus/Pro) while
|
||||
AI Bridge continues to provide observability and governance.
|
||||
AI Gateway continues to provide observability and governance.
|
||||
|
||||

|
||||
|
||||
In BYOK mode, users need two credentials:
|
||||
|
||||
- A **Coder session token** to authenticate with AI Bridge.
|
||||
- Their **own LLM credential** (personal API key or subscription token) which AI Bridge forwards
|
||||
- A **Coder session token** to authenticate with AI Gateway.
|
||||
- Their **own LLM credential** (personal API key or subscription token) which AI Gateway forwards
|
||||
to the upstream provider.
|
||||
|
||||
BYOK and centralized modes can be used together. When a user provides
|
||||
their own credential, AI Bridge forwards it directly. When no user
|
||||
credential is present, AI Bridge falls back to the admin-configured
|
||||
their own credential, AI Gateway forwards it directly. When no user
|
||||
credential is present, AI Gateway falls back to the admin-configured
|
||||
provider key. This lets organizations offer centralized keys as a default
|
||||
while allowing individual users to bring their own.
|
||||
|
||||
@@ -68,7 +68,7 @@ See individual client pages for configuration details.
|
||||
|
||||
## Compatibility
|
||||
|
||||
The table below shows tested AI clients and their compatibility with AI Bridge.
|
||||
The table below shows tested AI clients and their compatibility with AI Gateway.
|
||||
|
||||
| Client | OpenAI | Anthropic | Notes |
|
||||
|----------------------------------|--------|-----------|--------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
@@ -83,7 +83,7 @@ The table below shows tested AI clients and their compatibility with AI Bridge.
|
||||
| [VS Code](./vscode.md) | ✅ | ❌ | Only supports Custom Base URL for OpenAI. |
|
||||
| [JetBrains IDEs](./jetbrains.md) | ✅ | ❌ | Works in Chat mode via "Bring Your Own Key". |
|
||||
| [Zed](./zed.md) | ✅ | ✅ | |
|
||||
| [GitHub Copilot](./copilot.md) | ⚙️ | - | Requires [AI Bridge Proxy](../ai-bridge-proxy/index.md). Uses per-user GitHub tokens. |
|
||||
| [GitHub Copilot](./copilot.md) | ⚙️ | - | Requires [AI Gateway Proxy](../ai-gateway-proxy/index.md). Uses per-user GitHub tokens. |
|
||||
| WindSurf | ❌ | ❌ | No option to override base URL. |
|
||||
| Cursor | ❌ | ❌ | Override for OpenAI broken ([upstream issue](https://forum.cursor.com/t/requests-are-sent-to-incorrect-endpoint-when-using-base-url-override/144894)). |
|
||||
| Sourcegraph Amp | ❌ | ❌ | No option to override base URL. |
|
||||
@@ -92,15 +92,15 @@ The table below shows tested AI clients and their compatibility with AI Bridge.
|
||||
| Antigravity | ❌ | ❌ | No option to override base URL. |
|
||||
|
|
||||
|
||||
*Legend: ✅ supported, ⚙️ requires AI Bridge Proxy, ❌ not supported, - not applicable.*
|
||||
*Legend: ✅ supported, ⚙️ requires AI Gateway Proxy, ❌ not supported, - not applicable.*
|
||||
|
||||
## Configuring In-Workspace Tools
|
||||
|
||||
AI coding tools running inside a Coder workspace, such as IDE extensions, can be configured to use AI Bridge.
|
||||
AI coding tools running inside a Coder workspace, such as IDE extensions, can be configured to use AI Gateway.
|
||||
|
||||
While users can manually configure these tools with a long-lived API key, template admins can provide a more seamless experience by pre-configuring them. Admins can automatically inject the user's session token with `data.coder_workspace_owner.me.session_token` and the AI Bridge base URL into the workspace environment.
|
||||
While users can manually configure these tools with a long-lived API key, template admins can provide a more seamless experience by pre-configuring them. Admins can automatically inject the user's session token with `data.coder_workspace_owner.me.session_token` and the AI Gateway base URL into the workspace environment.
|
||||
|
||||
In this example, Claude Code respects these environment variables and will route all requests via AI Bridge.
|
||||
In this example, Claude Code respects these environment variables and will route all requests via AI Gateway.
|
||||
|
||||
```hcl
|
||||
data "coder_workspace_owner" "me" {}
|
||||
@@ -121,9 +121,9 @@ resource "coder_agent" "dev" {
|
||||
|
||||
## External and Desktop Clients
|
||||
|
||||
You can also configure AI tools running outside of a Coder workspace, such as local IDE extensions or desktop applications, to connect to AI Bridge.
|
||||
You can also configure AI tools running outside of a Coder workspace, such as local IDE extensions or desktop applications, to connect to AI Gateway.
|
||||
|
||||
The configuration is the same: point the tool to the AI Bridge [base URL](#base-urls) and use a Coder API key for authentication.
|
||||
The configuration is the same: point the tool to the AI Gateway [base URL](#base-urls) and use a Coder API key for authentication.
|
||||
|
||||
Users can generate a long-lived API key from the Coder UI or CLI. Follow the instructions at [Sessions and API tokens](../../../admin/users/sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-yourself) to create one.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user