Compare commits

..

2 Commits

Author SHA1 Message Date
Mathias Fredriksson 09d799606b fix(site): apply stream parts during pending and clear synchronously
shouldApplyMessagePart dropped parts when chatStatus was "pending",
causing early LLM response content to disappear when pubsub status
delivery was slower than message_part delivery. Remove the "pending"
guard since stream state is already cleared on transition to pending.

Replace RAF-deferred stream reset with synchronous clearStreamState
on durable message arrival. The RAF approach allowed stale tool calls
to persist when new parts canceled the scheduled reset.
2026-03-13 23:30:54 +00:00
Mathias Fredriksson 6e5ac02c96 test(site): assert correct stream state behavior during pending and step transitions
Bug A: Parts arriving during 'pending' chatStatus are silently dropped
by the hook's shouldApplyMessagePart closure and never recovered.

Bug F: Tool calls from a previous step persist when
cancelScheduledStreamReset prevents clearStreamState between steps.
2026-03-13 22:36:25 +00:00
539 changed files with 15997 additions and 58611 deletions
-72
View File
@@ -1,72 +0,0 @@
---
name: pull-requests
description: "Guide for creating, updating, and following up on pull requests in the Coder repository. Use when asked to open a PR, update a PR, rewrite a PR description, or follow up on CI/check failures."
---
# Pull Request Skill
## When to Use This Skill
Use this skill when asked to:
- Create a pull request for the current branch.
- Update an existing PR branch or description.
- Rewrite a PR body.
- Follow up on CI or check failures for an existing PR.
## References
Use the canonical docs for shared conventions and validation guidance:
- PR title and description conventions:
`.claude/docs/PR_STYLE_GUIDE.md`
- Local validation commands and git hooks: `AGENTS.md` (Essential Commands and
Git Hooks sections)
## Lifecycle Rules
1. **Check for an existing PR** before creating a new one:
```bash
gh pr list --head "$(git branch --show-current)" --author @me --json number --jq '.[0].number // empty'
```
If that returns a number, update that PR. If it returns empty output,
create a new one.
2. **Check you are not on main.** If the current branch is `main` or `master`,
create a feature branch before doing PR work.
3. **Default to draft.** Use `gh pr create --draft` unless the user explicitly
asks for ready-for-review.
4. **Keep description aligned with the full diff.** Re-read the diff against
the base branch before writing or updating the title and body. Describe the
entire PR diff, not just the last commit.
5. **Never auto-merge.** Do not merge or mark ready for review unless the user
explicitly asks.
6. **Never push to main or master.**
## CI / Checks Follow-up
**Always watch CI checks after pushing.** Do not push and walk away.
After pushing:
- Monitor CI with `gh pr checks <PR_NUMBER> --watch`.
- Use `gh pr view <PR_NUMBER> --json statusCheckRollup` for programmatic check
status.
If checks fail:
1. Find the failed run ID from the `gh pr checks` output.
2. Read the logs with `gh run view <run-id> --log-failed`.
3. Fix the problem locally.
4. Run `make pre-commit`.
5. Push the fix.
## What Not to Do
- Do not reference or call helper scripts that do not exist in this
repository.
- Do not auto-merge or mark ready for review without explicit user request.
- Do not push to `origin/main` or `origin/master`.
- Do not skip local validation before pushing.
- Do not fabricate or embellish PR descriptions.
+1 -1
View File
@@ -113,7 +113,7 @@ Coder emphasizes clear error handling, with specific patterns required:
All tests should run in parallel using `t.Parallel()` to ensure efficient testing and expose potential race conditions. The codebase is rigorously linted with golangci-lint to maintain consistent code quality.
Git contributions follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Git contributions follow a standard format with commit messages structured as `type: <message>`, where type is one of `feat`, `fix`, or `chore`.
## Development Workflow
+14 -5
View File
@@ -4,13 +4,22 @@ This guide documents the PR description style used in the Coder repository, base
## PR Title Format
Format: `type(scope): description`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/) format:
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
```text
type(scope): brief description
```
Examples:
**Common types:**
- `feat`: New features
- `fix`: Bug fixes
- `refactor`: Code refactoring without behavior change
- `perf`: Performance improvements
- `docs`: Documentation changes
- `chore`: Dependency updates, tooling changes
**Examples:**
- `feat: add tracing to aibridge`
- `fix: move contexts to appropriate locations`
+3 -5
View File
@@ -136,11 +136,9 @@ Then make your changes and push normally. Don't use `git push --force` unless th
## Commit Style
Format: `type(scope): message`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
- Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/)
- Format: `type(scope): message`
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
- Keep message titles concise (~70 characters)
- Use imperative, present tense in commit titles
-9
View File
@@ -1,9 +0,0 @@
paths:
# The triage workflow uses a quoted heredoc (<<'EOF') with ${VAR}
# placeholders that envsubst expands later. Shellcheck's SC2016
# warns about unexpanded variables in single-quoted strings, but
# the non-expansion is intentional here. Actionlint doesn't honor
# inline shellcheck disable directives inside heredocs.
.github/workflows/triage-via-chat-api.yaml:
ignore:
- 'SC2016'
+34 -96
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -191,7 +191,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@631208b7aac2daa8b707f55e7331f9112b0e062d # v1.44.0
uses: crate-ci/typos@2d0ce569feab1f8752f1dde43cc2f2aa53236e06 # v1.40.0
with:
config: .github/workflows/typos.toml
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -537,7 +537,7 @@ jobs:
embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }}
- name: Upload failed test db dumps
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -818,7 +818,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -826,7 +826,7 @@ jobs:
- name: Upload debug log
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coderd-debug-logs${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/e2e/test-results/debug.log
@@ -834,7 +834,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1108,7 +1108,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1198,7 +1198,7 @@ jobs:
make -j \
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
@@ -1216,28 +1216,11 @@ jobs:
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
# Free up disk space before building Docker images. The preceding
# Build step produces ~2 GB of binaries and packages, the Go build
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
# builds, pushes, and SBOM generation need headroom that isn't
# available without reclaiming some of that space.
- name: Clean up build cache
run: |
set -euxo pipefail
# Go caches are no longer needed — binaries are already compiled.
go clean -cache -modcache
# Remove .apk and .rpm packages that are not uploaded as
# artifacts and were only built as make prerequisites.
rm -f ./build/*.apk ./build/*.rpm
- name: Build Linux Docker images
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
DOCKER_CLI_EXPERIMENTAL: "enabled"
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
# the Docker image targets — they were already built above.
DOCKER_IMAGE_NO_PREREQUISITES: "true"
run: |
set -euxo pipefail
@@ -1319,7 +1302,7 @@ jobs:
id: attest_main
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1356,7 +1339,7 @@ jobs:
id: attest_latest
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1393,7 +1376,7 @@ jobs:
id: attest_version
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1455,60 +1438,15 @@ jobs:
^v
prune-untagged: true
- name: Upload build artifact (coder-linux-amd64.tar.gz)
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.tar.gz
path: ./build/*_linux_amd64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-amd64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-amd64.deb
path: ./build/*_linux_amd64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.tar.gz
path: ./build/*_linux_arm64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.deb
path: ./build/*_linux_arm64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.tar.gz
path: ./build/*_linux_armv7.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.deb
path: ./build/*_linux_armv7.deb
retention-days: 7
- name: Upload build artifact (coder-windows-amd64.zip)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-windows-amd64.zip
path: ./build/*_windows_amd64.zip
name: coder
path: |
./build/*.zip
./build/*.tar.gz
./build/*.deb
retention-days: 7
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
@@ -1543,7 +1481,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-141
View File
@@ -23,44 +23,6 @@ permissions:
concurrency: pr-${{ github.ref }}
jobs:
community-label:
runs-on: ubuntu-latest
permissions:
pull-requests: write
if: >-
${{
github.event_name == 'pull_request_target' &&
github.event.action == 'opened' &&
github.event.pull_request.author_association != 'MEMBER' &&
github.event.pull_request.author_association != 'COLLABORATOR' &&
github.event.pull_request.author_association != 'OWNER'
}}
steps:
- name: Add community label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const params = {
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
}
const labels = context.payload.pull_request.labels.map((label) => label.name)
if (labels.includes("community")) {
console.log('PR already has "community" label.')
return
}
console.log(
'Adding "community" label for author association "%s".',
context.payload.pull_request.author_association,
)
await github.rest.issues.addLabels({
...params,
labels: ["community"],
})
cla:
runs-on: ubuntu-latest
permissions:
@@ -83,109 +45,6 @@ jobs:
# Some users have signed a corporate CLA with Coder so are exempt from signing our community one.
allowlist: "coryb,aaronlehmann,dependabot*,blink-so*,blinkagent*"
title:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request_target' }}
steps:
- name: Validate PR title
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const { pull_request } = context.payload;
const title = pull_request.title;
const repo = { owner: context.repo.owner, repo: context.repo.repo };
const allowedTypes = [
"feat", "fix", "docs", "style", "refactor",
"perf", "test", "build", "ci", "chore", "revert",
];
const expectedFormat = `"type(scope): description" or "type: description"`;
const guidelinesLink = `See: https://github.com/coder/coder/blob/main/docs/about/contributing/CONTRIBUTING.md#commit-messages`;
const scopeHint = (type) =>
`Use a broader scope or no scope (e.g., "${type}: ...") for cross-cutting changes.\n` +
guidelinesLink;
console.log("Title: %s", title);
// Parse conventional commit format: type(scope)!: description
const match = title.match(/^(\w+)(\(([^)]*)\))?(!)?\s*:\s*.+/);
if (!match) {
core.setFailed(
`PR title does not match conventional commit format.\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
const type = match[1];
const scope = match[3]; // undefined if no parentheses
// Validate type.
if (!allowedTypes.includes(type)) {
core.setFailed(
`PR title has invalid type "${type}".\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
// If no scope, we're done.
if (!scope) {
console.log("No scope provided, title is valid.");
return;
}
console.log("Scope: %s", scope);
// Fetch changed files.
const files = await github.paginate(github.rest.pulls.listFiles, {
...repo,
pull_number: pull_request.number,
per_page: 100,
});
const changedPaths = files.map(f => f.filename);
console.log("Changed files: %d", changedPaths.length);
// Derive scope type from the changed files. The diff is the
// source of truth: if files exist under the scope, the path
// exists on the PR branch. No need for Contents API calls.
const isDir = changedPaths.some(f => f.startsWith(scope + "/"));
const isFile = changedPaths.some(f => f === scope);
const isStem = changedPaths.some(f => f.startsWith(scope + "."));
if (!isDir && !isFile && !isStem) {
core.setFailed(
`PR title scope "${scope}" does not match any files changed in this PR.\n` +
`Scopes must reference a path (directory or file stem) that contains changed files.\n` +
scopeHint(type)
);
return;
}
// Verify all changed files fall under the scope.
const outsideFiles = changedPaths.filter(f => {
if (isDir && f.startsWith(scope + "/")) return false;
if (f === scope) return false;
if (isStem && f.startsWith(scope + ".")) return false;
return true;
});
if (outsideFiles.length > 0) {
const listed = outsideFiles.map(f => " - " + f).join("\n");
core.setFailed(
`PR title scope "${scope}" does not contain all changed files.\n` +
`Files outside scope:\n${listed}\n\n` +
scopeHint(type)
);
return;
}
console.log("PR title is valid.");
release-labels:
runs-on: ubuntu-latest
permissions:
+19 -15
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -61,11 +61,11 @@ jobs:
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
permissions:
contents: read
id-token: write # to authenticate to EKS cluster
id-token: write
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,29 +76,33 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
role-to-assume: ${{ vars.AWS_DOGFOOD_DEPLOY_ROLE }}
aws-region: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Get Cluster Credentials
run: aws eks update-kubeconfig --name "$AWS_DOGFOOD_CLUSTER_NAME" --region "$AWS_DOGFOOD_DEPLOY_REGION"
env:
AWS_DOGFOOD_CLUSTER_NAME: ${{ vars.AWS_DOGFOOD_CLUSTER_NAME }}
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.8.2"
version: "2.7.0"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
# Retag image as dogfood while maintaining the multi-arch manifest
- name: Tag image as dogfood
@@ -142,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -48,7 +48,7 @@ jobs:
persist-credentials: false
- name: Docker login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v45.0.7
- uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.7
id: changed-files
with:
files: |
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -78,11 +78,11 @@ jobs:
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -30,7 +30,7 @@ jobs:
- name: Sync issues
id: sync
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
@@ -52,7 +52,7 @@ jobs:
- name: Complete release
id: complete
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+6 -6
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -14,12 +14,12 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Run Schmoder CI
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
with:
workflow: ci.yaml
repo: coder/schmoder
+12 -10
View File
@@ -80,7 +80,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -155,7 +155,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -358,7 +358,7 @@ jobs:
id: attest_base
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -474,7 +474,7 @@ jobs:
id: attest_main
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -518,7 +518,7 @@ jobs:
id: attest_latest
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -665,7 +665,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: release-artifacts
path: |
@@ -681,7 +681,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -700,11 +700,13 @@ jobs:
name: Publish to Homebrew tap
runs-on: ubuntu-latest
needs: release
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
if: ${{ !inputs.dry_run }}
steps:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -780,7 +782,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: SARIF file
path: results.sarif
+4 -4
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -69,7 +69,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
@@ -160,7 +160,7 @@ jobs:
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: trivy
path: trivy-results.sarif
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-295
View File
@@ -1,295 +0,0 @@
# This workflow reimplements the AI Triage Automation using the Coder Chat API
# instead of the Tasks API. The Chat API (/api/experimental/chats) is a simpler
# interface that does not require a dedicated GitHub Action or workspace
# provisioning — we just create a chat, poll for completion, and link the
# result on the issue. All API calls use curl + jq directly.
#
# Key differences from the Tasks API workflow (traiage.yaml):
# - No checkout of coder/create-task-action; everything is inline curl/jq.
# - No template_name / template_preset / prefix inputs — the Chat API handles
# resource allocation internally.
# - Uses POST /api/experimental/chats to create a chat session.
# - Polls GET /api/experimental/chats/<id> until the agent finishes.
# - Chat URL format: ${CODER_URL}/agents?chat=${CHAT_ID}
name: AI Triage via Chat API
on:
issues:
types:
- labeled
workflow_dispatch:
inputs:
issue_url:
description: "GitHub Issue URL to process"
required: true
type: string
permissions:
contents: read
jobs:
triage-chat:
name: Triage GitHub Issue via Chat API
runs-on: ubuntu-latest
if: github.event.label.name == 'chat-triage' || github.event_name == 'workflow_dispatch'
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
permissions:
contents: read
issues: write
steps:
# ------------------------------------------------------------------
# Step 1: Determine the GitHub user and issue URL.
# Identical to the Tasks API workflow — resolve the actor for
# workflow_dispatch or the issue sender for label events.
# ------------------------------------------------------------------
- name: Determine Inputs
id: determine-inputs
if: always()
env:
GITHUB_ACTOR: ${{ github.actor }}
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# For workflow_dispatch, use the actor who triggered it.
# For issues events, use the issue sender.
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
exit 1
fi
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
exit 0
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
exit 0
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
# ------------------------------------------------------------------
# Step 2: Verify the triggering user has push access.
# Unchanged from the Tasks API workflow.
# ------------------------------------------------------------------
- name: Verify push access
env:
GITHUB_REPOSITORY: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
set -euo pipefail
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
if [[ "${can_push}" != "true" ]]; then
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
exit 1
fi
# ------------------------------------------------------------------
# Step 3: Create a chat via the Coder Chat API.
# Unlike the Tasks API which provisions a full workspace, the Chat
# API creates a lightweight chat session. We POST to
# /api/experimental/chats with the triage prompt as the initial
# message and receive a chat ID back.
# ------------------------------------------------------------------
- name: Create chat via Coder Chat API
id: create-chat
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# Build the same triage prompt used by the Tasks API workflow.
TASK_PROMPT=$(cat <<'EOF'
Fix ${ISSUE_URL}
1. Use the gh CLI to read the issue description and comments.
2. Think carefully and try to understand the root cause. If the issue is unclear or not well defined, ask me to clarify and provide more information.
3. Write a proposed implementation plan to PLAN.md for me to review before starting implementation. Your plan should use TDD and only make the minimal changes necessary to fix the root cause.
4. When I approve your plan, start working on it. If you encounter issues with the plan, ask me for clarification and update the plan as required.
5. When you have finished implementation according to the plan, commit and push your changes, and create a PR using the gh CLI for me to review.
EOF
)
# Perform variable substitution on the prompt — scoped to $ISSUE_URL only.
# Using envsubst without arguments would expand every env var in scope
# (including CODER_SESSION_TOKEN), so we name the variable explicitly.
TASK_PROMPT=$(echo "${TASK_PROMPT}" | envsubst '$ISSUE_URL')
echo "Creating chat with prompt:"
echo "${TASK_PROMPT}"
# POST to the Chat API to create a new chat session.
RESPONSE=$(curl --silent --fail-with-body \
-X POST \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
-H "Content-Type: application/json" \
-d "$(jq -n --arg prompt "${TASK_PROMPT}" \
'{content: [{type: "text", text: $prompt}]}')" \
"${CODER_URL}/api/experimental/chats")
echo "Chat API response:"
echo "${RESPONSE}" | jq .
CHAT_ID=$(echo "${RESPONSE}" | jq -r '.id')
CHAT_STATUS=$(echo "${RESPONSE}" | jq -r '.status')
if [[ -z "${CHAT_ID}" || "${CHAT_ID}" == "null" ]]; then
echo "::error::Failed to create chat — no ID returned"
echo "Response: ${RESPONSE}"
exit 1
fi
# Validate that CHAT_ID is a UUID before using it in URL paths.
# This guards against unexpected API responses being interpolated
# into subsequent curl calls.
if [[ ! "${CHAT_ID}" =~ ^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$ ]]; then
echo "::error::CHAT_ID is not a valid UUID: ${CHAT_ID}"
exit 1
fi
CHAT_URL="${CODER_URL}/agents?chat=${CHAT_ID}"
echo "Chat created: ${CHAT_ID} (status: ${CHAT_STATUS})"
echo "Chat URL: ${CHAT_URL}"
echo "chat_id=${CHAT_ID}" >> "${GITHUB_OUTPUT}"
echo "chat_url=${CHAT_URL}" >> "${GITHUB_OUTPUT}"
# ------------------------------------------------------------------
# Step 4: Poll the chat status until the agent finishes.
# The Chat API is asynchronous — after creation the agent begins
# working in the background. We poll GET /api/experimental/chats/<id>
# every 5 seconds until the status is "waiting" (agent needs input),
# "completed" (agent finished), or "error". Timeout after 10 minutes.
# ------------------------------------------------------------------
- name: Poll chat status
id: poll-status
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
run: |
set -euo pipefail
POLL_INTERVAL=5
# 10 minutes = 600 seconds.
TIMEOUT=600
ELAPSED=0
echo "Polling chat ${CHAT_ID} every ${POLL_INTERVAL}s (timeout: ${TIMEOUT}s)..."
while true; do
RESPONSE=$(curl --silent --fail-with-body \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
"${CODER_URL}/api/experimental/chats/${CHAT_ID}")
STATUS=$(echo "${RESPONSE}" | jq -r '.status')
echo "[${ELAPSED}s] Chat status: ${STATUS}"
case "${STATUS}" in
waiting|completed)
echo "Chat reached terminal status: ${STATUS}"
echo "final_status=${STATUS}" >> "${GITHUB_OUTPUT}"
exit 0
;;
error)
echo "::error::Chat entered error state"
echo "${RESPONSE}" | jq .
echo "final_status=error" >> "${GITHUB_OUTPUT}"
exit 1
;;
pending|running)
# Still working — keep polling.
;;
*)
echo "::warning::Unknown chat status: ${STATUS}"
;;
esac
if [[ ${ELAPSED} -ge ${TIMEOUT} ]]; then
echo "::error::Timed out after ${TIMEOUT}s waiting for chat to finish"
echo "final_status=timeout" >> "${GITHUB_OUTPUT}"
exit 1
fi
sleep "${POLL_INTERVAL}"
ELAPSED=$((ELAPSED + POLL_INTERVAL))
done
# ------------------------------------------------------------------
# Step 5: Comment on the GitHub issue with a link to the chat.
# Only comment if the issue belongs to this repository (same guard
# as the Tasks API workflow).
# ------------------------------------------------------------------
- name: Comment on issue
if: startsWith(steps.determine-inputs.outputs.issue_url, format('{0}/{1}', github.server_url, github.repository))
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
COMMENT_BODY=$(cat <<EOF
🤖 **AI Triage Chat Created**
A Coder chat session has been created to investigate this issue.
**Chat URL:** ${CHAT_URL}
**Chat ID:** \`${CHAT_ID}\`
**Status:** ${FINAL_STATUS}
The agent is working on a triage plan. Visit the chat to follow progress or provide guidance.
EOF
)
gh issue comment "${ISSUE_URL}" --body "${COMMENT_BODY}"
echo "Comment posted on ${ISSUE_URL}"
# ------------------------------------------------------------------
# Step 6: Write a summary to the GitHub Actions step summary.
# ------------------------------------------------------------------
- name: Write summary
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
set -euo pipefail
{
echo "## AI Triage via Chat API"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Chat ID:** \`${CHAT_ID}\`"
echo "**Chat URL:** ${CHAT_URL}"
echo "**Status:** ${FINAL_STATUS}"
} >> "${GITHUB_STEP_SUMMARY}"
-2
View File
@@ -29,8 +29,6 @@ EDE = "EDE"
HELO = "HELO"
LKE = "LKE"
byt = "byt"
cpy = "cpy"
Cpy = "Cpy"
typ = "typ"
# file extensions used in seti icon theme
styl = "styl"
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+5 -19
View File
@@ -146,20 +146,11 @@ git config core.hooksPath scripts/githooks
Two hooks run automatically:
- **pre-commit**: Classifies staged files by type and runs either
the full `make pre-commit` or the lightweight `make pre-commit-light`
depending on whether Go, TypeScript, SQL, proto, or Makefile
changes are present. Falls back to the full target when
`CODER_HOOK_RUN_ALL=1` is set. A markdown-only commit takes
seconds; a Go change takes several minutes.
- **pre-push**: Classifies changed files (vs remote branch or
merge-base) and runs `make pre-push` when Go, TypeScript, SQL,
proto, or Makefile changes are detected. Skips tests entirely
for lightweight changes. Allowlisted in
`scripts/githooks/pre-push`. Runs only for developers who opt
in. Falls back to `make pre-push` when the diff range can't
be determined or `CODER_HOOK_RUN_ALL=1` is set. Allow at least
15 minutes for a full run.
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (heavier checks including tests).
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
who opt in. Allow at least 15 minutes.
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
@@ -217,11 +208,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
- PR titles follow the same `type(scope): message` format.
- When you use a scope, it must be a real filesystem path containing every
changed file.
- Use a broader path scope, or omit the scope, for cross-cutting changes.
- Example: `fix(coderd/chatd): ...` for changes only in `coderd/chatd/`.
### Frontend Patterns
+10 -41
View File
@@ -136,10 +136,18 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -506,12 +514,6 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
GREEN := $(shell tput setaf 2 2>/dev/null)
RED := $(shell tput setaf 1 2>/dev/null)
@@ -522,10 +524,6 @@ RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
.PHONY: fmt
# Subset of fmt that does not require Go or Node toolchains.
fmt-light: fmt/shfmt fmt/terraform fmt/markdown
.PHONY: fmt-light
fmt/go:
ifdef FILE
# Format single file
@@ -633,10 +631,6 @@ LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
.PHONY: lint
# Subset of lint that does not require Go or Node toolchains.
lint-light: lint/shellcheck lint/markdown lint/helm lint/bootstrap lint/migrations lint/actions/actionlint lint/typos
.PHONY: lint-light
lint/site-icons:
./scripts/check_site_icons.sh
.PHONY: lint/site-icons
@@ -779,25 +773,6 @@ pre-commit:
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit
# Lightweight pre-commit for changes that don't touch Go or
# TypeScript. Skips gen, lint/go, lint/ts, fmt/go, fmt/ts, and
# the binary build. Used by the pre-commit hook when only docs,
# shell, terraform, helm, or other fast-to-check files changed.
pre-commit-light:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit-light.XXXXXX")
echo "$(BOLD)pre-commit-light$(RESET) ($$logdir)"
echo "fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir fmt-light
$(check-unstaged)
echo "lint:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir lint-light
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit-light passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit-light
pre-push:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
@@ -806,7 +781,6 @@ pre-push:
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
test \
test-js \
test-storybook \
site/out/index.html
rm -rf $$logdir
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
@@ -1341,11 +1315,6 @@ test-js: site/node_modules/.installed
pnpm test:ci
.PHONY: test-js
test-storybook: site/node_modules/.installed
cd site/
pnpm exec vitest run --project=storybook
.PHONY: test-storybook
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
# dependency for any sqlc-cloud related targets.
sqlc-cloud-is-setup:
+2 -7
View File
@@ -385,16 +385,11 @@ func (a *agent) init() {
pathStore := agentgit.NewPathStore()
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
-14
View File
@@ -57,26 +57,18 @@ type fakeContainerCLI struct {
}
func (f *fakeContainerCLI) List(_ context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.containers, f.listErr
}
func (f *fakeContainerCLI) DetectArchitecture(_ context.Context, _ string) (string, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.arch, f.archErr
}
func (f *fakeContainerCLI) Copy(ctx context.Context, name, src, dst string) error {
f.mu.Lock()
defer f.mu.Unlock()
return f.copyErr
}
func (f *fakeContainerCLI) ExecAs(ctx context.Context, name, user string, args ...string) ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()
return nil, f.execErr
}
@@ -2697,9 +2689,7 @@ func TestAPI(t *testing.T) {
// When: The container is recreated (new container ID) with config changes.
terraformContainer.ID = "new-container-id"
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = terraformContainer.ID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
@@ -2831,9 +2821,7 @@ func TestAPI(t *testing.T) {
// Simulate container rebuild: new container ID, changed display apps.
newContainerID := "new-container-id"
terraformContainer.ID = newContainerID
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = newContainerID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
@@ -4938,11 +4926,9 @@ func TestDevcontainerPrebuildSupport(t *testing.T) {
)
api.Start()
fCCLI.mu.Lock()
fCCLI.containers = codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{testContainer},
}
fCCLI.mu.Unlock()
// Given: We allow the dev container to be created.
fDCCLI.upID = testContainer.ID
+169 -24
View File
@@ -2,9 +2,13 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -20,6 +24,28 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
@@ -52,31 +78,43 @@ type screenshotOutput struct {
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
dataDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: execer,
dataDir: dataDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
)
p.binPath = scriptBinPath
p.binPath = cachedPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
}
@@ -2,6 +2,11 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := t.Context()
ctx := context.Background()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(t.Context())
x, y, err := pd.CursorPosition(context.Background())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Start(ctx)
require.NoError(t, err)
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, binPath, pd.binPath)
assert.Equal(t, cachedPath, pd.binPath)
}
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
func TestEnsureBinary_Downloads(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
}
// Clear PATH so LookPath won't find a real binary.
// Ensure PATH doesn't contain a real portabledesktop binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
}
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+51 -174
View File
@@ -333,68 +333,22 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
return status, err
}
// Check if the target already exists so we can preserve its
// permissions on the temp file before rename.
var origMode os.FileMode
var haveOrigMode bool
if stat, serr := api.filesystem.Stat(path); serr == nil {
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: is a directory", path)
}
origMode = stat.Mode()
haveOrigMode = true
}
// Write to a temp file in the same directory so the rename is
// always on the same device (atomic).
tmpfile, err := afero.TempFile(api.filesystem, dir, filepath.Base(path))
f, err := api.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
return status, err
}
tmpName := tmpfile.Name()
defer f.Close()
_, err = io.Copy(tmpfile, r.Body)
if err != nil && !errors.Is(err, io.EOF) {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if haveOrigMode {
if err := api.filesystem.Chmod(tmpName, origMode); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
}
if err := api.filesystem.Rename(tmpName, path); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
api.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}
return 0, nil
@@ -493,10 +447,13 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
content := string(data)
for _, edit := range edits {
var err error
content, err = fuzzyReplace(content, edit)
if err != nil {
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
}
@@ -506,135 +463,68 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
if err != nil {
return http.StatusInternalServerError, err
}
tmpName := tmpfile.Name()
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if err := api.filesystem.Chmod(tmpName, stat.Mode()); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
err = api.filesystem.Rename(tmpName, path)
err = api.filesystem.Rename(tmpfile.Name(), path)
if err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
return http.StatusInternalServerError, err
}
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace
// (indentation-tolerant).
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When edit.ReplaceAll is false (the default), the search string must
// match exactly one location. If multiple matches are found, an error
// is returned asking the caller to include more context or set
// replace_all.
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still
// applied at the byte offsets of the original content so that surrounding
// text (including indentation of untouched lines) is preserved.
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
search := edit.Search
replace := edit.Replace
// Pass 1 exact substring match.
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
if edit.ReplaceAll {
return strings.ReplaceAll(content, search, replace), nil
}
count := strings.Count(content, search)
if count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
// Exactly one match.
return strings.Replace(content, search, replace, 1), nil
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search
// into lines.
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element
// from SplitAfter. Drop it so it doesn't interfere with line
// matching.
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
trimRight := func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}
trimAll := func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return "", xerrors.New("search string not found in file. Verify the search " +
"string matches the file content exactly, including whitespace " +
"and indentation")
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
@@ -659,26 +549,6 @@ outer:
return 0, 0, false
}
// countLineMatches counts how many non-overlapping contiguous
// subsequences of contentLines match searchLines according to eq.
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
count := 0
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
return count
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
count++
i += len(searchLines) - 1 // skip past this match
}
return count
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
@@ -692,3 +562,10 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+3 -207
View File
@@ -14,7 +14,6 @@ import (
"strings"
"syscall"
"testing"
"testing/iotest"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -400,83 +399,6 @@ func TestWriteFile(t *testing.T) {
}
}
func TestWriteFile_ReportsIOError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, nil)
tmpdir := os.TempDir()
path := filepath.Join(tmpdir, "write-io-error")
err := afero.WriteFile(fs, path, []byte("original"), 0o644)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// A reader that always errors simulates a failed body read
// (e.g. network interruption). The atomic write should leave
// the original file intact.
body := iotest.ErrReader(xerrors.New("simulated I/O error"))
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path), body)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
got := &codersdk.Error{}
err = json.NewDecoder(w.Body).Decode(got)
require.NoError(t, err)
require.ErrorContains(t, got, "simulated I/O error")
// The original file must survive the failed write.
data, err := afero.ReadFile(fs, path)
require.NoError(t, err)
require.Equal(t, "original", string(data))
}
func TestWriteFile_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Overwrite the file with new content.
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path),
bytes.NewReader([]byte("#!/bin/sh\necho world\n")))
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"write_file should preserve the original file's permissions")
}
func TestEditFiles(t *testing.T) {
t.Parallel()
@@ -654,9 +576,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
// When the second edit creates ambiguity (two "bar"
// occurrences), it should fail.
name: "EditEditAmbiguous",
name: "EditEdit", // Edits affect previous edits.
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
@@ -673,33 +593,7 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 2 occurrences"},
// File should not be modified on error.
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
},
{
// With replace_all the cascading edit replaces
// both occurrences.
name: "EditEditReplaceAll",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit-ra"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
},
{
name: "Multiline",
@@ -826,7 +720,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchErrors",
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
@@ -839,46 +733,9 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"search string not found in file"},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "AmbiguousExactMatch",
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ambig-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 3 occurrences"},
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
},
{
name: "ReplaceAllExact",
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
@@ -985,67 +842,6 @@ func TestEditFiles(t *testing.T) {
}
}
func TestEditFiles_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
// Sanity-check the initial mode.
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
body := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: path,
Edits: []workspacesdk.FileEdit{
{
Search: "hello",
Replace: "world",
},
},
},
},
}
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// Verify content was updated.
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
// Verify permissions are preserved after the
// temp-file-and-rename cycle.
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"edit_files should preserve the original file's permissions")
}
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
+2 -58
View File
@@ -1,13 +1,11 @@
package agentproc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -20,13 +18,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// maxWaitDuration is the maximum time a blocking
// process output request can wait, regardless of
// what the client requests.
maxWaitDuration = 5 * time.Minute
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
@@ -35,10 +26,10 @@ type API struct {
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv, workingDir),
manager: newManager(logger, execer, updateEnv),
pathStore: pathStore,
}
}
@@ -160,42 +151,6 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
return
}
// Enforce chat ID isolation. If the request carries
// a chat context, only allow access to processes
// belonging to that chat.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
if proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
// Check for blocking mode via query params.
waitStr := r.URL.Query().Get("wait")
wantWait := waitStr == "true"
if wantWait {
// Extend the write deadline so the HTTP server's
// WriteTimeout does not kill the connection while
// we block.
rc := http.NewResponseController(rw)
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil {
api.logger.Error(ctx, "extend write deadline for blocking process output",
slog.Error(err),
)
}
// Cap the wait at maxWaitDuration regardless of
// client-supplied timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, maxWaitDuration)
defer waitCancel()
_ = proc.waitForOutput(waitCtx)
// Fall through to read snapshot below.
}
output, truncated := proc.output()
info := proc.info()
@@ -213,17 +168,6 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// Enforce chat ID isolation.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
proc, procOK := api.manager.get(id)
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
+3 -277
View File
@@ -7,10 +7,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -78,22 +76,6 @@ func getOutput(t *testing.T, handler http.Handler, id string) *httptest.Response
return w
}
// getOutputWithHeaders sends a GET /{id}/output request with
// custom headers and returns the recorder.
func getOutputWithHeaders(t *testing.T, handler http.Handler, id string, headers http.Header) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
path := fmt.Sprintf("/%s/output", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
for k, v := range headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
@@ -115,25 +97,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, nil, nil)
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, updateEnv, nil)
}
// newTestAPIWithOptions creates a new API with optional
// updateEnv and workingDir hooks.
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
t.Cleanup(func() {
_ = api.Close()
})
@@ -278,100 +253,6 @@ func TestStartProcess(t *testing.T) {
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
t.Parallel()
// No working directory closure, so the process
// should fall back to $HOME. We verify through
// the process list API which reports the resolved
// working directory using native OS paths,
// avoiding shell path format mismatches on
// Windows (Git Bash returns POSIX paths).
handler := newTestAPI(t)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
t.Parallel()
// The closure provides a valid directory, so the
// process should start there. Use the marker file
// pattern to avoid path format mismatches on
// Windows.
tmpDir := t.TempDir()
handler := newTestAPIWithOptions(t, nil, func() string {
return tmpDir
})
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
t.Parallel()
// The closure returns a path that doesn't exist,
// so the process should fall back to $HOME.
handler := newTestAPIWithOptions(t, nil, func() string {
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
})
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
@@ -756,161 +637,6 @@ func TestProcessOutput(t *testing.T) {
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
t.Run("ChatIDEnforcement", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process with chat-a.
chatA := uuid.New()
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo secret",
Background: true,
}, http.Header{
workspacesdk.CoderChatIDHeader: {chatA.String()},
})
waitForExit(t, handler, id)
// Chat-b should NOT see this process.
chatB := uuid.New()
w1 := getOutputWithHeaders(t, handler, id, http.Header{
workspacesdk.CoderChatIDHeader: {chatB.String()},
})
require.Equal(t, http.StatusNotFound, w1.Code)
// Without any chat ID header, should return 200
// (backwards compatible).
w2 := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w2.Code)
})
t.Run("WaitForExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-wait && sleep 0.1",
})
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-wait")
})
t.Run("WaitAlreadyExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, id)
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.Contains(t, resp.Output, "done")
})
t.Run("WaitTimeout", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.IntervalMedium)
defer cancel()
w := getOutputWithWaitCtx(ctx, t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("ConcurrentWaiters", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
var (
wg sync.WaitGroup
resps [2]workspacesdk.ProcessOutputResponse
codes [2]int
)
for i := range 2 {
wg.Add(1)
go func() {
defer wg.Done()
w := getOutputWithWait(t, handler, id)
codes[i] = w.Code
_ = json.NewDecoder(w.Body).Decode(&resps[i])
}()
}
// Signal the process to exit so both waiters unblock.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
wg.Wait()
for i := range 2 {
require.Equal(t, http.StatusOK, codes[i], "waiter %d", i)
require.False(t, resps[i].Running, "waiter %d", i)
}
})
}
func getOutputWithWait(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
return getOutputWithWaitCtx(ctx, t, handler, id)
}
func getOutputWithWaitCtx(ctx context.Context, t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
path := fmt.Sprintf("/%s/output?wait=true", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
func TestSignalProcess(t *testing.T) {
@@ -1055,7 +781,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
return current, nil
}, pathStore, nil)
}, pathStore)
defer api.Close()
routes := api.Routes()
+2 -19
View File
@@ -39,13 +39,11 @@ const (
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
cond *sync.Cond
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
closed bool
totalBytes int
maxHead int
maxTail int
@@ -54,24 +52,20 @@ type HeadTailBuffer struct {
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// Write implements io.Writer. It is safe for concurrent use.
@@ -302,15 +296,6 @@ func truncateLines(s string) string {
return b.String()
}
// Close marks the buffer as closed and wakes any waiters.
// This is called when the process exits.
func (b *HeadTailBuffer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
b.cond.Broadcast()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
@@ -320,7 +305,5 @@ func (b *HeadTailBuffer) Reset() {
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.closed = false
b.totalBytes = 0
b.cond.Broadcast()
}
-26
View File
@@ -1,26 +0,0 @@
//go:build !windows
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Unix, Setpgid creates a new process group so
// that signals can be delivered to the entire group (the shell
// and all its children).
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// signalProcess sends a signal to the process group rooted at p.
// Using the negative PID sends the signal to every process in the
// group, ensuring child processes (e.g. from shell pipelines) are
// also signaled.
func signalProcess(p *os.Process, sig syscall.Signal) error {
return syscall.Kill(-p.Pid, sig)
}
-20
View File
@@ -1,20 +0,0 @@
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Windows, process groups are not supported in the
// same way as Unix, so this returns an empty struct.
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}
// signalProcess sends a signal directly to the process. Windows
// does not support process group signaling, so we fall back to
// sending the signal to the process itself.
func signalProcess(p *os.Process, _ syscall.Signal) error {
return p.Kill()
}
+21 -78
View File
@@ -70,25 +70,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
workingDir func() string
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
workingDir: workingDir,
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
@@ -111,9 +109,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
cmd.Dir = m.resolveWorkDir(req.WorkDir)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
@@ -158,7 +157,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc := &process{
id: id,
command: req.Command,
workDir: cmd.Dir,
workDir: req.WorkDir,
background: req.Background,
chatID: chatID,
cmd: cmd,
@@ -208,9 +207,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc.exitCode = &code
proc.mu.Unlock()
// Wake any waiters blocked on new output or
// process exit before closing the done channel.
proc.buf.Close()
close(proc.done)
}()
@@ -276,15 +272,13 @@ func (m *manager) signal(id string, sig string) error {
switch sig {
case "kill":
// Use process group kill to ensure child processes
// (e.g. from shell pipelines) are also killed.
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
// Use process group signal to ensure child processes
// are also terminated.
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
@@ -322,54 +316,3 @@ func (m *manager) Close() error {
return nil
}
// waitForOutput blocks until the buffer is closed (process
// exited) or the context is canceled. Returns nil when the
// buffer closed, ctx.Err() when the context expired.
func (p *process) waitForOutput(ctx context.Context) error {
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
nevermind := make(chan struct{})
defer close(nevermind)
go func() {
select {
case <-ctx.Done():
// Acquire the lock before broadcasting to
// guarantee the waiter has entered cond.Wait()
// (which atomically releases the lock).
// Without this, a Broadcast between the loop
// predicate check and cond.Wait() is lost.
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
p.buf.cond.Broadcast()
case <-nevermind:
}
}()
for ctx.Err() == nil && !p.buf.closed {
p.buf.cond.Wait()
}
return ctx.Err()
}
// resolveWorkDir returns the directory a process should start in.
// Priority: explicit request dir > agent configured dir > $HOME.
// Falls through when a candidate is empty or does not exist on
// disk, matching the behavior of SSH sessions.
func (m *manager) resolveWorkDir(requested string) string {
if requested != "" {
return requested
}
if m.workingDir != nil {
if dir := m.workingDir(); dir != "" {
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return dir
}
}
}
if home, err := os.UserHomeDir(); err == nil {
return home
}
return ""
}
+2 -2
View File
@@ -398,11 +398,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
},
})
if err != nil {
logger.Warn(ctx, "reporting script completed", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
}
})
if err != nil {
logger.Warn(ctx, "reporting script completed: track command goroutine", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
}
}()
+27 -6
View File
@@ -6,6 +6,7 @@ import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"github.com/google/uuid"
@@ -22,6 +23,26 @@ import (
"github.com/coder/coder/v2/testutil"
)
// logSink captures structured log entries for testing.
type logSink struct {
mu sync.Mutex
entries []slog.SinkEntry
}
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, e)
}
func (*logSink) Sync() {}
func (s *logSink) getEntries() []slog.SinkEntry {
s.mu.Lock()
defer s.mu.Unlock()
return append([]slog.SinkEntry{}, s.entries...)
}
// getField returns the value of a field by name from a slog.Map.
func getField(fields slog.Map, name string) interface{} {
for _, f := range fields {
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
sink := testutil.NewFakeSink(t)
logger := sink.Logger(slog.LevelInfo)
sink := &logSink{}
logger := slog.Make(sink)
workspaceID := uuid.New()
templateID := uuid.New()
templateVersionID := uuid.New()
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 1
return len(sink.getEntries()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
entries := sink.Entries()
entries := sink.getEntries()
require.Len(t, entries, 1)
entry := entries[0]
require.Equal(t, slog.LevelInfo, entry.Level)
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req2)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 2
return len(sink.getEntries()) >= 2
}, testutil.WaitShort, testutil.IntervalFast)
entries = sink.Entries()
entries = sink.getEntries()
entry = entries[1]
require.Len(t, entries, 2)
require.Equal(t, slog.LevelInfo, entry.Level)
-9
View File
@@ -78,9 +78,6 @@ func withDone(t *testing.T) []reaper.Option {
// processes and passes their PIDs through the shared channel.
func TestReap(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
@@ -127,9 +124,6 @@ func TestReap(t *testing.T) {
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
func TestForkReapExitCodes(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
@@ -170,9 +164,6 @@ func TestForkReapExitCodes(t *testing.T) {
// ensures SIGINT cannot kill the parent test binary.
func TestReapInterrupt(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
-15
View File
@@ -46,7 +46,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
autoUpdates string
copyParametersFrom string
useParameterDefaults bool
noWait bool
// Organization context is only required if more than 1 template
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
@@ -373,14 +372,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
if noWait {
_, _ = fmt.Fprintf(inv.Stdout,
"\nThe %s workspace has been created and is building in the background.\n",
cliui.Keyword(workspace.Name),
)
return nil
}
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil {
return xerrors.Errorf("watch build: %w", err)
@@ -454,12 +445,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Description: "Automatically accept parameter defaults when no value is provided.",
Value: serpent.BoolOf(&useParameterDefaults),
},
serpent.Option{
Flag: "no-wait",
Env: "CODER_CREATE_NO_WAIT",
Description: "Return immediately after creating the workspace. The build will run in the background.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
)
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
-75
View File
@@ -603,81 +603,6 @@ func TestCreate(t *testing.T) {
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
}
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was actually created.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
})
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
}))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--use-parameter-defaults",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was created and parameters were applied.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
})
}
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
-6
View File
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+1 -16
View File
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
Annotations struct {
ReadOnlyHint *bool `json:"readOnlyHint"`
DestructiveHint *bool `json:"destructiveHint"`
IdempotentHint *bool `json:"idempotentHint"`
OpenWorldHint *bool `json:"openWorldHint"`
} `json:"annotations"`
Name string `json:"name"`
} `json:"tools"`
} `json:"result"`
}
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
annotations := toolsResponse.Result.Tools[0].Annotations
require.NotNil(t, annotations.ReadOnlyHint)
require.NotNil(t, annotations.DestructiveHint)
require.NotNil(t, annotations.IdempotentHint)
require.NotNil(t, annotations.OpenWorldHint)
assert.True(t, *annotations.ReadOnlyHint)
assert.False(t, *annotations.DestructiveHint)
assert.True(t, *annotations.IdempotentHint)
assert.False(t, *annotations.OpenWorldHint)
// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
+45 -74
View File
@@ -1732,18 +1732,19 @@ const (
func (r *RootCmd) scaletestAutostart() *serpent.Command {
var (
workspaceCount int64
workspaceJobTimeout time.Duration
autostartBuildTimeout time.Duration
autostartDelay time.Duration
template string
noCleanup bool
workspaceCount int64
workspaceJobTimeout time.Duration
autostartDelay time.Duration
autostartTimeout time.Duration
template string
noCleanup bool
parameterFlags workspaceParameterFlags
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
@@ -1771,7 +1772,7 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("parse output flags: %w", err)
return xerrors.Errorf("could not parse --output flags")
}
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
@@ -1802,41 +1803,15 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := autostart.NewMetrics(reg)
setupBarrier := new(sync.WaitGroup)
setupBarrier.Add(int(workspaceCount))
// The workspace-build-updates experiment must be enabled to use
// the centralized pubsub channel for coordinating workspace builds.
experiments, err := client.Experiments(ctx)
if err != nil {
return xerrors.Errorf("get experiments: %w", err)
}
if !experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) {
return xerrors.New("the workspace-build-updates experiment must be enabled to run the autostart scaletest")
}
workspaceNames := make([]string, 0, workspaceCount)
resultSink := make(chan autostart.RunResult, workspaceCount)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i := range workspaceCount {
id := strconv.Itoa(int(i))
workspaceNames = append(workspaceNames, loadtestutil.GenerateDeterministicWorkspaceName(id))
}
dispatcher := autostart.NewWorkspaceDispatcher(workspaceNames)
decoder, err := client.WatchAllWorkspaceBuilds(ctx)
if err != nil {
return xerrors.Errorf("watch all workspace builds: %w", err)
}
defer decoder.Close()
// Start the dispatcher. It will run in a goroutine and automatically
// close all workspace channels when the build updates channel closes.
dispatcher.Start(ctx, decoder.Chan())
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for workspaceName, buildUpdatesChannel := range dispatcher.Channels {
id := strings.TrimPrefix(workspaceName, loadtestutil.ScaleTestPrefix+"-")
config := autostart.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
@@ -1846,16 +1821,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
// Use deterministic workspace name so we can pre-create the channel.
Name: workspaceName,
},
},
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartBuildTimeout: autostartBuildTimeout,
AutostartDelay: autostartDelay,
SetupBarrier: setupBarrier,
BuildUpdates: buildUpdatesChannel,
ResultSink: resultSink,
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartDelay: autostartDelay,
AutostartTimeout: autostartTimeout,
Metrics: metrics,
SetupBarrier: setupBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
@@ -1877,11 +1849,18 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
th.AddRun(autostartTestName, id, runner)
}
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Running autostart load test...")
@@ -1892,40 +1871,31 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// Collect all metrics from the channel.
close(resultSink)
var runResults []autostart.RunResult
for r := range resultSink {
runResults = append(runResults, r)
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
_, _ = fmt.Fprintf(inv.Stderr, "\nAll %d autostart builds completed successfully (elapsed: %s)\n", res.TotalRuns, time.Duration(res.Elapsed).Round(time.Millisecond))
if len(runResults) > 0 {
results := autostart.NewRunResults(runResults)
for _, out := range outputs {
if err := out.write(results.ToHarnessResults(), inv.Stdout); err != nil {
return xerrors.Errorf("write output: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(context.Background())
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
_, _ = fmt.Fprintln(inv.Stderr, "Cleanup complete")
} else {
_, _ = fmt.Fprintln(inv.Stderr, "\nSkipping cleanup (--no-cleanup specified). Resources left running.")
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
@@ -1948,13 +1918,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "Timeout for workspace jobs (e.g. build, start).",
Value: serpent.DurationOf(&workspaceJobTimeout),
},
{
Flag: "autostart-build-timeout",
Env: "CODER_SCALETEST_AUTOSTART_BUILD_TIMEOUT",
Default: "15m",
Description: "Timeout for the autostart build to complete. Must be longer than workspace-job-timeout to account for queueing time in high-load scenarios.",
Value: serpent.DurationOf(&autostartBuildTimeout),
},
{
Flag: "autostart-delay",
Env: "CODER_SCALETEST_AUTOSTART_DELAY",
@@ -1962,6 +1925,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "How long after all the workspaces have been stopped to schedule them to be started again.",
Value: serpent.DurationOf(&autostartDelay),
},
{
Flag: "autostart-timeout",
Env: "CODER_SCALETEST_AUTOSTART_TIMEOUT",
Default: "5m",
Description: "Timeout for the autostart build to be initiated after the scheduled start time.",
Value: serpent.DurationOf(&autostartTimeout),
},
{
Flag: "template",
FlagShorthand: "t",
@@ -1980,9 +1950,10 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
tracingFlags.attach(&cmd.Options)
output.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
+1 -1
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("create role: %w", err)
return xerrors.Errorf("patch role: %w", err)
}
}
-23
View File
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
)
build = workspace.LatestBuild
default:
// If the last build was a failed start, run a stop
// first to clean up any partially-provisioned
// resources.
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop,
})
if stopErr != nil {
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
}
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
if stopErr != nil {
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
}
// Re-fetch workspace after stop completes so
// startWorkspace sees the latest state.
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
}
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
// It's possible for a workspace build to fail due to the template requiring starting
// workspaces with the active version.
-52
View File
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
}
func TestStart_FailedStartCleansUp(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
store, ps := dbtestutil.NewDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: store,
Pubsub: ps,
IncludeProvisionerDaemon: true,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Insert a failed start build directly into the database so that
// the workspace's latest build is a failed "start" transition.
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
ID: workspace.ID,
OwnerID: member.ID,
OrganizationID: owner.OrganizationID,
TemplateID: template.ID,
}).
Seed(database.WorkspaceBuild{
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
}).
Failed().
Do()
inv, root := clitest.New(t, "start", workspace.Name)
clitest.SetupConfig(t, memberClient, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
// The CLI should detect the failed start and clean up first.
pty.ExpectMatch("Cleaning up before retrying")
pty.ExpectMatch("workspace has been started")
_ = testutil.TryReceive(ctx, t, doneChan)
}
+17 -26
View File
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
)
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
// Bypass rate limiting for support bundle collection since it makes many API calls.
// Note: this can only be done by the owner user.
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
cliLog.Debug(inv.Context(), "running as owner")
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
} else if !ok {
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
} else {
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
}
// Check if we're running inside a workspace
if val, found := os.LookupEnv("CODER"); found && val == "true" {
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
}
// Bypass rate limiting for support bundle collection since it makes many API calls.
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
deps := support.Deps{
Client: client,
// Support adds a sink so we don't need to supply one ourselves.
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
return
}
var docsURL string
if bun.Deployment.Config != nil {
docsURL = bun.Deployment.Config.Values.DocsURL.String()
} else {
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
if bun.Deployment.Config == nil {
cliui.Error(inv.Stdout, "No deployment configuration available!")
return
}
if bun.Deployment.HealthReport != nil {
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
} else {
cliui.Warn(inv.Stdout, "No deployment health report available.")
docsURL := bun.Deployment.Config.Values.DocsURL.String()
if bun.Deployment.HealthReport == nil {
cliui.Error(inv.Stdout, "No deployment health report available!")
return
}
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
if bun.Network.Netcheck == nil {
+22 -47
View File
@@ -28,9 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -52,21 +50,9 @@ func TestSupportBundle(t *testing.T) {
dc.Values.Prometheus.Enable = true
secretValue := uuid.NewString()
seedSecretDeploymentOptions(t, &dc, secretValue)
// Use a mock healthcheck function to avoid flaky DERP health
// checks in CI. The DERP checker performs real network operations
// (portmapper gateway probing, STUN) that can hang for 60s+ on
// macOS CI runners. Since this test validates support bundle
// generation, not healthcheck correctness, a canned report is
// sufficient.
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dc.Values,
HealthcheckFunc: func(_ context.Context, _ string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
return &healthsdk.HealthcheckReport{
Time: time.Now(),
Healthy: true,
Severity: health.SeverityOK,
}
},
DeploymentValues: dc.Values,
HealthcheckTimeout: testutil.WaitSuperLong,
})
t.Cleanup(func() { closer.Close() })
@@ -74,7 +60,7 @@ func TestSupportBundle(t *testing.T) {
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Set up test fixtures
setupCtx := testutil.Context(t, testutil.WaitLong)
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
workspaceWithAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, func(agents []*proto.Agent) []*proto.Agent {
// This should not show up in the bundle output
agents[0].Env["SECRET_VALUE"] = secretValue
@@ -83,6 +69,22 @@ func TestSupportBundle(t *testing.T) {
workspaceWithoutAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, nil)
memberWorkspace := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, member.ID, nil)
// Wait for healthcheck to complete successfully before continuing with sub-tests.
// The result is cached so subsequent requests will be fast.
healthcheckDone := make(chan *healthsdk.HealthcheckReport)
go func() {
defer close(healthcheckDone)
hc, err := healthsdk.New(client).DebugHealth(setupCtx)
if err != nil {
assert.NoError(t, err, "seed healthcheck cache")
return
}
healthcheckDone <- &hc
}()
if _, ok := testutil.AssertReceive(setupCtx, t, healthcheckDone); !ok {
t.Fatal("healthcheck did not complete in time -- this may be a transient issue")
}
t.Run("WorkspaceWithAgent", func(t *testing.T) {
t.Parallel()
@@ -130,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
assertBundleContents(t, path, true, false, []string{secretValue})
})
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
t.Run("NoPrivilege", func(t *testing.T) {
t.Parallel()
d := t.TempDir()
path := filepath.Join(d, "bundle.zip")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
clitest.SetupConfig(t, memberClient, root)
err := inv.Run()
require.NoError(t, err)
r, err := zip.OpenReader(path)
require.NoError(t, err, "open zip file")
defer r.Close()
fileNames := make(map[string]struct{}, len(r.File))
for _, f := range r.File {
fileNames[f.Name] = struct{}{}
}
// These should always be present in the zip structure, even if
// the content is null/empty for non-admin users.
for _, name := range []string{
"deployment/buildinfo.json",
"deployment/config.json",
"workspace/workspace.json",
"logs.txt",
"cli_logs.txt",
"network/netcheck.json",
"network/interfaces.json",
} {
require.Contains(t, fileNames, name)
}
require.ErrorContains(t, err, "failed authorization check")
})
// This ensures that the CLI does not panic when trying to generate a support bundle
@@ -180,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("received request: %s %s", r.Method, r.URL)
switch r.URL.Path {
case "/api/v2/users/me":
resp := codersdk.User{}
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
case "/api/v2/authcheck":
// Fake auth check
resp := codersdk.AuthorizationResponse{
-4
View File
@@ -20,10 +20,6 @@ OPTIONS:
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
Specify the source workspace name to copy parameters from.
--no-wait bool, $CODER_CREATE_NO_WAIT
Return immediately after creating the workspace. The build will run in
the background.
--parameter string-array, $CODER_RICH_PARAMETER
Rich parameter value in the format "name=value".
+1 -1
View File
@@ -7,7 +7,7 @@
"last_seen_at": "====[timestamp]=====",
"name": "test-daemon",
"version": "v0.0.0-devel",
"api_version": "1.16",
"api_version": "1.15",
"provisioners": [
"echo"
],
-6
View File
@@ -170,12 +170,6 @@ AI BRIDGE OPTIONS:
exporting these records to external SIEM or observability systems.
AI BRIDGE PROXY OPTIONS:
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
Comma-separated list of CIDR ranges that are permitted even though
they fall within blocked private/reserved IP ranges. By default all
private ranges are blocked to prevent SSRF attacks. Use this to allow
access to specific internal networks.
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
provider requests.
+10 -11
View File
@@ -8,17 +8,16 @@ USAGE:
Aliases: user
SUBCOMMANDS:
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
oidc-claims Display the OIDC claims for the authenticated user.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user
cannot log into the platform
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user cannot
log into the platform
———
Run `coder --help` for a list of global options.
-4
View File
@@ -24,10 +24,6 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
-24
View File
@@ -1,24 +0,0 @@
coder v0.0.0-devel
USAGE:
coder users oidc-claims [flags]
Display the OIDC claims for the authenticated user.
- Display your OIDC claims:
$ coder users oidc-claims
- Display your OIDC claims as JSON:
$ coder users oidc-claims -o json
OPTIONS:
-c, --column [key|value] (default: key,value)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
-11
View File
@@ -752,11 +752,6 @@ workspace_prebuilds:
# limit; disabled when set to zero.
# (default: 3, type: int)
failure_hard_limit: 3
# Configure the background chat processing daemon.
chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
@@ -873,12 +868,6 @@ aibridgeproxy:
# by the system. If not provided, the system certificate pool is used.
# (default: <unset>, type: string)
upstream_proxy_ca: ""
# Comma-separated list of CIDR ranges that are permitted even though they fall
# within blocked private/reserved IP ranges. By default all private ranges are
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
# networks.
# (default: <unset>, type: string-array)
allowed_private_cidrs: []
# Configure data retention policies for various database tables. Retention
# policies automatically purge old data to reduce database size and improve
# performance. Setting a retention duration to 0 disables automatic purging for
+12 -37
View File
@@ -17,14 +17,13 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" && !serviceAccount {
if email == "" {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin || serviceAccount {
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
-53
View File
@@ -8,7 +8,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
-79
View File
@@ -1,79 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) userOIDCClaims() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.ChangeFormatterData(
cliui.TableFormat([]claimRow{}, []string{"key", "value"}),
func(data any) (any, error) {
resp, ok := data.(codersdk.OIDCClaimsResponse)
if !ok {
return nil, xerrors.Errorf("expected type %T, got %T", resp, data)
}
rows := make([]claimRow, 0, len(resp.Claims))
for k, v := range resp.Claims {
rows = append(rows, claimRow{
Key: k,
Value: fmt.Sprintf("%v", v),
})
}
return rows, nil
},
),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "oidc-claims",
Short: "Display the OIDC claims for the authenticated user.",
Long: FormatExamples(
Example{
Description: "Display your OIDC claims",
Command: "coder users oidc-claims",
},
Example{
Description: "Display your OIDC claims as JSON",
Command: "coder users oidc-claims -o json",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
resp, err := client.UserOIDCClaims(inv.Context())
if err != nil {
return xerrors.Errorf("get oidc claims: %w", err)
}
out, err := formatter.Format(inv.Context(), resp)
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
type claimRow struct {
Key string `json:"-" table:"key,default_sort"`
Value string `json:"-" table:"value"`
}
-161
View File
@@ -1,161 +0,0 @@
package cli_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestUserOIDCClaims(t *testing.T) {
t.Parallel()
newOIDCTest := func(t *testing.T) (*oidctest.FakeIDP, *codersdk.Client) {
t.Helper()
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
ownerClient := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
return fake, ownerClient
}
t.Run("OwnClaims", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
"groups": []string{"admin", "eng"},
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims", "-o", "json")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
var resp codersdk.OIDCClaimsResponse
err = json.Unmarshal(buf.Bytes(), &resp)
require.NoError(t, err, "unmarshal JSON output")
require.NotEmpty(t, resp.Claims, "claims should not be empty")
assert.Equal(t, "alice@coder.com", resp.Claims["email"])
})
t.Run("Table", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "bob@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
output := buf.String()
require.Contains(t, output, "email")
require.Contains(t, output, "bob@coder.com")
})
t.Run("NotOIDCUser", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, client, root)
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not an OIDC user")
})
// Verify that two different OIDC users each only see their own
// claims. The endpoint has no user parameter, so there is no way
// to request another user's claims by design.
t.Run("OnlyOwnClaims", func(t *testing.T) {
t.Parallel()
aliceFake, aliceOwnerClient := newOIDCTest(t)
aliceClaims := jwt.MapClaims{
"email": "alice-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
aliceClient, aliceLoginResp := aliceFake.Login(t, aliceOwnerClient, aliceClaims)
defer aliceLoginResp.Body.Close()
bobFake, bobOwnerClient := newOIDCTest(t)
bobClaims := jwt.MapClaims{
"email": "bob-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
bobClient, bobLoginResp := bobFake.Login(t, bobOwnerClient, bobClaims)
defer bobLoginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
// Alice sees her own claims.
aliceResp, err := aliceClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "alice-isolation@coder.com", aliceResp.Claims["email"])
// Bob sees his own claims.
bobResp, err := bobClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "bob-isolation@coder.com", bobResp.Claims["email"])
})
t.Run("ClaimsNeverNull", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
// Use minimal claims — just enough for OIDC login.
claims := jwt.MapClaims{
"email": "minimal@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.UserOIDCClaims(ctx)
require.NoError(t, err)
require.NotNil(t, resp.Claims, "claims should never be nil, expected empty map")
})
}
-1
View File
@@ -19,7 +19,6 @@ func (r *RootCmd) users() *serpent.Command {
r.userSingle(),
r.userDelete(),
r.userEditRoles(),
r.userOIDCClaims(),
r.createUserStatusCommand(codersdk.UserStatusActive),
r.createUserStatusCommand(codersdk.UserStatusSuspended),
},
-38
View File
@@ -1,38 +0,0 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+1721 -2063
View File
File diff suppressed because it is too large Load Diff
+1719 -2039
View File
File diff suppressed because it is too large Load Diff
+1 -2
View File
@@ -32,8 +32,7 @@ type Auditable interface {
idpsync.OrganizationSyncSettings |
idpsync.GroupSyncSettings |
idpsync.RoleSyncSettings |
database.TaskTable |
database.AiSeatState
database.TaskTable
}
// Map is a map of changed fields in an audited resource. It maps field names to
-8
View File
@@ -132,8 +132,6 @@ func ResourceTarget[T Auditable](tgt T) string {
return "Organization Role Sync"
case database.TaskTable:
return typed.Name
case database.AiSeatState:
return "AI Seat"
default:
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
}
@@ -198,8 +196,6 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
return noID // Org field on audit log has org id
case database.TaskTable:
return typed.ID
case database.AiSeatState:
return typed.UserID
default:
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
}
@@ -255,8 +251,6 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
return database.ResourceTypeIdpSyncSettingsGroup
case database.TaskTable:
return database.ResourceTypeTask
case database.AiSeatState:
return database.ResourceTypeAiSeat
default:
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
}
@@ -315,8 +309,6 @@ func ResourceRequiresOrgID[T Auditable]() bool {
return true
case database.TaskTable:
return true
case database.AiSeatState:
return false
default:
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
}
+559 -1162
View File
File diff suppressed because it is too large Load Diff
-534
View File
@@ -2,26 +2,13 @@ package chatd
import (
"context"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -97,524 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).Times(1)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
nil,
"",
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
).Times(1)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionCache: make(map[uuid.UUID]cachedInstruction),
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction := server.resolveInstructions(
ctx,
chat,
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
localMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishMessage(chatID, localMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
cachedMessage := codersdk.ChatMessage{
ID: 2,
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &cachedMessage,
})
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
catchupMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
}).Return([]database.ChatMessage{catchupMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
editedMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{editedMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishEditedMessage(chatID, editedMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(1), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
retryingAt := time.Unix(1_700_000_000, 0).UTC()
expected := &codersdk.ChatStreamRetry{
Attempt: 1,
DelayMs: (1500 * time.Millisecond).Milliseconds(),
Error: "rate limit exceeded",
RetryingAt: retryingAt,
}
server.publishRetry(chatID, expected)
event := requireStreamRetryEvent(t, events)
require.Equal(t, expected, event.Retry)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
return &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
pubsub: dbpubsub.NewInMemory(),
}
}
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
require.NotNil(t, event.Message)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream message event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
require.NotNil(t, event.Retry)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream retry event")
return codersdk.ChatStreamEvent{}
}
}
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
t.Helper()
select {
case event, ok := <-events:
if !ok {
t.Fatal("chat stream closed unexpectedly")
}
t.Fatalf("unexpected chat stream event: %+v", event)
case <-time.After(wait):
}
}
// TestPublishToStream_DropWarnRateLimiting walks through a
// realistic lifecycle: buffer fills up, subscriber channel fills
// up, counters get reset between steps. It verifies that WARN
// logs are rate-limited to at most once per streamDropWarnInterval
// and that counter resets re-enable an immediate WARN.
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
mClock := quartz.NewMock(t)
server := &Server{
logger: sink.Logger(),
clock: mClock,
}
chatID := uuid.New()
subCh := make(chan codersdk.ChatStreamEvent, 1)
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
// Set up state that mirrors a running chat: buffer at capacity,
// buffering enabled, one saturated subscriber.
state := &chatStreamState{
buffering: true,
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): subCh,
},
}
server.chatStreams.Store(chatID, state)
bufferMsg := "chat stream buffer full, dropping oldest event"
subMsg := "dropping chat stream event"
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
return func(e slog.SinkEntry) bool {
return e.Level == level && e.Message == msg
}
}
// --- Phase 1: buffer-full rate limiting ---
// message_part events hit both the buffer-full and subscriber-full
// paths. The first publish triggers a WARN for each; the rest
// within the window are DEBUG.
partEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
for i := 0; i < 50; i++ {
server.publishToStream(chatID, partEvent)
}
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
// Subscriber also saw 50 drops (one per publish).
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
// --- Phase 2: clock advance triggers second WARN with count ---
mClock.Advance(streamDropWarnInterval + time.Second)
server.publishToStream(chatID, partEvent)
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 2)
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 2)
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
// --- Phase 3: counter reset (simulates step persist) ---
state.mu.Lock()
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
state.resetDropCounters()
state.mu.Unlock()
// The very next drop should WARN immediately — the reset zeroed
// lastWarnAt so the interval check passes.
server.publishToStream(chatID, partEvent)
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
t.Helper()
for _, f := range entry.Fields {
if f.Name == name {
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
return
}
}
t.Fatalf("field %q not found in log entry", name)
}
+132 -1719
View File
File diff suppressed because it is too large Load Diff
+6 -32
View File
@@ -42,11 +42,6 @@ type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
}
// RunOptions configures a single streaming chat loop run.
@@ -127,7 +122,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
switch c.GetType() {
case fantasy.ContentTypeText:
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
if !ok || strings.TrimSpace(text.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.TextPart{
@@ -136,7 +131,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
})
case fantasy.ContentTypeReasoning:
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
if !ok || strings.TrimSpace(reasoning.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ReasoningPart{
@@ -265,7 +260,6 @@ func Run(ctx context.Context, opts RunOptions) error {
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
stepStart := time.Now()
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -371,7 +365,6 @@ func Run(ctx context.Context, opts RunOptions) error {
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -617,12 +610,10 @@ func processStepStream(
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: the stream may surface the
// cancel as context.Canceled or propagate the
// ErrInterrupted cause directly, depending on
// the provider implementation.
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
@@ -640,23 +631,6 @@ func processStepStream(
}
}
// The stream iterator may stop yielding parts without
// producing a StreamPartTypeError when the context is
// canceled (e.g. some providers close the response body
// silently). Detect this case and flush partial content
// so that persistInterruptedStep can save it.
if ctx.Err() != nil &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
-81
View File
@@ -7,7 +7,6 @@ import (
"strings"
"sync"
"testing"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
@@ -65,8 +64,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.Greater(t, persistedStep.Runtime, time.Duration(0),
"step runtime should be positive")
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
@@ -578,84 +575,6 @@ func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *test
assert.False(t, localTR.ProviderExecuted)
}
func TestToResponseMessages_FiltersEmptyTextAndReasoningParts(t *testing.T) {
t.Parallel()
sr := stepResult{
content: []fantasy.Content{
// Empty text — should be filtered.
fantasy.TextContent{Text: ""},
// Whitespace-only text — should be filtered.
fantasy.TextContent{Text: " \t\n"},
// Empty reasoning — should be filtered.
fantasy.ReasoningContent{Text: ""},
// Whitespace-only reasoning — should be filtered.
fantasy.ReasoningContent{Text: " \n"},
// Non-empty text — should pass through.
fantasy.TextContent{Text: "hello world"},
// Leading/trailing whitespace with content — kept
// with the original value (not trimmed).
fantasy.TextContent{Text: " hello "},
// Non-empty reasoning — should pass through.
fantasy.ReasoningContent{Text: "let me think"},
// Tool call — should be unaffected by filtering.
fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Input: `{"path":"main.go"}`,
},
// Local tool result — should be unaffected by filtering.
fantasy.ToolResultContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Result: fantasy.ToolResultOutputContentText{Text: "file contents"},
},
},
}
msgs := sr.toResponseMessages()
require.Len(t, msgs, 2, "expected assistant + tool messages")
// First message: assistant role with non-empty text, reasoning,
// and the tool call. The four empty/whitespace-only parts must
// have been dropped.
assistantMsg := msgs[0]
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
require.Len(t, assistantMsg.Content, 4,
"assistant message should have 2x TextPart, ReasoningPart, and ToolCallPart")
// Part 0: non-empty text.
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[0])
require.True(t, ok, "part 0 should be TextPart")
assert.Equal(t, "hello world", textPart.Text)
// Part 1: padded text — original whitespace preserved.
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[1])
require.True(t, ok, "part 1 should be TextPart")
assert.Equal(t, " hello ", paddedPart.Text)
// Part 2: non-empty reasoning.
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](assistantMsg.Content[2])
require.True(t, ok, "part 2 should be ReasoningPart")
assert.Equal(t, "let me think", reasoningPart.Text)
// Part 3: tool call (unaffected by text/reasoning filtering).
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[3])
require.True(t, ok, "part 3 should be ToolCallPart")
assert.Equal(t, "tc-1", toolCallPart.ToolCallID)
assert.Equal(t, "read_file", toolCallPart.ToolName)
// Second message: tool role with the local tool result.
toolMsg := msgs[1]
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1,
"tool message should have only the local ToolResultPart")
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "tool part should be ToolResultPart")
assert.Equal(t, "tc-1", toolResultPart.ToolCallID)
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
@@ -1,399 +0,0 @@
package chatloop
import (
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testProviderData implements fantasy.ProviderOptionsData so we can
// construct arbitrary ProviderMetadata for extractContextLimit tests.
type testProviderData struct {
data map[string]any
}
func (*testProviderData) Options() {}
func (d *testProviderData) MarshalJSON() ([]byte, error) {
return json.Marshal(d.data)
}
// Required by the ProviderOptionsData interface; unused in tests.
func (d *testProviderData) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, &d.data)
}
func TestNormalizeMetadataKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
want string
}{
{name: "lowercase", key: "camelCase", want: "camelcase"},
{name: "hyphens stripped", key: "kebab-case", want: "kebabcase"},
{name: "underscores stripped", key: "snake_case", want: "snakecase"},
{name: "uppercase", key: "UPPER", want: "upper"},
{name: "spaces stripped", key: "with spaces", want: "withspaces"},
{name: "empty", key: "", want: ""},
{name: "digits preserved", key: "123", want: "123"},
{name: "mixed separators", key: "Max_Context-Tokens", want: "maxcontexttokens"},
{name: "dots stripped", key: "context.limit", want: "contextlimit"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := normalizeMetadataKey(tt.key)
require.Equal(t, tt.want, got)
})
}
}
func TestIsContextLimitKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
want bool
skip bool
}{ // Exact matches after normalization.
{name: "context_limit", key: "context_limit", want: true},
{name: "context_window", key: "context_window", want: true},
{name: "context_length", key: "context_length", want: true},
{name: "max_context", key: "max_context", want: true},
{name: "max_context_tokens", key: "max_context_tokens", want: true},
{name: "max_input_tokens", key: "max_input_tokens", want: true},
{name: "max_input_token", key: "max_input_token", want: true},
{name: "input_token_limit", key: "input_token_limit", want: true},
// Case and separator variations.
{name: "Context-Window mixed case", key: "Context-Window", want: true},
{name: "MAX_CONTEXT_TOKENS screaming", key: "MAX_CONTEXT_TOKENS", want: true},
{name: "contextLimit camelCase", key: "contextLimit", want: true},
// Fallback heuristic: contains "context" + limit/window/length.
{name: "model_context_limit", key: "model_context_limit", want: true},
{name: "context_window_size", key: "context_window_size", want: true},
{name: "context_length_max", key: "context_length_max", want: true},
// Fallback heuristic: starts with "max" + contains "context".
// BUG(isContextLimitKey): "max_context_version" matches
// because it contains "context" and starts with "max",
// but a version field is not a context limit.
// TODO: Fix the heuristic and remove this skip.
{name: "max_context_version false positive", key: "max_context_version", want: false, skip: true}, // Non-matching keys.
{name: "context_id no limit keyword", key: "context_id", want: false},
{name: "empty string", key: "", want: false},
{name: "unrelated key", key: "model_name", want: false},
{name: "limit without context", key: "rate_limit", want: false},
{name: "max without context", key: "max_tokens", want: false},
{name: "context alone", key: "context", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if tt.skip {
t.Skip("known bug: isContextLimitKey false positive")
}
got := isContextLimitKey(tt.key)
require.Equal(t, tt.want, got)
})
}
}
func TestNumericContextLimitValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value any
want int64
wantOK bool
}{
// float64: the default numeric type from json.Unmarshal.
{name: "float64 integer", value: float64(128000), want: 128000, wantOK: true},
{name: "float64 fractional rejected", value: float64(128000.5), want: 0, wantOK: false},
{name: "float64 zero rejected", value: float64(0), want: 0, wantOK: false},
{name: "float64 negative rejected", value: float64(-1), want: 0, wantOK: false},
// int64
{name: "int64 positive", value: int64(200000), want: 200000, wantOK: true},
{name: "int64 zero rejected", value: int64(0), want: 0, wantOK: false},
{name: "int64 negative rejected", value: int64(-1), want: 0, wantOK: false},
// int32
{name: "int32 positive", value: int32(50000), want: 50000, wantOK: true},
{name: "int32 zero rejected", value: int32(0), want: 0, wantOK: false},
// int
{name: "int positive", value: int(50000), want: 50000, wantOK: true},
{name: "int zero rejected", value: int(0), want: 0, wantOK: false},
// string
{name: "string numeric", value: "128000", want: 128000, wantOK: true},
{name: "string trimmed", value: " 128000 ", want: 128000, wantOK: true},
{name: "string non-numeric rejected", value: "not a number", want: 0, wantOK: false},
{name: "string empty rejected", value: "", want: 0, wantOK: false},
{name: "string zero rejected", value: "0", want: 0, wantOK: false},
{name: "string negative rejected", value: "-1", want: 0, wantOK: false},
// json.Number
{name: "json.Number valid", value: json.Number("200000"), want: 200000, wantOK: true},
{name: "json.Number invalid rejected", value: json.Number("invalid"), want: 0, wantOK: false},
{name: "json.Number zero rejected", value: json.Number("0"), want: 0, wantOK: false},
// Unhandled types.
{name: "bool rejected", value: true, want: 0, wantOK: false},
{name: "nil rejected", value: nil, want: 0, wantOK: false},
{name: "slice rejected", value: []int{1}, want: 0, wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := numericContextLimitValue(tt.value)
require.Equal(t, tt.wantOK, ok)
require.Equal(t, tt.want, got)
})
}
}
func TestPositiveInt64(t *testing.T) {
t.Parallel()
got, ok := positiveInt64(42)
require.True(t, ok)
require.Equal(t, int64(42), got)
got, ok = positiveInt64(0)
require.False(t, ok)
require.Equal(t, int64(0), got)
got, ok = positiveInt64(-1)
require.False(t, ok)
require.Equal(t, int64(0), got)
}
func TestCollectContextLimitValues(t *testing.T) {
t.Parallel()
t.Run("FlatMap", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"context_limit": float64(200000),
"other_key": float64(999),
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{200000}, collected)
})
t.Run("NestedMaps", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"provider": map[string]any{
"info": map[string]any{
"context_window": float64(100000),
},
},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{100000}, collected)
})
t.Run("ArrayTraversal", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"context_limit": float64(50000)},
map[string]any{"context_limit": float64(80000)},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Len(t, collected, 2)
require.Contains(t, collected, int64(50000))
require.Contains(t, collected, int64(80000))
})
t.Run("MixedNesting", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"models": []any{
map[string]any{
"context_limit": float64(128000),
},
},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{128000}, collected)
})
t.Run("NonMatchingKey", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"model_name": "gpt-4",
"tokens": float64(1000),
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Empty(t, collected)
})
t.Run("ScalarIgnored", func(t *testing.T) {
t.Parallel()
var collected []int64
collectContextLimitValues("just a string", func(v int64) {
collected = append(collected, v)
})
require.Empty(t, collected)
})
}
func TestFindContextLimitValue(t *testing.T) {
t.Parallel()
t.Run("SingleCandidate", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"context_limit": float64(200000),
}
limit, ok := findContextLimitValue(input)
require.True(t, ok)
require.Equal(t, int64(200000), limit)
})
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"a": map[string]any{"context_limit": float64(50000)},
"b": map[string]any{"context_limit": float64(200000)},
}
limit, ok := findContextLimitValue(input)
require.True(t, ok)
require.Equal(t, int64(200000), limit)
})
t.Run("NoCandidates", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"model": "gpt-4",
}
_, ok := findContextLimitValue(input)
require.False(t, ok)
})
t.Run("NilInput", func(t *testing.T) {
t.Parallel()
_, ok := findContextLimitValue(nil)
require.False(t, ok)
})
}
func TestExtractContextLimit(t *testing.T) {
t.Parallel()
t.Run("AnthropicStyle", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"anthropic": &testProviderData{
data: map[string]any{
"cache_read_input_tokens": float64(100),
"context_limit": float64(200000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(200000), result.Int64)
})
t.Run("OpenAIStyle", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"openai": &testProviderData{
data: map[string]any{
"max_context_tokens": float64(128000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(128000), result.Int64)
})
t.Run("NestedDeeply", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"provider": &testProviderData{
data: map[string]any{
"info": map[string]any{
"context_window": float64(100000),
},
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(100000), result.Int64)
})
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"a": &testProviderData{
data: map[string]any{
"context_limit": float64(50000),
},
},
"b": &testProviderData{
data: map[string]any{
"context_limit": float64(200000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(200000), result.Int64)
})
t.Run("NoMatchingKeys", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"openai": &testProviderData{
data: map[string]any{
"model": "gpt-4",
"tokens": float64(1000),
},
},
}
result := extractContextLimit(metadata)
assert.False(t, result.Valid)
})
t.Run("NilMetadata", func(t *testing.T) {
t.Parallel()
result := extractContextLimit(nil)
assert.False(t, result.Valid)
})
t.Run("EmptyMetadata", func(t *testing.T) {
t.Parallel()
result := extractContextLimit(fantasy.ProviderMetadata{})
assert.False(t, result.Valid)
})
}
+3 -215
View File
@@ -139,13 +139,9 @@ func ConvertMessagesWithFiles(
},
})
case codersdk.ChatMessageRoleUser:
userParts := partsToMessageParts(logger, pm.parts, resolved)
if len(userParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: userParts,
Content: partsToMessageParts(logger, pm.parts, resolved),
})
case codersdk.ChatMessageRoleAssistant:
fantasyParts := normalizeAssistantToolCallInputs(
@@ -157,9 +153,6 @@ func ConvertMessagesWithFiles(
}
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
}
if len(fantasyParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: fantasyParts,
@@ -173,13 +166,9 @@ func ConvertMessagesWithFiles(
}
}
}
toolParts := partsToMessageParts(logger, pm.parts, resolved)
if len(toolParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: toolParts,
Content: partsToMessageParts(logger, pm.parts, resolved),
})
}
}
@@ -332,7 +321,6 @@ func parseContentV1(role codersdk.ChatMessageRole, raw pqtype.NullRawMessage) ([
if err := json.Unmarshal(raw.RawMessage, &parts); err != nil {
return nil, xerrors.Errorf("parse %s content: %w", role, err)
}
decodeNulInParts(parts)
return parts, nil
}
@@ -1030,16 +1018,11 @@ func sanitizeToolCallID(id string) string {
}
// MarshalParts encodes SDK chat message parts for persistence.
// NUL characters in string fields are encoded as PUA sentinel
// pairs (U+E000 U+E001) before marshaling so the resulting JSON
// never contains \u0000 (rejected by PostgreSQL jsonb). The
// encoding operates on Go string values, not JSON bytes, so it
// survives jsonb text normalization.
func MarshalParts(parts []codersdk.ChatMessagePart) (pqtype.NullRawMessage, error) {
if len(parts) == 0 {
return pqtype.NullRawMessage{}, nil
}
data, err := json.Marshal(encodeNulInParts(parts))
data, err := json.Marshal(parts)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode chat message parts: %w", err)
}
@@ -1186,23 +1169,11 @@ func partsToMessageParts(
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeText:
// Anthropic rejects empty text content blocks with
// "text content blocks must be non-empty". Empty parts
// can arise when a stream sends TextStart/TextEnd with
// no delta in between. We filter them here rather than
// at persistence time to preserve the raw record.
if strings.TrimSpace(part.Text) == "" {
continue
}
result = append(result, fantasy.TextPart{
Text: part.Text,
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
})
case codersdk.ChatMessagePartTypeReasoning:
// Same guard as text parts above.
if strings.TrimSpace(part.Text) == "" {
continue
}
result = append(result, fantasy.ReasoningPart{
Text: part.Text,
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
@@ -1245,186 +1216,3 @@ func partsToMessageParts(
}
return result
}
// encodeNulInString replaces NUL (U+0000) characters in s with
// the sentinel pair U+E000 U+E001, and doubles any pre-existing
// U+E000 to U+E000 U+E000 so the encoding is reversible.
// Operates on Unicode code points, not JSON escape sequences,
// making it safe through jsonb round-trips (jsonb stores parsed
// characters, not original escape text).
func encodeNulInString(s string) string {
if !strings.ContainsRune(s, 0) && !strings.ContainsRune(s, '\uE000') {
return s
}
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
switch r {
case '\uE000':
_, _ = b.WriteRune('\uE000')
_, _ = b.WriteRune('\uE000')
case 0:
_, _ = b.WriteRune('\uE000')
_, _ = b.WriteRune('\uE001')
default:
_, _ = b.WriteRune(r)
}
}
return b.String()
}
// decodeNulInString reverses encodeNulInString: U+E000 U+E000
// becomes U+E000, and U+E000 U+E001 becomes NUL.
func decodeNulInString(s string) string {
if !strings.ContainsRune(s, '\uE000') {
return s
}
var b strings.Builder
b.Grow(len(s))
runes := []rune(s)
for i := 0; i < len(runes); i++ {
if runes[i] == '\uE000' && i+1 < len(runes) {
switch runes[i+1] {
case '\uE000':
_, _ = b.WriteRune('\uE000')
i++
case '\uE001':
_, _ = b.WriteRune(0)
i++
default:
// Unpaired sentinel — preserve as-is.
_, _ = b.WriteRune(runes[i])
}
} else {
_, _ = b.WriteRune(runes[i])
}
}
return b.String()
}
// encodeNulInValue recursively walks a JSON value (as produced
// by json.Unmarshal with UseNumber) and applies
// encodeNulInString to every string, including map keys.
func encodeNulInValue(v any) any {
switch val := v.(type) {
case string:
return encodeNulInString(val)
case map[string]any:
out := make(map[string]any, len(val))
for k, elem := range val {
out[encodeNulInString(k)] = encodeNulInValue(elem)
}
return out
case []any:
out := make([]any, len(val))
for i, elem := range val {
out[i] = encodeNulInValue(elem)
}
return out
default:
return v // numbers, bools, nil
}
}
// decodeNulInValue recursively walks a JSON value and applies
// decodeNulInString to every string, including map keys.
func decodeNulInValue(v any) any {
switch val := v.(type) {
case string:
return decodeNulInString(val)
case map[string]any:
out := make(map[string]any, len(val))
for k, elem := range val {
out[decodeNulInString(k)] = decodeNulInValue(elem)
}
return out
case []any:
out := make([]any, len(val))
for i, elem := range val {
out[i] = decodeNulInValue(elem)
}
return out
default:
return v
}
}
// encodeNulInJSON walks all string values (and keys) inside a
// json.RawMessage and applies encodeNulInString. Returns the
// original unchanged when the raw message does not contain NUL
// escapes or U+E000 bytes, or when parsing fails.
func encodeNulInJSON(raw json.RawMessage) json.RawMessage {
if len(raw) == 0 {
return raw
}
// Quick exit: no \u0000 escape and no U+E000 UTF-8 bytes.
if !bytes.Contains(raw, []byte(`\u0000`)) &&
!bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
return raw
}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
var v any
if err := dec.Decode(&v); err != nil {
return raw
}
result, err := json.Marshal(encodeNulInValue(v))
if err != nil {
return raw
}
return result
}
// decodeNulInJSON walks all string values (and keys) inside a
// json.RawMessage and applies decodeNulInString.
func decodeNulInJSON(raw json.RawMessage) json.RawMessage {
if len(raw) == 0 {
return raw
}
// U+E000 encoded as UTF-8 is 0xEE 0x80 0x80.
if !bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
return raw
}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
var v any
if err := dec.Decode(&v); err != nil {
return raw
}
result, err := json.Marshal(decodeNulInValue(v))
if err != nil {
return raw
}
return result
}
// encodeNulInParts returns a shallow copy of parts with all
// string and json.RawMessage fields NUL-encoded. The caller's
// slice is not modified.
func encodeNulInParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
encoded := make([]codersdk.ChatMessagePart, len(parts))
copy(encoded, parts)
for i := range encoded {
p := &encoded[i]
p.Text = encodeNulInString(p.Text)
p.Content = encodeNulInString(p.Content)
p.Args = encodeNulInJSON(p.Args)
p.ArgsDelta = encodeNulInString(p.ArgsDelta)
p.Result = encodeNulInJSON(p.Result)
p.ResultDelta = encodeNulInString(p.ResultDelta)
}
return encoded
}
// decodeNulInParts reverses encodeNulInParts in place.
func decodeNulInParts(parts []codersdk.ChatMessagePart) {
for i := range parts {
p := &parts[i]
p.Text = decodeNulInString(p.Text)
p.Content = decodeNulInString(p.Content)
p.Args = decodeNulInJSON(p.Args)
p.ArgsDelta = decodeNulInString(p.ArgsDelta)
p.Result = decodeNulInJSON(p.Result)
p.ResultDelta = decodeNulInString(p.ResultDelta)
}
}
-327
View File
@@ -17,10 +17,7 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// testMsg builds a database.ChatMessage for ParseContent tests.
@@ -1444,327 +1441,3 @@ func extractToolResultIDs(t *testing.T, msgs ...fantasy.Message) []string {
}
return ids
}
func TestNulEscapeRoundTrip(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
// Seed minimal dependencies for the DB round-trip path:
// user, provider, model config, chat.
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "openai",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "nul-roundtrip-test",
})
require.NoError(t, err)
textTests := []struct {
name string
input string
hasNul bool // Whether the input contains actual NUL bytes.
}{
// --- basic ---
{"NoNul", "hello world", false},
{"SingleNul", "a\x00b", true},
{"MultipleNuls", "a\x00b\x00c", true},
{"ConsecutiveNuls", "\x00\x00\x00", true},
// --- boundaries ---
{"EmptyString", "", false},
{"NulOnly", "\x00", true},
{"NulAtStart", "\x00hello", true},
{"NulAtEnd", "hello\x00", true},
// --- sentinel / marker in original data ---
// U+E000 is the sentinel character. The encoder must
// double it so it round-trips without being mistaken
// for an encoded NUL.
{"SentinelInOriginal", "a\uE000b", false},
{"ConsecutiveSentinels", "\uE000\uE000\uE000", false},
// U+E001 is the marker character used in the NUL pair.
{"MarkerCharInOriginal", "a\uE001b", false},
// U+E000 followed by U+E001 looks exactly like an
// encoded NUL in the encoded form, so the encoder must
// double the U+E000 to avoid confusion.
{"SentinelThenMarkerChar", "\uE000\uE001", false},
{"NulAndSentinel", "a\x00b\uE000c", true},
// Both orders: sentinel adjacent to NUL.
{"SentinelThenNul", "\uE000\x00", true},
{"NulThenSentinel", "\x00\uE000", true},
{"AlternatingSentinelNul", "\x00\uE000\x00\uE000", true},
// --- strings containing backslashes ---
// Backslashes are normal characters at the Go string
// level; no special handling needed (unlike the old
// JSON-byte approach).
{"BackslashU0000Text", "\\u0000", false},
{"BackslashThenNul", "\\\x00", true},
// --- literal text that looks like escape patterns ---
{"LiteralTextU0000", "the value is u0000 here", false},
{"LiteralTextUE000", "sentinel uE000 text", false},
// --- other control characters mixed with NUL ---
{"ControlCharsMixedWithNul", "\x01\x00\x02\x00\x1f", true},
// --- long / stress ---
{"LongNulRun", "\x00\x00\x00\x00\x00\x00\x00\x00", true},
// Simulated find -print0 output.
{"FindPrint0", "/usr/bin/ls\x00/usr/bin/cat\x00/usr/bin/grep\x00", true},
}
for _, tc := range textTests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(tc.input),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
// When the input has real NUL bytes, the stored JSON
// must not contain the \u0000 escape sequence.
if tc.hasNul {
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
}
// In-memory round-trip through ParseContent.
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 1)
require.Equal(t, tc.input, decoded[0].Text)
// Full DB round-trip: write to PostgreSQL jsonb, read
// back, and verify the value survives storage.
ctx := testutil.Context(t, testutil.WaitShort)
dbMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{user.ID},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{string(encoded.RawMessage)},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
require.Len(t, dbMsgs, 1)
readBack, err := db.GetChatMessageByID(ctx, dbMsgs[0].ID)
require.NoError(t, err)
dbDecoded, err := chatprompt.ParseContent(readBack)
require.NoError(t, err)
require.Len(t, dbDecoded, 1)
require.Equal(t, tc.input, dbDecoded[0].Text)
})
}
// Tool result with NUL in the result JSON value.
t.Run("ToolResultWithNul", func(t *testing.T) {
t.Parallel()
resultJSON := json.RawMessage(`"output:\u0000done"`)
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
msg := testMsgV1(codersdk.ChatMessageRoleTool, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 1)
// JSON re-serialization may reformat, so compare
// semantically.
assert.JSONEq(t, string(resultJSON), string(decoded[0].Result))
})
// Multiple parts in one message: one with NUL, one without.
t.Run("MultiPartMixed", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText("clean text"),
codersdk.ChatMessageText("has\x00nul"),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 2)
require.Equal(t, "clean text", decoded[0].Text)
require.Equal(t, "has\x00nul", decoded[1].Text)
})
}
func TestConvertMessagesWithFiles_FiltersEmptyTextAndReasoningParts(t *testing.T) {
t.Parallel()
// Helper to build a DB message from SDK parts.
makeMsg := func(t *testing.T, role database.ChatMessageRole, parts []codersdk.ChatMessagePart) database.ChatMessage {
t.Helper()
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
return database.ChatMessage{
Role: role,
Visibility: database.ChatMessageVisibilityBoth,
Content: encoded,
ContentVersion: chatprompt.CurrentContentVersion,
}
}
t.Run("UserRole", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""), // empty — filtered
codersdk.ChatMessageText(" \t\n "), // whitespace — filtered
codersdk.ChatMessageReasoning(""), // empty — filtered
codersdk.ChatMessageReasoning(" \n"), // whitespace — filtered
codersdk.ChatMessageText("hello"), // kept
codersdk.ChatMessageText(" hello "), // kept with original whitespace
codersdk.ChatMessageReasoning("thinking deeply"), // kept
codersdk.ChatMessageToolCall("call-1", "my_tool", json.RawMessage(`{"x":1}`)),
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleUser, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
require.Len(t, prompt, 1)
resultParts := prompt[0].Content
require.Len(t, resultParts, 5, "expected 5 parts after filtering empty text/reasoning")
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
require.True(t, ok, "expected TextPart at index 0")
require.Equal(t, "hello", textPart.Text)
// Leading/trailing whitespace is preserved — only
// all-whitespace parts are dropped.
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[1])
require.True(t, ok, "expected TextPart at index 1")
require.Equal(t, " hello ", paddedPart.Text)
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](resultParts[2])
require.True(t, ok, "expected ReasoningPart at index 2")
require.Equal(t, "thinking deeply", reasoningPart.Text)
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[3])
require.True(t, ok, "expected ToolCallPart at index 3")
require.Equal(t, "call-1", toolCallPart.ToolCallID)
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](resultParts[4])
require.True(t, ok, "expected ToolResultPart at index 4")
require.Equal(t, "call-1", toolResultPart.ToolCallID)
})
t.Run("AssistantRole", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""), // empty — filtered
codersdk.ChatMessageText(" "), // whitespace — filtered
codersdk.ChatMessageReasoning(""), // empty — filtered
codersdk.ChatMessageText(" reply "), // kept with whitespace
codersdk.ChatMessageToolCall("tc-1", "read_file", json.RawMessage(`{"path":"x"}`)),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
// 2 messages: assistant + synthetic tool result injected
// by injectMissingToolResults for the unmatched tool call.
require.Len(t, prompt, 2)
resultParts := prompt[0].Content
require.Len(t, resultParts, 2, "expected text + tool-call after filtering")
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
require.True(t, ok, "expected TextPart")
require.Equal(t, " reply ", textPart.Text)
tcPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[1])
require.True(t, ok, "expected ToolCallPart")
require.Equal(t, "tc-1", tcPart.ToolCallID)
})
t.Run("AllEmptyDropsMessage", func(t *testing.T) {
t.Parallel()
// When every part is filtered, the message itself should
// be dropped rather than appending an empty-content message.
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""),
codersdk.ChatMessageText(" "),
codersdk.ChatMessageReasoning(""),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
require.Empty(t, prompt, "all-empty message should be dropped entirely")
})
}
+1 -9
View File
@@ -1083,7 +1083,6 @@ func openAIProviderOptionsFromChatConfig(
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
ServiceTier: openAIServiceTierFromChat(options.ServiceTier),
StrictJSONSchema: options.StrictJSONSchema,
Store: boolPtrOrDefault(options.Store, true),
TextVerbosity: OpenAITextVerbosityFromChat(options.TextVerbosity),
User: normalizedStringPointer(options.User),
}
@@ -1100,7 +1099,7 @@ func openAIProviderOptionsFromChatConfig(
MaxCompletionTokens: options.MaxCompletionTokens,
TextVerbosity: normalizedStringPointer(options.TextVerbosity),
Prediction: options.Prediction,
Store: boolPtrOrDefault(options.Store, true),
Store: options.Store,
Metadata: options.Metadata,
PromptCacheKey: normalizedStringPointer(options.PromptCacheKey),
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
@@ -1281,13 +1280,6 @@ func useOpenAIResponsesOptions(model fantasy.LanguageModel) bool {
}
}
func boolPtrOrDefault(value *bool, def bool) *bool {
if value != nil {
return value
}
return &def
}
func normalizedStringPointer(value *string) *string {
if value == nil {
return nil
+34 -23
View File
@@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
@@ -26,37 +25,37 @@ func TestReasoningEffortFromChat(t *testing.T) {
{
name: "OpenAICaseInsensitive",
provider: "openai",
input: ptr.Ref(" HIGH "),
want: ptr.Ref(string(fantasyopenai.ReasoningEffortHigh)),
input: stringPtr(" HIGH "),
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
},
{
name: "AnthropicEffort",
provider: "anthropic",
input: ptr.Ref("max"),
want: ptr.Ref(string(fantasyanthropic.EffortMax)),
input: stringPtr("max"),
want: stringPtr(string(fantasyanthropic.EffortMax)),
},
{
name: "OpenRouterEffort",
provider: "openrouter",
input: ptr.Ref("medium"),
want: ptr.Ref(string(fantasyopenrouter.ReasoningEffortMedium)),
input: stringPtr("medium"),
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
},
{
name: "VercelEffort",
provider: "vercel",
input: ptr.Ref("xhigh"),
want: ptr.Ref(string(fantasyvercel.ReasoningEffortXHigh)),
input: stringPtr("xhigh"),
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
},
{
name: "InvalidEffortReturnsNil",
provider: "openai",
input: ptr.Ref("unknown"),
input: stringPtr("unknown"),
want: nil,
},
{
name: "UnsupportedProviderReturnsNil",
provider: "bedrock",
input: ptr.Ref("high"),
input: stringPtr("high"),
want: nil,
},
{
@@ -83,8 +82,8 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: ptr.Ref(true),
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"openai"},
@@ -93,22 +92,22 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: ptr.Ref(false),
Exclude: ptr.Ref(true),
MaxTokens: ptr.Ref[int64](123),
Effort: ptr.Ref("high"),
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
Effort: stringPtr("high"),
},
IncludeUsage: ptr.Ref(true),
IncludeUsage: boolPtr(true),
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"anthropic"},
AllowFallbacks: ptr.Ref(true),
RequireParameters: ptr.Ref(false),
DataCollection: ptr.Ref("allow"),
AllowFallbacks: boolPtr(true),
RequireParameters: boolPtr(false),
DataCollection: stringPtr("allow"),
Only: []string{"openai"},
Ignore: []string{"foo"},
Quantizations: []string{"int8"},
Sort: ptr.Ref("latency"),
Sort: stringPtr("latency"),
},
},
}
@@ -137,3 +136,15 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func stringPtr(value string) *string {
return &value
}
func boolPtr(value bool) *bool {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}
+61 -138
View File
@@ -3,12 +3,14 @@ package chattool
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@@ -21,10 +23,9 @@ const (
// maxOutputToModel is the maximum output sent to the LLM.
maxOutputToModel = 32 << 10 // 32KB
// snapshotTimeout is how long a non-blocking fallback
// request is allowed to take when retrieving a process
// output snapshot after a blocking wait times out.
snapshotTimeout = 30 * time.Second
// pollInterval is how often we check for process completion
// in foreground mode.
pollInterval = 200 * time.Millisecond
)
// nonInteractiveEnvVars are set on every process to prevent
@@ -77,10 +78,10 @@ type ProcessToolOptions struct {
// ExecuteArgs are the parameters accepted by the execute tool.
type ExecuteArgs struct {
Command string `json:"command" description:"The shell command to execute."`
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
Command string `json:"command"`
Timeout *string `json:"timeout,omitempty"`
WorkDir *string `json:"workdir,omitempty"`
RunInBackground *bool `json:"run_in_background,omitempty"`
}
// Execute returns an AgentTool that runs a shell command in the
@@ -88,7 +89,7 @@ type ExecuteArgs struct {
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command times out, the response includes a background_process_id so you can retrieve output later with process_output.",
"Execute a shell command in the workspace.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -121,16 +122,6 @@ func executeTool(
background := args.RunInBackground != nil && *args.RunInBackground
// Detect shell-style backgrounding (trailing &) and promote to
// background mode. Models sometimes use "cmd &" instead of the
// run_in_background parameter, which causes the shell to fork
// and exit immediately, leaving an untracked orphan process.
trimmed := strings.TrimSpace(args.Command)
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") && !strings.HasSuffix(trimmed, "|&") {
background = true
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
}
var workDir string
if args.WorkDir != nil {
workDir = *args.WorkDir
@@ -172,7 +163,7 @@ func executeBackground(
return fantasy.NewTextResponse(string(data))
}
// executeForeground starts a process and waits for its
// executeForeground starts a process and polls for its
// completion, enforcing the configured timeout.
func executeForeground(
ctx context.Context,
@@ -211,7 +202,7 @@ func executeForeground(
return errorResult(fmt.Sprintf("start process: %v", err))
}
result := waitForProcess(cmdCtx, conn, resp.ID, timeout)
result := pollProcess(cmdCtx, conn, resp.ID, timeout)
result.WallDurationMs = time.Since(start).Milliseconds()
// Add an advisory note for file-dump commands.
@@ -236,84 +227,62 @@ func truncateOutput(output string) string {
return output
}
// waitForProcess waits for process completion using the
// blocking process output API instead of polling.
func waitForProcess(
// pollProcess polls for process output until the process exits
// or the context times out.
func pollProcess(
ctx context.Context,
conn workspacesdk.AgentConn,
processID string,
timeout time.Duration,
) ExecuteResult {
// Block until the process exits or the context is
// canceled.
resp, err := conn.ProcessOutput(ctx, processID, &workspacesdk.ProcessOutputOptions{
Wait: true,
})
if err != nil {
if ctx.Err() != nil {
// Timeout: fetch final snapshot with a fresh
// context. The blocking request was canceled
// so the response body was lost.
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Timeout — get whatever output we have. Use a
// fresh context since cmdCtx is already canceled.
bgCtx, bgCancel := context.WithTimeout(
context.Background(),
snapshotTimeout,
5*time.Second,
)
defer bgCancel()
resp, err = conn.ProcessOutput(bgCtx, processID, nil)
outputResp, outputErr := conn.ProcessOutput(bgCtx, processID)
bgCancel()
output := truncateOutput(outputResp.Output)
timeoutErr := xerrors.Errorf("command timed out after %s", timeout)
if outputErr != nil {
timeoutErr = errors.Join(timeoutErr, xerrors.Errorf("failed to get output: %w", outputErr))
}
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: timeoutErr.Error(),
Truncated: outputResp.Truncated,
}
case <-ticker.C:
outputResp, err := conn.ProcessOutput(ctx, processID)
if err != nil {
return ExecuteResult{
Success: false,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err),
BackgroundProcessID: processID,
Success: false,
Error: fmt.Sprintf("get process output: %v", err),
}
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: resp.Truncated,
BackgroundProcessID: processID,
if !outputResp.Running {
exitCode := 0
if outputResp.ExitCode != nil {
exitCode = *outputResp.ExitCode
}
output := truncateOutput(outputResp.Output)
return ExecuteResult{
Success: exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: outputResp.Truncated,
}
}
}
return ExecuteResult{
Success: false,
Error: fmt.Sprintf("get process output: %v", err),
}
}
// The server-side wait may return before the
// process exits if maxWaitDuration is shorter than
// the client's timeout. Retry if our context still
// has time left.
if resp.Running {
if ctx.Err() == nil {
// Still within the caller's timeout, retry.
return waitForProcess(ctx, conn, processID, timeout)
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: resp.Truncated,
BackgroundProcessID: processID,
}
}
exitCode := 0
if resp.ExitCode != nil {
exitCode = *resp.ExitCode
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: resp.Truncated,
}
}
@@ -343,19 +312,10 @@ func detectFileDump(command string) string {
return ""
}
const (
// defaultProcessOutputTimeout is the default time the
// process_output tool blocks waiting for new output or
// process exit before returning. This avoids polling
// loops that waste tokens and HTTP round-trips.
defaultProcessOutputTimeout = 10 * time.Second
)
// ProcessOutputArgs are the parameters accepted by the
// process_output tool.
type ProcessOutputArgs struct {
ProcessID string `json:"process_id"`
WaitTimeout *string `json:"wait_timeout,omitempty" description:"Override the default 10s block duration. The call blocks until the process exits or this timeout is reached. Set to '0s' for an immediate snapshot without waiting."`
ProcessID string `json:"process_id"`
}
// ProcessOutput returns an AgentTool that retrieves the output
@@ -365,13 +325,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
"process_output",
"Retrieve output from a background process. "+
"Use the process_id returned by execute with "+
"run_in_background=true or from a timed-out "+
"execute's background_process_id. Blocks up to "+
"10s for the process to exit, then returns the "+
"output and exit_code. If still running after "+
"the timeout, returns the output so far. Use "+
"wait_timeout to override the default 10s wait "+
"(e.g. '30s', or '0s' for an immediate snapshot).",
"run_in_background=true. Returns the current output, "+
"whether the process is still running, and the exit "+
"code if it has finished.",
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -383,42 +339,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
timeout := defaultProcessOutputTimeout
if args.WaitTimeout != nil {
parsed, err := time.ParseDuration(*args.WaitTimeout)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("invalid wait_timeout %q: %v", *args.WaitTimeout, err),
), nil
}
timeout = parsed
}
var opts *workspacesdk.ProcessOutputOptions
// Save parent context before applying timeout.
parentCtx := ctx
if timeout > 0 {
opts = &workspacesdk.ProcessOutputOptions{
Wait: true,
}
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
resp, err := conn.ProcessOutput(ctx, args.ProcessID, opts)
resp, err := conn.ProcessOutput(ctx, args.ProcessID)
if err != nil {
// If our wait timed out but the parent is still alive,
// fetch a non-blocking snapshot.
if ctx.Err() == nil || parentCtx.Err() != nil {
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
bgCtx, bgCancel := context.WithTimeout(parentCtx, snapshotTimeout)
defer bgCancel()
resp, err = conn.ProcessOutput(bgCtx, args.ProcessID, nil)
if err != nil {
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
// Fall through to normal response handling below.
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
output := truncateOutput(resp.Output)
exitCode := 0
@@ -432,7 +355,7 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
Truncated: resp.Truncated,
}
if resp.Running {
// Process is still running, success is not
// Process is still running success is not
// yet determined.
result.Success = true
result.Note = "process is still running"
@@ -1,100 +0,0 @@
package chattool
import (
"context"
"encoding/json"
"strings"
"testing"
"unicode/utf8"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestTruncateOutput(t *testing.T) {
t.Parallel()
t.Run("EmptyOutput", func(t *testing.T) {
t.Parallel()
result := runForegroundWithOutput(t, "")
assert.Empty(t, result.Output)
})
t.Run("ShortOutput", func(t *testing.T) {
t.Parallel()
result := runForegroundWithOutput(t, "short")
assert.Equal(t, "short", result.Output)
})
t.Run("ExactlyAtLimit", func(t *testing.T) {
t.Parallel()
output := strings.Repeat("a", maxOutputToModel)
result := runForegroundWithOutput(t, output)
assert.Equal(t, maxOutputToModel, len(result.Output))
assert.Equal(t, output, result.Output)
})
t.Run("OverLimit", func(t *testing.T) {
t.Parallel()
output := strings.Repeat("b", maxOutputToModel+1024)
result := runForegroundWithOutput(t, output)
assert.Equal(t, maxOutputToModel, len(result.Output))
})
t.Run("MultiByteCutMidCharacter", func(t *testing.T) {
t.Parallel()
// Build output that places a 3-byte UTF-8 character
// (U+2603, snowman ☃) right at the truncation boundary
// so the cut falls mid-character.
padding := strings.Repeat("x", maxOutputToModel-1)
output := padding + "☃" // ☃ is 3 bytes, only 1 byte fits
result := runForegroundWithOutput(t, output)
assert.LessOrEqual(t, len(result.Output), maxOutputToModel)
assert.True(t, utf8.ValidString(result.Output),
"truncated output must be valid UTF-8")
})
}
// runForegroundWithOutput runs a foreground command through the
// Execute tool with a mock that returns the given output, and
// returns the parsed result.
func runForegroundWithOutput(t *testing.T, output string) ExecuteResult {
t.Helper()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: output,
}, nil)
tool := Execute(ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo test"}`,
})
require.NoError(t, err)
var result ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
return result
}
-493
View File
@@ -1,493 +0,0 @@
package chattool_test
import (
"context"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestExecuteTool(t *testing.T) {
t.Parallel()
t.Run("EmptyCommand", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
tool := newExecuteTool(t, mockConn)
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":""}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "command is required")
})
t.Run("AmpersandDetection", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
command string
runInBackground *bool
wantCommand string
wantBackground bool
wantBackgroundResp bool // true if the response should contain a background_process_id
comment string
}{
{
name: "SimpleBackground",
command: "cmd &",
wantCommand: "cmd",
wantBackground: true,
wantBackgroundResp: true,
comment: "Trailing & is correctly detected and stripped.",
},
{
name: "TrailingDoubleAmpersand",
command: "cmd &&",
wantCommand: "cmd &&",
wantBackground: false,
wantBackgroundResp: false,
comment: "Ends with &&, excluded by the && suffix check.",
},
{
name: "NoAmpersand",
command: "cmd",
wantCommand: "cmd",
wantBackground: false,
wantBackgroundResp: false,
},
{
name: "ChainThenBackground",
command: "cmd1 && cmd2 &",
wantCommand: "cmd1 && cmd2",
wantBackground: true,
wantBackgroundResp: true,
comment: "Ends with & but not &&, so it gets promoted " +
"to background and the trailing & is stripped. " +
"The remaining command runs in background mode.",
},
{
// "|&" is bash's pipe-stderr operator, not
// backgrounding. It must not be detected as a
// trailing "&".
name: "BashPipeStderr",
command: "cmd |&",
wantCommand: "cmd |&",
wantBackground: false,
wantBackgroundResp: false,
},
{
name: "AlreadyBackgroundWithTrailingAmpersand",
command: "cmd &",
runInBackground: ptr(true),
wantCommand: "cmd &",
wantBackground: true,
wantBackgroundResp: true,
comment: "When run_in_background is already true, " +
"the stripping logic is skipped, preserving " +
"the original command.",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
var capturedReq workspacesdk.StartProcessRequest
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
capturedReq = req
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
})
// For foreground cases, ProcessOutput is polled.
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
}, nil).
AnyTimes()
tool := newExecuteTool(t, mockConn)
input := map[string]any{"command": tc.command}
if tc.runInBackground != nil {
input["run_in_background"] = *tc.runInBackground
}
inputJSON, err := json.Marshal(input)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: string(inputJSON),
})
require.NoError(t, err)
assert.False(t, resp.IsError, "response should not be an error")
assert.Equal(t, tc.wantCommand, capturedReq.Command,
"command passed to StartProcess")
assert.Equal(t, tc.wantBackground, capturedReq.Background,
"background flag passed to StartProcess")
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
if tc.wantBackgroundResp {
assert.NotEmpty(t, result.BackgroundProcessID,
"expected background_process_id in response")
} else {
assert.Empty(t, result.BackgroundProcessID,
"expected no background_process_id")
}
})
}
})
t.Run("ForegroundSuccess", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
var capturedReq workspacesdk.StartProcessRequest
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
capturedReq = req
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
})
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "hello world",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hello"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.True(t, result.Success)
assert.Equal(t, 0, result.ExitCode)
assert.Equal(t, "hello world", result.Output)
assert.Empty(t, result.BackgroundProcessID)
assert.Equal(t, "true", capturedReq.Env["CODER_CHAT_AGENT"])
})
t.Run("ForegroundNonZeroExit", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 42
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "something failed",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"exit 42"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Equal(t, 42, result.ExitCode)
assert.Equal(t, "something failed", result.Output)
})
t.Run("BackgroundExecution", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
assert.True(t, req.Background)
return workspacesdk.StartProcessResponse{ID: "bg-42"}, nil
})
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"sleep 999","run_in_background":true}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.True(t, result.Success)
assert.Equal(t, "bg-42", result.BackgroundProcessID)
})
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
// First call (blocking wait) returns context error
// because the 50ms timeout expires.
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
DoAndReturn(func(ctx context.Context, _ string, _ *workspacesdk.ProcessOutputOptions) (workspacesdk.ProcessOutputResponse, error) {
<-ctx.Done()
return workspacesdk.ProcessOutputResponse{}, ctx.Err()
})
// Second call (snapshot fallback) returns partial output.
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: true,
Output: "partial output",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
// 50ms timeout expires during the blocking wait.
Input: `{"command":"sleep 999","timeout":"50ms"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Equal(t, -1, result.ExitCode)
assert.Contains(t, result.Error, "timed out")
assert.Equal(t, "partial output", result.Output)
})
t.Run("StartProcessError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{}, xerrors.New("connection lost"))
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
// Errors from StartProcess are returned as a JSON body
// with success=false, not as a ToolResponse error.
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Contains(t, result.Error, "connection lost")
})
t.Run("ProcessOutputError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected"))
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Contains(t, result.Error, "agent disconnected")
})
t.Run("GetWorkspaceConnNil", func(t *testing.T) {
t.Parallel()
tool := chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: nil,
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "not configured")
})
t.Run("GetWorkspaceConnError", func(t *testing.T) {
t.Parallel()
tool := chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("workspace offline")
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "workspace offline")
})
}
func TestDetectFileDump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
command string
wantHit bool
}{
{
name: "CatFile",
command: "cat foo.txt",
wantHit: true,
},
{
name: "NotCatPrefix",
command: "concatenate foo",
wantHit: false,
},
{
name: "GrepIncludeAll",
command: "grep --include-all pattern",
wantHit: true,
},
{
name: "RgListFiles",
command: "rg -l pattern",
wantHit: true,
},
{
name: "GrepRecursive",
command: "grep -r pattern",
wantHit: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "output",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
input, err := json.Marshal(map[string]any{
"command": tc.command,
})
require.NoError(t, err)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: string(input),
})
require.NoError(t, err)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
if tc.wantHit {
assert.Contains(t, result.Note, "read_file",
"expected advisory note for %q", tc.command)
} else {
assert.Empty(t, result.Note,
"expected no note for %q", tc.command)
}
})
}
}
// newExecuteTool creates an Execute tool wired to the given mock.
func newExecuteTool(t *testing.T, mockConn *agentconnmock.MockAgentConn) fantasy.AgentTool {
t.Helper()
return chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
},
})
}
func ptr[T any](v T) *T {
return &v
}
+2 -152
View File
@@ -92,7 +92,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
// Verify the chat completed and messages were persisted.
chatData, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID)
require.NoError(t, err)
t.Logf("Chat status after step 1: %s, messages: %d",
chatData.Status, len(chatMsgs.Messages))
@@ -154,7 +154,7 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
// Verify the follow-up completed and produced content.
chatData2, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID)
require.NoError(t, err)
t.Logf("Chat status after step 2: %s, messages: %d",
chatData2.Status, len(chatMsgs2.Messages))
@@ -272,156 +272,6 @@ func logMessages(t *testing.T, msgs []codersdk.ChatMessage) {
}
}
// TestOpenAIReasoningRoundTrip is an integration test that verifies
// reasoning items from OpenAI's Responses API survive the full
// persist → reconstruct → re-send cycle when Store: true. It sends
// a query to a reasoning model, waits for completion, then sends a
// follow-up message. If reasoning items are sent back without their
// required following output item, the API rejects the second request:
//
// Item 'rs_xxx' of type 'reasoning' was provided without its
// required following item.
//
// The test requires OPENAI_API_KEY to be set.
func TestOpenAIReasoningRoundTrip(t *testing.T) {
t.Parallel()
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
t.Skip("OPENAI_API_KEY not set; skipping OpenAI integration test")
}
baseURL := os.Getenv("OPENAI_BASE_URL")
ctx := testutil.Context(t, testutil.WaitSuperLong)
// Stand up a full coderd with the agents experiment.
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
})
_ = coderdtest.CreateFirstUser(t, client)
// Configure an OpenAI provider with the real API key.
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
// Create a model config for a reasoning model with Store: true
// (the default). Using o4-mini because it always produces
// reasoning items.
contextLimit := int64(200000)
isDefault := true
reasoningSummary := "auto"
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
Model: "o4-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
ModelConfig: &codersdk.ChatModelCallConfig{
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
Store: ptr.Ref(true),
ReasoningSummary: &reasoningSummary,
},
},
},
})
require.NoError(t, err)
// --- Step 1: Send a message that triggers reasoning ---
t.Log("Creating chat with reasoning query...")
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "What is 2+2? Be brief.",
},
},
})
require.NoError(t, err)
t.Logf("Chat created: %s (status=%s)", chat.ID, chat.Status)
// Stream events until the chat reaches a terminal status.
events, closer, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer.Close()
waitForChatDone(ctx, t, events, "step 1")
// Verify the chat completed and messages were persisted.
chatData, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 1: %s, messages: %d",
chatData.Status, len(chatMsgs.Messages))
logMessages(t, chatMsgs.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status,
"chat should be in waiting status after step 1")
// Verify the assistant message has reasoning content.
assistantMsg := findAssistantWithText(t, chatMsgs.Messages)
require.NotNil(t, assistantMsg,
"expected an assistant message with text content after step 1")
partTypes := partTypeSet(assistantMsg.Content)
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeReasoning,
"assistant message should contain reasoning parts from o4-mini")
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeText,
"assistant message should contain a text part")
// --- Step 2: Send a follow-up message ---
// This is the critical test: if reasoning items are sent back
// without their required following item, the API will reject
// the request with:
// Item 'rs_xxx' of type 'reasoning' was provided without its
// required following item.
t.Log("Sending follow-up message...")
_, err = client.CreateChatMessage(ctx, chat.ID,
codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "And what is 3+3? Be brief.",
},
},
})
require.NoError(t, err)
// Stream the follow-up response.
events2, closer2, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer2.Close()
waitForChatDone(ctx, t, events2, "step 2")
// Verify the follow-up completed and produced content.
chatData2, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 2: %s, messages: %d",
chatData2.Status, len(chatMsgs2.Messages))
logMessages(t, chatMsgs2.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status,
"chat should be in waiting status after step 2")
require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages),
"follow-up should have added more messages")
// The last assistant message should have text.
lastAssistant := findLastAssistantWithText(t, chatMsgs2.Messages)
require.NotNil(t, lastAssistant,
"expected an assistant message with text in the follow-up")
t.Log("OpenAI reasoning round-trip test passed.")
}
// partTypeSet returns the set of part types present in a message.
func partTypeSet(parts []codersdk.ChatMessagePart) map[codersdk.ChatMessagePartType]struct{} {
set := make(map[codersdk.ChatMessagePartType]struct{}, len(parts))
-542
View File
@@ -1,542 +0,0 @@
package mcpclient
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"strings"
"sync"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/database"
)
// toolNameSep separates the server slug from the original tool
// name in prefixed tool names. Double underscore avoids collisions
// with tool names that may contain single underscores.
//
// TODO: tool names that themselves contain "__" produce ambiguous
// prefixed names (e.g. "srv__my__tool" is indistinguishable from
// slug "srv" + tool "my__tool" vs slug "srv__my" + tool "tool").
// This doesn't affect tool invocation since originalName is used
// directly when calling the remote server.
const toolNameSep = "__"
// connectTimeout bounds how long we wait for a single MCP server
// to start its transport and complete initialization. Servers that
// take longer are skipped so one slow server cannot block the
// entire chat startup.
const connectTimeout = 10 * time.Second
// toolCallTimeout bounds how long a single tool invocation may
// take before being canceled.
const toolCallTimeout = 60 * time.Second
// ConnectAll connects to all configured MCP servers, discovers
// their tools, and returns them as fantasy.AgentTool values. It
// skips servers that fail to connect and logs warnings. The
// returned cleanup function must be called to close all
// connections.
func ConnectAll(
ctx context.Context,
logger slog.Logger,
configs []database.MCPServerConfig,
tokens []database.MCPServerUserToken,
) ([]fantasy.AgentTool, func()) {
// Index tokens by server config ID so auth header
// construction is O(1) per server.
tokensByConfigID := make(
map[uuid.UUID]database.MCPServerUserToken, len(tokens),
)
for _, tok := range tokens {
tokensByConfigID[tok.MCPServerConfigID] = tok
}
var (
mu sync.Mutex
clients []*client.Client
tools []fantasy.AgentTool
)
// Build cleanup eagerly so it always closes any clients
// that connected, even if a later connection fails.
cleanup := func() {
mu.Lock()
defer mu.Unlock()
for _, c := range clients {
_ = c.Close()
}
clients = nil
}
var eg errgroup.Group
for _, cfg := range configs {
if !cfg.Enabled {
continue
}
eg.Go(func() error {
serverTools, mcpClient, connectErr := connectOne(
ctx, logger, cfg, tokensByConfigID,
)
if connectErr != nil {
logger.Warn(ctx,
"skipping MCP server due to connection failure",
slog.F("server_slug", cfg.Slug),
slog.F("server_url", RedactURL(cfg.Url)),
slog.F("error", redactErrorURL(connectErr)),
)
// Connection failures are not propagated — the
// LLM simply won't have this server's tools.
return nil
}
mu.Lock()
clients = append(clients, mcpClient)
tools = append(tools, serverTools...)
mu.Unlock()
return nil
})
}
// All goroutines return nil; error is intentionally
// discarded.
_ = eg.Wait()
return tools, cleanup
}
// connectOne establishes a connection to a single MCP server,
// discovers its tools, and wraps each one as an AgentTool with
// the server slug prefix applied.
func connectOne(
ctx context.Context,
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
) ([]fantasy.AgentTool, *client.Client, error) {
headers := buildAuthHeaders(ctx, logger, cfg, tokensByConfigID)
tr, err := createTransport(cfg, headers)
if err != nil {
return nil, nil, xerrors.Errorf(
"create transport: %w", err,
)
}
mcpClient := client.NewClient(tr)
// The timeout covers the entire connect+init+list sequence,
// not each phase individually.
connectCtx, cancel := context.WithTimeout(
ctx, connectTimeout,
)
defer cancel()
if err := mcpClient.Start(connectCtx); err != nil {
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf(
"start transport: %w", err,
)
}
_, err = mcpClient.Initialize(
connectCtx,
mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "coder",
Version: buildinfo.Version(),
},
},
},
)
if err != nil {
// Best-effort close so we don't leak the transport.
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf("initialize: %w", err)
}
toolsResult, err := mcpClient.ListTools(
connectCtx, mcp.ListToolsRequest{},
)
if err != nil {
_ = mcpClient.Close()
return nil, nil, xerrors.Errorf("list tools: %w", err)
}
var tools []fantasy.AgentTool
for _, mcpTool := range toolsResult.Tools {
if !isToolAllowed(
mcpTool.Name,
cfg.ToolAllowList,
cfg.ToolDenyList,
) {
logger.Debug(ctx, "skipping denied MCP tool",
slog.F("server_slug", cfg.Slug),
slog.F("tool_name", mcpTool.Name),
)
continue
}
tools = append(
tools, newMCPTool(cfg.Slug, mcpTool, mcpClient),
)
}
// If no tools passed filtering, close the client early
// to avoid holding an idle connection.
if len(tools) == 0 {
_ = mcpClient.Close()
return nil, nil, nil
}
return tools, mcpClient, nil
}
// createTransport builds the appropriate mcp-go transport based
// on the server's configured transport type.
func createTransport(
cfg database.MCPServerConfig,
headers map[string]string,
) (transport.Interface, error) {
switch cfg.Transport {
case "sse":
return transport.NewSSE(
cfg.Url,
transport.WithHeaders(headers),
)
case "", "streamable_http":
// Default to streamable HTTP, the newer transport.
return transport.NewStreamableHTTP(
cfg.Url,
transport.WithHTTPHeaders(headers),
)
default:
return nil, xerrors.Errorf(
"unsupported transport %q", cfg.Transport,
)
}
}
// buildAuthHeaders constructs HTTP headers for authenticating
// with the MCP server based on the configured auth type.
func buildAuthHeaders(
ctx context.Context,
logger slog.Logger,
cfg database.MCPServerConfig,
tokensByConfigID map[uuid.UUID]database.MCPServerUserToken,
) map[string]string {
// Using map[string]string rather than http.Header because
// the mcp-go transport options accept map[string]string.
// MCP servers typically don't require multi-valued headers.
headers := make(map[string]string)
switch cfg.AuthType {
case "oauth2":
tok, ok := tokensByConfigID[cfg.ID]
if !ok {
logger.Warn(ctx,
"no oauth2 token found for MCP server",
slog.F("server_slug", cfg.Slug),
)
break
}
if tok.Expiry.Valid && tok.Expiry.Time.Before(time.Now()) {
logger.Warn(ctx,
"oauth2 token for MCP server is expired",
slog.F("server_slug", cfg.Slug),
slog.F("expired_at", tok.Expiry.Time),
)
}
if tok.AccessToken == "" {
logger.Warn(ctx,
"oauth2 token record has empty access token",
slog.F("server_slug", cfg.Slug),
)
break
}
tokenType := tok.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
headers["Authorization"] = tokenType + " " + tok.AccessToken
case "api_key":
if cfg.APIKeyHeader != "" && cfg.APIKeyValue != "" {
headers[cfg.APIKeyHeader] = cfg.APIKeyValue
}
case "custom_headers":
if cfg.CustomHeaders != "" {
var custom map[string]string
if err := json.Unmarshal(
[]byte(cfg.CustomHeaders), &custom,
); err != nil {
logger.Warn(ctx,
"failed to parse custom headers JSON",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
} else {
for k, v := range custom {
headers[k] = v
}
}
}
case "none", "":
// No auth headers needed.
}
return headers
}
// isToolAllowed checks a tool name against the allow and deny
// lists. When the allow list is non-empty only tools in it are
// permitted and the deny list is ignored. When the allow list
// is empty and the deny list is non-empty, tools in the deny
// list are rejected. Both lists use exact string matching
// against the original (non-prefixed) tool name.
func isToolAllowed(
toolName string,
allowList []string,
denyList []string,
) bool {
if len(allowList) > 0 {
for _, allowed := range allowList {
if allowed == toolName {
return true
}
}
// Allow list is set but the tool isn't in it.
return false
}
for _, denied := range denyList {
if denied == toolName {
return false
}
}
return true
}
// RedactURL strips userinfo and query parameters from a URL
// to avoid logging embedded credentials. Query params are
// removed because API keys are sometimes passed as
// ?api_key=sk-... in server URLs.
func RedactURL(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return rawURL
}
u.User = nil
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
// redactErrorURL rewrites URLs in an error string to strip
// credentials. Go's net/http embeds the full request URL in
// *url.Error messages, which can leak userinfo.
func redactErrorURL(err error) string {
if err == nil {
return ""
}
var urlErr *url.Error
if errors.As(err, &urlErr) {
urlErr.URL = RedactURL(urlErr.URL)
return urlErr.Error()
}
return err.Error()
}
// mcpToolWrapper adapts a single MCP tool into a
// fantasy.AgentTool. It stores the prefixed name for Info() but
// strips the prefix when forwarding calls to the remote server.
type mcpToolWrapper struct {
prefixedName string
originalName string
description string
parameters map[string]any
required []string
client *client.Client
providerOptions fantasy.ProviderOptions
}
// newMCPTool creates an mcpToolWrapper from an mcp.Tool
// discovered on a remote server.
func newMCPTool(
serverSlug string,
tool mcp.Tool,
mcpClient *client.Client,
) *mcpToolWrapper {
return &mcpToolWrapper{
prefixedName: serverSlug + toolNameSep + tool.Name,
originalName: tool.Name,
description: tool.Description,
parameters: tool.InputSchema.Properties,
required: tool.InputSchema.Required,
client: mcpClient,
}
}
func (t *mcpToolWrapper) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: t.prefixedName,
Description: t.description,
Parameters: t.parameters,
Required: t.required,
Parallel: true,
}
}
func (t *mcpToolWrapper) Run(
ctx context.Context,
params fantasy.ToolCall,
) (fantasy.ToolResponse, error) {
var args map[string]any
if params.Input != "" {
if err := json.Unmarshal(
[]byte(params.Input), &args,
); err != nil {
return fantasy.NewTextErrorResponse(
"invalid JSON input: " + err.Error(),
), nil
}
}
callCtx, cancel := context.WithTimeout(ctx, toolCallTimeout)
defer cancel()
result, err := t.client.CallTool(
callCtx,
mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: t.originalName,
Arguments: args,
},
},
)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return convertCallResult(result), nil
}
func (t *mcpToolWrapper) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
func (t *mcpToolWrapper) SetProviderOptions(
opts fantasy.ProviderOptions,
) {
t.providerOptions = opts
}
// convertCallResult translates an MCP CallToolResult into a
// fantasy.ToolResponse. The fantasy response model supports a
// single content type per response, so we prioritize text. All
// text items are collected first. Binary items (image or audio)
// are only returned when no text content is available.
func convertCallResult(
result *mcp.CallToolResult,
) fantasy.ToolResponse {
if result == nil {
return fantasy.NewTextResponse("")
}
var (
textParts []string
binaryResult *fantasy.ToolResponse
)
for _, item := range result.Content {
switch c := item.(type) {
case mcp.TextContent:
textParts = append(textParts, c.Text)
case mcp.ImageContent:
data, err := base64.StdEncoding.DecodeString(
c.Data,
)
if err != nil {
textParts = append(textParts,
"[image decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
r := fantasy.ToolResponse{
Type: "image",
Data: data,
MediaType: c.MIMEType,
IsError: result.IsError,
}
binaryResult = &r
}
case mcp.AudioContent:
data, err := base64.StdEncoding.DecodeString(
c.Data,
)
if err != nil {
textParts = append(textParts,
"[audio decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
r := fantasy.ToolResponse{
Type: "media",
Data: data,
MediaType: c.MIMEType,
IsError: result.IsError,
}
binaryResult = &r
}
default:
textParts = append(textParts,
fmt.Sprintf("[unsupported content type: %T]", c),
)
}
}
// If structured content is present, marshal it to JSON and
// append as a text part so the data is preserved for the LLM.
if result.StructuredContent != nil {
data, err := json.Marshal(result.StructuredContent)
if err != nil {
textParts = append(textParts,
"[structured content marshal error: "+
err.Error()+"]",
)
} else {
textParts = append(textParts, string(data))
}
}
// Prefer text content. Only fall back to binary when no
// text was collected.
if len(textParts) > 0 {
resp := fantasy.NewTextResponse(
strings.Join(textParts, "\n"),
)
resp.IsError = result.IsError
return resp
}
if binaryResult != nil {
return *binaryResult
}
return fantasy.NewTextResponse("")
}
-659
View File
@@ -1,659 +0,0 @@
package mcpclient_test
import (
"context"
"database/sql"
"encoding/json"
"net/http/httptest"
"sync"
"testing"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd/mcpclient"
"github.com/coder/coder/v2/coderd/database"
)
// newTestMCPServer creates a streamable HTTP MCP server with the
// given tools. The caller must close the returned *httptest.Server.
func newTestMCPServer(t *testing.T, tools ...mcpserver.ServerTool) *httptest.Server {
t.Helper()
srv := mcpserver.NewMCPServer("test-server", "1.0.0")
srv.AddTools(tools...)
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
return ts
}
// echoTool returns a ServerTool that echoes its "input" argument
// prefixed with "echo: ".
func echoTool() mcpserver.ServerTool {
return mcpserver.ServerTool{
Tool: mcp.NewTool("echo",
mcp.WithDescription("Echoes the input"),
mcp.WithString("input", mcp.Description("The input"), mcp.Required()),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
input, _ := req.GetArguments()["input"].(string)
return mcp.NewToolResultText("echo: " + input), nil
},
}
}
// greetTool returns a ServerTool that greets by name.
func greetTool() mcpserver.ServerTool {
return mcpserver.ServerTool{
Tool: mcp.NewTool("greet",
mcp.WithDescription("Greets the user"),
mcp.WithString("name", mcp.Description("Name to greet"), mcp.Required()),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
name, _ := req.GetArguments()["name"].(string)
return mcp.NewToolResultText("hello " + name), nil
},
}
}
// makeConfig builds a database.MCPServerConfig suitable for tests.
func makeConfig(slug, url string) database.MCPServerConfig {
return database.MCPServerConfig{
ID: uuid.New(),
Slug: slug,
DisplayName: slug,
Url: url,
Transport: "streamable_http",
AuthType: "none",
Enabled: true,
}
}
func TestConnectAll_DiscoverTools(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool(), greetTool())
cfg := makeConfig("myserver", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
// Two tools should be discovered, namespaced with the server slug.
require.Len(t, tools, 2)
names := toolNames(tools)
assert.Contains(t, names, "myserver__echo")
assert.Contains(t, names, "myserver__greet")
// Verify the description is preserved.
foundEcho := findTool(tools, "myserver__echo")
require.NotNilf(t, foundEcho, "expected to find myserver__echo")
echoInfo := foundEcho.Info()
assert.Equal(t, "Echoes the input", echoInfo.Description)
}
func TestConnectAll_CallTool(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
tool := tools[0]
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "srv__echo",
Input: `{"input":"hello world"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
assert.Equal(t, "echo: hello world", resp.Content)
}
func TestConnectAll_ToolAllowList(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool(), greetTool())
cfg := makeConfig("filtered", ts.URL)
// Only allow the "echo" tool.
cfg.ToolAllowList = []string{"echo"}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
assert.Equal(t, "filtered__echo", tools[0].Info().Name)
}
func TestConnectAll_ToolDenyList(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool(), greetTool())
cfg := makeConfig("filtered", ts.URL)
// Deny the "greet" tool, so only "echo" remains.
cfg.ToolDenyList = []string{"greet"}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
assert.Equal(t, "filtered__echo", tools[0].Info().Name)
}
func TestConnectAll_ConnectionFailure(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
cfg := makeConfig("bad", "http://127.0.0.1:0/does-not-exist")
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
assert.Empty(t, tools, "no tools should be returned for an unreachable server")
}
func TestConnectAll_MultipleServers(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, echoTool())
ts2 := newTestMCPServer(t, greetTool())
cfg1 := makeConfig("alpha", ts1.URL)
cfg2 := makeConfig("beta", ts2.URL)
tools, cleanup := mcpclient.ConnectAll(
ctx, logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 2)
names := toolNames(tools)
assert.Contains(t, names, "alpha__echo")
assert.Contains(t, names, "beta__greet")
}
func TestConnectAll_AuthHeaders(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Create a server whose tool handler records the Authorization
// header it receives on each request.
var (
mu sync.Mutex
seenHeaders []string
)
srv := mcpserver.NewMCPServer("auth-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("whoami",
mcp.WithDescription("Returns the auth header"),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
auth := req.Header.Get("Authorization")
mu.Lock()
seenHeaders = append(seenHeaders, auth)
mu.Unlock()
return mcp.NewToolResultText("auth:" + auth), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
configID := uuid.New()
cfg := database.MCPServerConfig{
ID: configID,
Slug: "auth-srv",
DisplayName: "Auth Server",
Url: ts.URL,
Transport: "streamable_http",
AuthType: "oauth2",
Enabled: true,
}
token := database.MCPServerUserToken{
MCPServerConfigID: configID,
AccessToken: "test-token-abc",
TokenType: "Bearer",
}
tools, cleanup := mcpclient.ConnectAll(
ctx, logger,
[]database.MCPServerConfig{cfg},
[]database.MCPServerUserToken{token},
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
// Call the tool and verify the response includes the auth header
// that was sent.
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-auth",
Name: "auth-srv__whoami",
Input: "{}",
})
require.NoError(t, err)
assert.False(t, resp.IsError)
assert.Equal(t, "auth:Bearer test-token-abc", resp.Content)
// Also verify the handler actually observed the header.
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, seenHeaders)
assert.Equal(t, "Bearer test-token-abc", seenHeaders[len(seenHeaders)-1])
}
// --- helpers ---
func toolNames(tools []fantasy.AgentTool) []string {
names := make([]string, 0, len(tools))
for _, t := range tools {
names = append(names, t.Info().Name)
}
return names
}
func findTool(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
for _, t := range tools {
if t.Info().Name == name {
return t
}
}
return nil
}
// TestConnectAll_DisabledServer verifies that disabled configs are
// silently skipped.
func TestConnectAll_DisabledServer(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("disabled", ts.URL)
cfg.Enabled = false
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
assert.Empty(t, tools)
}
// TestConnectAll_CallToolInvalidInput verifies that malformed JSON
// input returns an error response rather than a Go error.
func TestConnectAll_CallToolInvalidInput(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
// Pass syntactically invalid JSON as tool input.
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-bad",
Name: "srv__echo",
Input: `{not json`,
})
require.NoError(t, err, "Run should not return a Go error for bad input")
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "invalid JSON input")
}
// TestConnectAll_ToolInfoParameters verifies that tool input schema
// parameters are propagated to the ToolInfo.
func TestConnectAll_ToolInfoParameters(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
info := tools[0].Info()
// The echo tool has a required "input" string parameter.
require.NotNil(t, info.Parameters)
_, hasInput := info.Parameters["input"]
assert.True(t, hasInput, "parameters should contain 'input'")
// The "input" field should also appear in Required.
inputProp, ok := info.Parameters["input"].(map[string]any)
assert.True(t, ok, "input parameter should be a map")
if ok {
propBytes, _ := json.Marshal(inputProp)
assert.Contains(t, string(propBytes), "string")
}
assert.Contains(t, info.Required, "input")
}
// TestConnectAll_APIKeyAuth verifies that api_key auth sends the
// configured header and value on every request.
func TestConnectAll_APIKeyAuth(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
var (
mu sync.Mutex
seenHeaders []string
)
srv := mcpserver.NewMCPServer("apikey-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("check",
mcp.WithDescription("Returns the API key header"),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
val := req.Header.Get("X-API-Key")
mu.Lock()
seenHeaders = append(seenHeaders, val)
mu.Unlock()
return mcp.NewToolResultText("key:" + val), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
cfg := makeConfig("apikey", ts.URL)
cfg.AuthType = "api_key"
cfg.APIKeyHeader = "X-API-Key"
cfg.APIKeyValue = "secret-123"
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-apikey",
Name: "apikey__check",
Input: "{}",
})
require.NoError(t, err)
assert.False(t, resp.IsError)
assert.Equal(t, "key:secret-123", resp.Content)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, seenHeaders)
assert.Equal(t, "secret-123", seenHeaders[len(seenHeaders)-1])
}
// TestConnectAll_CustomHeadersAuth verifies that custom_headers
// auth sends the configured headers on every request.
func TestConnectAll_CustomHeadersAuth(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
var (
mu sync.Mutex
seenHeaders []string
)
srv := mcpserver.NewMCPServer("custom-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("check",
mcp.WithDescription("Returns the custom auth header"),
),
Handler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
val := req.Header.Get("X-Custom-Auth")
mu.Lock()
seenHeaders = append(seenHeaders, val)
mu.Unlock()
return mcp.NewToolResultText("custom:" + val), nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
cfg := makeConfig("custom", ts.URL)
cfg.AuthType = "custom_headers"
cfg.CustomHeaders = `{"X-Custom-Auth":"custom-val"}`
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-custom",
Name: "custom__check",
Input: "{}",
})
require.NoError(t, err)
assert.False(t, resp.IsError)
assert.Equal(t, "custom:custom-val", resp.Content)
mu.Lock()
defer mu.Unlock()
require.NotEmpty(t, seenHeaders)
assert.Equal(t, "custom-val", seenHeaders[len(seenHeaders)-1])
}
// TestConnectAll_CustomHeadersInvalidJSON verifies that invalid
// JSON in CustomHeaders does not prevent the server from
// connecting. The auth headers are silently skipped.
func TestConnectAll_CustomHeadersInvalidJSON(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("badjson", ts.URL)
cfg.AuthType = "custom_headers"
cfg.CustomHeaders = "{not json}"
tools, cleanup := mcpclient.ConnectAll(
ctx, logger, []database.MCPServerConfig{cfg}, nil,
)
t.Cleanup(cleanup)
// The server should still connect; only auth headers are
// skipped.
require.Len(t, tools, 1)
assert.Equal(t, "badjson__echo", tools[0].Info().Name)
}
// TestConnectAll_ParallelConnections verifies that connecting to
// multiple MCP servers simultaneously returns all discovered
// tools with the correct server slug prefixes.
func TestConnectAll_ParallelConnections(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, echoTool())
ts2 := newTestMCPServer(t, greetTool())
ts3 := newTestMCPServer(t, echoTool())
cfg1 := makeConfig("srv1", ts1.URL)
cfg2 := makeConfig("srv2", ts2.URL)
cfg3 := makeConfig("srv3", ts3.URL)
tools, cleanup := mcpclient.ConnectAll(
ctx, logger,
[]database.MCPServerConfig{cfg1, cfg2, cfg3},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
names := toolNames(tools)
assert.Contains(t, names, "srv1__echo")
assert.Contains(t, names, "srv2__greet")
assert.Contains(t, names, "srv3__echo")
}
func TestRedactURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expected string
}{
{"plain", "https://mcp.example.com/v1", "https://mcp.example.com/v1"},
{"with userinfo", "https://user:secret@mcp.example.com/v1", "https://mcp.example.com/v1"},
{"with query params", "https://mcp.example.com/v1?api_key=sk-123", "https://mcp.example.com/v1"},
{"with both", "https://user:pass@host/p?key=val", "https://host/p"},
{"invalid url", "://not-a-url", "://not-a-url"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := mcpclient.RedactURL(tt.input)
assert.Equal(t, tt.expected, got)
})
}
}
func TestConnectAll_ExpiredToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
configID := uuid.New()
cfg := database.MCPServerConfig{
ID: configID,
Slug: "expired-srv",
DisplayName: "Expired Server",
Url: ts.URL,
Transport: "streamable_http",
AuthType: "oauth2",
Enabled: true,
}
// Token exists but is expired.
token := database.MCPServerUserToken{
MCPServerConfigID: configID,
AccessToken: "expired-token",
TokenType: "Bearer",
Expiry: sql.NullTime{Time: time.Now().Add(-1 * time.Hour), Valid: true},
}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token})
t.Cleanup(cleanup)
// The server accepts any auth, so the tool is still discovered
// despite the expired token. The important thing is that the
// warning is logged (verified via IgnoreErrors: true in slogtest).
require.NotEmpty(t, tools)
}
func TestConnectAll_EmptyAccessToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
configID := uuid.New()
cfg := database.MCPServerConfig{
ID: configID,
Slug: "empty-tok",
DisplayName: "Empty Token Server",
Url: ts.URL,
Transport: "streamable_http",
AuthType: "oauth2",
Enabled: true,
}
// Token record exists but AccessToken is empty.
token := database.MCPServerUserToken{
MCPServerConfigID: configID,
AccessToken: "",
TokenType: "Bearer",
}
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, []database.MCPServerUserToken{token})
t.Cleanup(cleanup)
// Tool is still discovered (server doesn't require auth), but
// no Authorization header was sent. The warning about empty
// access token is logged.
require.NotEmpty(t, tools)
}
func TestConnectAll_CallToolError(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Server with a tool that always returns an error result.
srv := mcpserver.NewMCPServer("error-server", "1.0.0")
srv.AddTools(mcpserver.ServerTool{
Tool: mcp.NewTool("fail_tool",
mcp.WithDescription("Always fails"),
),
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{mcp.NewTextContent("something broke")},
IsError: true,
}, nil
},
})
httpSrv := mcpserver.NewStreamableHTTPServer(srv)
ts := httptest.NewServer(httpSrv)
t.Cleanup(ts.Close)
cfg := makeConfig("err-srv", ts.URL)
tools, cleanup := mcpclient.ConnectAll(ctx, logger, []database.MCPServerConfig{cfg}, nil)
t.Cleanup(cleanup)
require.Len(t, tools, 1)
resp, err := tools[0].Run(ctx, fantasy.ToolCall{
ID: "call-err",
Name: "err-srv__fail_tool",
Input: "{}",
})
require.NoError(t, err, "Run should not return a Go error for MCP-level errors")
assert.True(t, resp.IsError, "response should be flagged as error")
assert.Contains(t, resp.Content, "something broke")
}
+1 -3
View File
@@ -62,7 +62,6 @@ func (p *Server) maybeGenerateChatTitle(
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
generatedTitle *generatedChatTitle,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
@@ -112,8 +111,7 @@ func (p *Server) maybeGenerateChatTitle(
return
}
chat.Title = title
generatedTitle.Store(title)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
return
}
+3 -10
View File
@@ -84,14 +84,6 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
return false
}
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
enabled, err := p.db.GetChatDesktopEnabled(ctx)
if err != nil {
return false
}
return enabled
}
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
tools := []fantasy.AgentTool{
fantasy.NewAgentTool(
@@ -261,8 +253,9 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
}
// Only include the computer use tool when an Anthropic
// provider is configured and desktop is enabled.
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
// provider is configured, since it requires an Anthropic
// model.
if p.isAnthropicConfigured(ctx) {
tools = append(tools, fantasy.NewAgentTool(
"spawn_computer_use_agent",
"Spawn a dedicated computer use agent that can see the desktop "+
+3 -173
View File
@@ -15,7 +15,6 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
@@ -145,20 +144,14 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
return nil
}
func chatdTestContext(t *testing.T) context.Context {
t.Helper()
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// No Anthropic key in ProviderAPIKeys.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -183,13 +176,12 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the provider check passes.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -240,42 +232,16 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
}
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-desktop-disabled",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
}
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the tool can proceed.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// The parent uses an OpenAI model.
@@ -332,139 +298,3 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
assert.NotEmpty(t, chattool.ComputerUseModelName)
}
func TestIsSubagentDescendant(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
// Build a chain: root -> child -> grandchild.
root, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "root",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("root")},
})
require.NoError(t, err)
child, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
Title: "child",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("child")},
})
require.NoError(t, err)
grandchild, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: child.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
Title: "grandchild",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("grandchild")},
})
require.NoError(t, err)
// Build a separate, unrelated chain.
unrelated, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "unrelated-root",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated")},
})
require.NoError(t, err)
unrelatedChild, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: unrelated.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: unrelated.ID,
Valid: true,
},
Title: "unrelated-child",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated-child")},
})
require.NoError(t, err)
tests := []struct {
name string
ancestor uuid.UUID
target uuid.UUID
want bool
}{
{
name: "SameID",
ancestor: root.ID,
target: root.ID,
want: false,
},
{
name: "DirectChild",
ancestor: root.ID,
target: child.ID,
want: true,
},
{
name: "GrandChild",
ancestor: root.ID,
target: grandchild.ID,
want: true,
},
{
name: "Unrelated",
ancestor: root.ID,
target: unrelatedChild.ID,
want: false,
},
{
name: "RootChat",
ancestor: child.ID,
target: root.ID,
want: false,
},
{
name: "BrokenChain",
ancestor: root.ID,
target: uuid.New(),
want: false,
},
{
name: "NotDescendant",
ancestor: unrelated.ID,
target: child.ID,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := chatdTestContext(t)
got, err := isSubagentDescendant(ctx, db, tt.ancestor, tt.target)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
-128
View File
@@ -1,128 +0,0 @@
package chatd
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
)
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
// active usage-limit period containing now.
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
utcNow := now.UTC()
switch period {
case codersdk.ChatUsageLimitPeriodDay:
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 0, 1)
case codersdk.ChatUsageLimitPeriodWeek:
// Walk backward to Monday of the current ISO week.
// ISO 8601 weeks always start on Monday, so this never
// crosses an ISO-week boundary.
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for start.Weekday() != time.Monday {
start = start.AddDate(0, 0, -1)
}
end = start.AddDate(0, 0, 7)
case codersdk.ChatUsageLimitPeriodMonth:
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 1, 0)
default:
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
}
return start, end
}
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
//
// Note: There is a potential race condition where two concurrent messages
// from the same user can both pass the limit check if processed in
// parallel, allowing brief overage. This is acceptable because:
// - Cost is only known after the LLM API returns.
// - Overage is bounded by message cost × concurrency.
// - Fail-open is the deliberate design choice for this feature.
//
// Architecture note: today this path enforces one period globally
// (day/week/month) from config.
// To support simultaneous periods, add nullable
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
// means no limit for that period.
// Then scan spend once over the widest active window with conditional SUMs
// for each period and compare each spend/limit pair Go-side, blocking on
// whichever period is tightest.
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
// deployment config reads and cross-user chat spend aggregation.
authCtx := dbauthz.AsChatd(ctx)
config, err := db.GetChatUsageLimitConfig(authCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
return nil, err
}
if !config.Enabled {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
period, ok := mapDBPeriodToSDK(config.Period)
if !ok {
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
}
// Resolve effective limit in a single query:
// individual override > group limit > global default.
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
if err != nil {
return nil, err
}
// -1 means limits are disabled (shouldn't happen since we checked above,
// but handle gracefully).
if effectiveLimit < 0 {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
start, end := ComputeUsagePeriodBounds(now, period)
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
UserID: userID,
StartTime: start,
EndTime: end,
})
if err != nil {
return nil, err
}
return &codersdk.ChatUsageLimitStatus{
IsLimited: true,
Period: period,
SpendLimitMicros: &effectiveLimit,
CurrentSpend: spendTotal,
PeriodStart: start,
PeriodEnd: end,
}, nil
}
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
switch dbPeriod {
case string(codersdk.ChatUsageLimitPeriodDay):
return codersdk.ChatUsageLimitPeriodDay, true
case string(codersdk.ChatUsageLimitPeriodWeek):
return codersdk.ChatUsageLimitPeriodWeek, true
case string(codersdk.ChatUsageLimitPeriodMonth):
return codersdk.ChatUsageLimitPeriodMonth, true
default:
return "", false
}
}
-132
View File
@@ -1,132 +0,0 @@
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
import (
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
)
func TestComputeUsagePeriodBounds(t *testing.T) {
t.Parallel()
newYork, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatalf("load America/New_York: %v", err)
}
tests := []struct {
name string
now time.Time
period codersdk.ChatUsageLimitPeriod
wantStart time.Time
wantEnd time.Time
}{
{
name: "day/mid_day",
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/midnight_exactly",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/end_of_day",
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/wednesday",
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/monday",
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/sunday",
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/year_boundary",
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
},
{
name: "month/mid_month",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/first_day",
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/last_day",
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/february",
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/leap_year_february",
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "day/non_utc_timezone",
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
if !start.Equal(tc.wantStart) {
t.Errorf("start: got %v, want %v", start, tc.wantStart)
}
if !end.Equal(tc.wantEnd) {
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
}
})
}
}
+194 -1040
View File
File diff suppressed because it is too large Load Diff
+139 -980
View File
File diff suppressed because it is too large Load Diff
+18 -81
View File
@@ -10,7 +10,6 @@ import (
"flag"
"fmt"
"io"
"math"
"net/http"
httppprof "net/http/pprof"
"net/url"
@@ -45,7 +44,6 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/aiseats"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
@@ -631,9 +629,7 @@ func New(options *Options) *API {
),
dbRolluper: options.DatabaseRolluper,
ProfileCollector: defaultProfileCollector{},
AISeatTracker: aiseats.Noop{},
}
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
ctx,
options.Logger.Named("workspaceapps"),
@@ -767,27 +763,17 @@ func New(options *Options) *API {
}
api.agentProvider = stn
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
if maxChatsPerAcquire > math.MaxInt32 {
maxChatsPerAcquire = math.MaxInt32
}
if maxChatsPerAcquire < math.MinInt32 {
maxChatsPerAcquire = math.MinInt32
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
})
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
@@ -1044,12 +1030,10 @@ func New(options *Options) *API {
// OAuth2 metadata endpoint for RFC 8414 discovery
r.Route("/.well-known/oauth-authorization-server", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
r.Get("/*", api.oauth2AuthorizationServerMetadata())
})
// OAuth2 protected resource metadata endpoint for RFC 9728 discovery
r.Route("/.well-known/oauth-protected-resource", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
r.Get("/*", api.oauth2ProtectedResourceMetadata())
})
@@ -1162,9 +1146,6 @@ func New(options *Options) *API {
r.Get("/summary", api.chatCostSummary)
})
})
r.Route("/insights", func(r chi.Router) {
r.Get("/pull-requests", api.prInsights)
})
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
@@ -1173,8 +1154,6 @@ func New(options *Options) *API {
r.Route("/config", func(r chi.Router) {
r.Get("/system-prompt", api.getChatSystemPrompt)
r.Put("/system-prompt", api.putChatSystemPrompt)
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/user-prompt", api.getUserChatCustomPrompt)
r.Put("/user-prompt", api.putUserChatCustomPrompt)
})
@@ -1196,32 +1175,19 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/usage-limits", func(r chi.Router) {
r.Get("/", api.getChatUsageLimitConfig)
r.Put("/", api.updateChatUsageLimitConfig)
r.Get("/status", api.getMyChatUsageLimitStatus)
r.Route("/overrides/{user}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitOverride)
r.Delete("/", api.deleteChatUsageLimitOverride)
})
r.Route("/group-overrides/{group}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitGroupOverride)
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Patch("/", api.patchChat)
r.Get("/git/watch", api.watchChatGit)
r.Get("/desktop", api.watchChatDesktop)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Get("/messages", api.getChatMessages)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
r.Route("/stream", func(r chi.Router) {
r.Get("/", api.streamChat)
r.Get("/desktop", api.watchChatDesktop)
r.Get("/git", api.watchChatGit)
})
r.Get("/stream", api.streamChat)
r.Post("/interrupt", api.interruptChat)
r.Get("/diff-status", api.getChatDiffStatus)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
@@ -1233,34 +1199,10 @@ func New(options *Options) *API {
r.Route("/mcp", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
)
// MCP server configuration endpoints.
r.Route("/servers", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents))
r.Get("/", api.listMCPServerConfigs)
r.Post("/", api.createMCPServerConfig)
r.Route("/{mcpServer}", func(r chi.Router) {
r.Get("/", api.getMCPServerConfig)
r.Patch("/", api.updateMCPServerConfig)
r.Delete("/", api.deleteMCPServerConfig)
// OAuth2 user flow
r.Get("/oauth2/connect", api.mcpServerOAuth2Connect)
r.Get("/oauth2/callback", api.mcpServerOAuth2Callback)
r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect)
})
})
// MCP HTTP transport endpoint with mandatory authentication
r.Route("/http", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP))
r.Mount("/", api.mcpHTTPHandler())
})
})
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceBuildUpdates),
)
r.Get("/", api.watchAllWorkspaceBuilds)
r.Mount("/http", api.mcpHTTPHandler())
})
})
@@ -1515,8 +1457,6 @@ func New(options *Options) *API {
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Post("/logout", api.postLogout)
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
r.Get("/oidc-claims", api.userOIDCClaims)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {
r.Get("/", api.AssignableSiteRoles)
@@ -2087,8 +2027,6 @@ type API struct {
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// AISeatTracker records AI seat usage.
AISeatTracker aiseats.SeatTracker
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
@@ -2301,7 +2239,6 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
ExternalAuthConfigs: api.ExternalAuthConfigs,
AISeatTracker: api.AISeatTracker,
Clock: api.Clock,
HeartbeatFn: options.heartbeatFn,
},
-9
View File
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
m(&req)
}
// Service accounts cannot have a password or email and must
// use login_type=none. Enforce this after mutators so callers
// only need to set ServiceAccount=true.
if req.ServiceAccount {
req.Password = ""
req.Email = ""
req.UserLoginType = codersdk.LoginTypeNone
}
user, err := client.CreateUserWithOrgs(context.Background(), req)
var apiError *codersdk.Error
// If the user already exists by username or email conflict, try again up to "retries" times.
+7 -39
View File
@@ -13,64 +13,32 @@ var _ usage.Inserter = (*UsageInserter)(nil)
type UsageInserter struct {
sync.Mutex
discreteEvents []usagetypes.DiscreteEvent
heartbeatEvents []usagetypes.HeartbeatEvent
seenHeartbeats map[string]struct{}
events []usagetypes.DiscreteEvent
}
func NewUsageInserter() *UsageInserter {
return &UsageInserter{
discreteEvents: []usagetypes.DiscreteEvent{},
seenHeartbeats: map[string]struct{}{},
heartbeatEvents: []usagetypes.HeartbeatEvent{},
events: []usagetypes.DiscreteEvent{},
}
}
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
u.Lock()
defer u.Unlock()
u.discreteEvents = append(u.discreteEvents, event)
u.events = append(u.events, event)
return nil
}
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
u.Lock()
defer u.Unlock()
if _, seen := u.seenHeartbeats[id]; seen {
return nil
}
u.seenHeartbeats[id] = struct{}{}
u.heartbeatEvents = append(u.heartbeatEvents, event)
return nil
}
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
copy(eventsCopy, u.heartbeatEvents)
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
copy(eventsCopy, u.events)
return eventsCopy
}
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
copy(eventsCopy, u.discreteEvents)
return eventsCopy
}
func (u *UsageInserter) TotalEventCount() int {
u.Lock()
defer u.Unlock()
return len(u.discreteEvents) + len(u.heartbeatEvents)
}
func (u *UsageInserter) Reset() {
u.Lock()
defer u.Unlock()
u.seenHeartbeats = map[string]struct{}{}
u.discreteEvents = []usagetypes.DiscreteEvent{}
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
u.events = []usagetypes.DiscreteEvent{}
}
+18 -26
View File
@@ -6,30 +6,22 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs
CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs
CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+7 -92
View File
@@ -21,7 +21,6 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/render"
@@ -195,14 +194,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
func ReducedUser(user database.User) codersdk.ReducedUser {
return codersdk.ReducedUser{
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
IsServiceAccount: user.IsServiceAccount,
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
}
}
@@ -1166,86 +1164,3 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
value := v.Int64
return &value
}
// ChatDiffStatus converts a database.ChatDiffStatus to a
// codersdk.ChatDiffStatus. When status is nil an empty value
// containing only the chatID is returned.
func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
result := codersdk.ChatDiffStatus{
ChatID: chatID,
}
if status == nil {
return result
}
result.ChatID = status.ChatID
if status.Url.Valid {
u := strings.TrimSpace(status.Url.String)
if u != "" {
result.URL = &u
}
}
if result.URL == nil {
// Try to build a branch URL from the stored origin.
// Since this function does not have access to the API
// instance, we construct a GitHub provider directly as
// a best-effort fallback.
// TODO: This uses the default github.com API base URL,
// so branch URLs for GitHub Enterprise instances will
// be incorrect. To fix this, this function would need
// access to the external auth configs.
gp := gitprovider.New("github", "", nil)
if gp != nil {
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
}
}
}
}
if status.PullRequestState.Valid {
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
if pullRequestState != "" {
result.PullRequestState = &pullRequestState
}
}
result.PullRequestTitle = status.PullRequestTitle
result.PullRequestDraft = status.PullRequestDraft
result.ChangesRequested = status.ChangesRequested
result.Additions = status.Additions
result.Deletions = status.Deletions
result.ChangedFiles = status.ChangedFiles
if status.AuthorLogin.Valid {
result.AuthorLogin = &status.AuthorLogin.String
}
if status.AuthorAvatarUrl.Valid {
result.AuthorAvatarURL = &status.AuthorAvatarUrl.String
}
if status.BaseBranch.Valid {
result.BaseBranch = &status.BaseBranch.String
}
if status.HeadBranch.Valid {
result.HeadBranch = &status.HeadBranch.String
}
if status.PrNumber.Valid {
result.PRNumber = &status.PrNumber.Int32
}
if status.Commits.Valid {
result.Commits = &status.Commits.Int32
}
if status.Approved.Valid {
result.Approved = &status.Approved.Bool
}
if status.ReviewerCount.Valid {
result.ReviewerCount = &status.ReviewerCount.Int32
}
if status.RefreshedAt.Valid {
refreshedAt := status.RefreshedAt.Time
result.RefreshedAt = &refreshedAt
}
staleAt := status.StaleAt
result.StaleAt = &staleAt
return result
}
+22 -336
View File
@@ -705,7 +705,7 @@ var (
DisplayName: "Chat Daemon",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceWorkspace.Type: {policy.ActionRead},
rbac.ResourceDeploymentConfig.Type: {policy.ActionRead},
rbac.ResourceUser.Type: {policy.ActionReadPersonal},
}),
@@ -769,9 +769,6 @@ func AsSubAgentAPI(ctx context.Context, orgID uuid.UUID, userID uuid.UUID) conte
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
// DO NOT USE THIS UNLESS YOU HAVE ABSOLUTELY NO OTHER CHOICE. Prefer using a
// more specific As* helper above (or adding a new, narrowly-scoped one) so
// that permissions remain limited to the operation you need.
func AsSystemRestricted(ctx context.Context) context.Context {
return As(ctx, subjectSystemRestricted)
}
@@ -1267,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
// System roles are stored in the database but have a fixed, code-defined
// meaning. Do not rewrite the name for them so the static "who can assign
// what" mapping applies.
if !rolestore.IsSystemRoleName(roleName.Name) {
if !rbac.SystemRoleName(roleName.Name) {
// To support a dynamic mapping of what roles can assign what, we need
// to store this in the database. For now, just use a static role so
// owners and org admins can assign roles.
@@ -1694,13 +1691,6 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
return q.db.CleanTailnetTunnels(ctx)
}
func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -1736,13 +1726,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
}
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return 0, err
}
return q.db.CountEnabledModelsWithoutPricing(ctx)
}
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
@@ -1834,6 +1817,18 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -1859,20 +1854,6 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -1930,20 +1911,6 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
return id, nil
}
func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerConfigByID(ctx, id)
}
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerUserToken(ctx, arg)
}
func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil {
return err
@@ -2157,12 +2124,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
}
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
// This is a system-only function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
}
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
@@ -2360,13 +2327,6 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
}
func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil {
return 0, err
}
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -2494,17 +2454,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
// computer-use tooling. We only require that an explicit actor is present
// in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDesktopEnabled(ctx)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2583,14 +2532,6 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatIDDescPaginated(ctx, arg)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2655,33 +2596,8 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.GetChatUsageLimitConfig(ctx)
}
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitGroupOverrideRow{}, err
}
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitUserOverrideRow{}, err
}
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedChats(ctx, arg, prep)
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
@@ -2787,13 +2703,6 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledMCPServerConfigs(ctx)
}
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
}
@@ -2852,13 +2761,6 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg)
}
func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetForcedMCPServerConfigs(ctx)
}
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
}
@@ -2999,48 +2901,6 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}
func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigByID(ctx, id)
}
func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigBySlug(ctx, slug)
}
func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigs(ctx)
}
func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
}
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.GetMCPServerUserToken(ctx, arg)
}
func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerUserTokensByUserID(ctx, userID)
}
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
return nil, err
@@ -3227,34 +3087,6 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsPerModel(ctx, arg)
}
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsRecentPRs(ctx, arg)
}
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetPRInsightsSummaryRow{}, err
}
return q.db.GetPRInsightsSummary(ctx, arg)
}
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsTimeSeries(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -3918,13 +3750,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
}
return q.db.GetUserChatSpendInPeriod(ctx, arg)
}
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
@@ -3932,13 +3757,6 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
return q.db.GetUserCount(ctx, includeSystem)
}
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.GetUserGroupSpendLimit(ctx, userID)
}
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
@@ -4608,13 +4426,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
}
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return database.AIBridgeModelThought{}, err
}
return q.db.InsertAIBridgeModelThought(ctx, arg)
}
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
// All aibridge_token_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
@@ -4671,16 +4482,16 @@ func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFil
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
}
func (q *querier) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
return database.ChatMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return nil, err
return database.ChatMessage{}, err
}
return q.db.InsertChatMessages(ctx, arg)
return q.db.InsertChatMessage(ctx, arg)
}
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
@@ -4807,13 +4618,6 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
return q.db.InsertLicense(ctx, arg)
}
func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.InsertMCPServerConfig(ctx, arg)
}
func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
return database.WorkspaceAgentMemoryResourceMonitor{}, err
@@ -5311,20 +5115,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitGroupOverrides(ctx)
}
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitOverrides(ctx)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
}
@@ -5444,13 +5234,6 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.ResolveUserChatSpendLimit(ctx, userID)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -5466,32 +5249,6 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
return q.db.SelectUsageEventsForPublishing(ctx, arg)
}
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
msg, err := q.db.GetChatMessageByID(ctx, id)
if err != nil {
return err
}
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteChatMessageByID(ctx, id)
}
func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
@@ -5573,17 +5330,6 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatMCPServerIDs(ctx, arg)
}
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
// Authorize update on the parent chat of the edited message.
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
@@ -5748,13 +5494,6 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da
return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args)
}
func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.UpdateMCPServerConfig(ctx, arg)
}
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
@@ -6678,13 +6417,6 @@ func (q *querier) UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg datab
return q.db.UpdateWorkspacesTTLByTemplateID(ctx, arg)
}
func (q *querier) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return false, err
}
return q.db.UpsertAISeatState(ctx, arg)
}
func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -6706,13 +6438,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
@@ -6744,27 +6469,6 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
return q.db.UpsertChatSystemPrompt(ctx, value)
}
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitUserOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6800,13 +6504,6 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
return q.db.UpsertLogoURL(ctx, value)
}
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.UpsertMCPServerUserToken(ctx, arg)
}
func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
@@ -6959,13 +6656,6 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
}
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
return false, err
}
return q.db.UsageEventExistsByID(ctx, id)
}
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
// This check is probably overly restrictive, but the "correct" check isn't
// necessarily obvious. It's only used as a verification check for ACLs right
@@ -7061,7 +6751,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
+32 -360
View File
@@ -401,27 +401,16 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.SoftDeleteChatMessagesAfterIDParams{
arg := database.DeleteChatMessagesAfterIDParams{
ChatID: chat.ID,
AfterID: 123,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("SoftDeleteChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := database.ChatMessage{
ID: 456,
ChatID: chat.ID,
}
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), msg.ID).Return(nil).AnyTimes()
check.Args(msg.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
@@ -524,10 +513,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
}))
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
@@ -573,14 +558,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := database.GetChatMessagesByChatIDDescPaginatedParams{ChatID: chat.ID, BeforeID: 0, LimitVal: 50}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -629,17 +606,12 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because it re-routes through GetChats which uses SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -652,10 +624,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
@@ -686,13 +654,13 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
}))
s.Run("InsertChatMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessagesParams{ChatID: chat.ID})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatMessages(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msgs)
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
}))
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -865,254 +833,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetUserChatSpendInPeriodParams{
UserID: uuid.New(),
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
spend := int64(123)
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
}))
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(456)
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(789)
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: true,
DefaultLimitMicros: 1_000_000,
Period: "monthly",
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
override := database.GetChatUsageLimitGroupOverrideRow{
GroupID: groupID,
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
override := database.GetChatUsageLimitUserOverrideRow{
UserID: userID,
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
GroupID: uuid.New(),
GroupName: "group-name",
GroupDisplayName: "Group Name",
GroupAvatarUrl: "https://example.com/group.png",
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
MemberCount: 5,
}}
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitOverridesRow{{
UserID: uuid.New(),
Username: "usage-limit-user",
Name: "Usage Limit User",
AvatarURL: "https://example.com/avatar.png",
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
}}
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
arg := database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 6_000_000,
Period: "monthly",
}
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: arg.Enabled,
DefaultLimitMicros: arg.DefaultLimitMicros,
Period: arg.Period,
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitGroupOverrideParams{
SpendLimitMicros: 7_000_000,
GroupID: uuid.New(),
}
override := database.UpsertChatUsageLimitGroupOverrideRow{
GroupID: arg.GroupID,
Name: "group",
DisplayName: "Group",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitUserOverrideParams{
SpendLimitMicros: 8_000_000,
UserID: uuid.New(),
}
override := database.UpsertChatUsageLimitUserOverrideRow{
UserID: arg.UserID,
Username: "user",
Name: "User",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
slug := "test-mcp-server"
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug})
dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes()
check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
ids := []uuid.UUID{configA.ID, configB.ID}
dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token)
}))
s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
userID := uuid.New()
tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})}
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
}))
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertMCPServerConfigParams{
DisplayName: "Test MCP Server",
Slug: "test-mcp-server",
}
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug})
dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatMCPServerIDsParams{
ID: chat.ID,
MCPServerIDs: []uuid.UUID{uuid.New()},
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
arg := database.UpdateMCPServerConfigParams{
ID: config.ID,
DisplayName: "Updated MCP Server",
Slug: "updated-mcp-server",
}
dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
AccessToken: "test-access-token",
TokenType: "bearer",
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1435,14 +1155,6 @@ func (s *MethodTestSuite) TestProvisionerJob() {
}
func (s *MethodTestSuite) TestLicense() {
s.Run("GetActiveAISeatCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetActiveAISeatCount(gomock.Any()).Return(int64(100), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead).Returns(int64(100))
}))
s.Run("UpsertAISeatState", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertAISeatState(gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes()
check.Args(database.UpsertAISeatStateParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
a := database.License{ID: 1}
b := database.License{ID: 2}
@@ -1489,8 +1201,8 @@ func (s *MethodTestSuite) TestLicense() {
check.Args().Asserts().Returns("value")
}))
s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"})
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"})
}))
s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes()
@@ -1612,7 +1324,7 @@ func (s *MethodTestSuite) TestOrganization() {
org := testutil.Fake(s.T(), faker, database.Organization{})
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
WorkspaceSharingDisabled: true,
}
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
@@ -2043,26 +1755,6 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsSummaryParams{}
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsTimeSeriesParams{}
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsPerModelParams{}
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsRecentPRsParams{}
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
@@ -2551,12 +2243,9 @@ func (s *MethodTestSuite) TestWorkspace() {
check.Args(w.ID).Asserts(w, policy.ActionShare)
}))
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: uuid.New(),
ExcludeServiceAccounts: false,
}
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
orgID := uuid.New()
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
@@ -5262,12 +4951,6 @@ func (s *MethodTestSuite) TestUsageEvents() {
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
}))
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
id := uuid.NewString()
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
}))
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
now := dbtime.Now()
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
@@ -5328,17 +5011,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params).Asserts(intc, policy.ActionCreate)
}))
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
@@ -5697,16 +5369,12 @@ func TestAsChatd(t *testing.T) {
require.NoError(t, err, "chat %s should be allowed", action)
}
// Workspace read + update (update needed for ActivityBumpWorkspace).
for _, action := range []policy.Action{
policy.ActionRead, policy.ActionUpdate,
} {
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
require.NoError(t, err, "workspace %s should be allowed", action)
}
// Workspace read.
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace)
require.NoError(t, err, "workspace read should be allowed")
// DeploymentConfig read.
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
require.NoError(t, err, "deployment config read should be allowed")
// User read_personal (needed for GetUserChatCustomPrompt).
@@ -5717,12 +5385,16 @@ func TestAsChatd(t *testing.T) {
t.Run("DeniedActions", func(t *testing.T) {
t.Parallel()
// Cannot delete workspaces.
err := auth.Authorize(ctx, actor, policy.ActionDelete, rbac.ResourceWorkspace)
require.Error(t, err, "workspace delete should be denied")
// Cannot write workspaces.
for _, action := range []policy.Action{
policy.ActionUpdate, policy.ActionDelete,
} {
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
require.Error(t, err, "workspace %s should be denied", action)
}
// Cannot access users.
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
require.Error(t, err, "user read should be denied")
// Cannot access API keys.
+1 -2
View File
@@ -29,7 +29,6 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/rbac/rolestore"
"github.com/coder/coder/v2/coderd/util/slice"
)
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
UUID: pair.OrganizationID,
Valid: pair.OrganizationID != uuid.Nil,
},
IsSystem: rolestore.IsSystemRoleName(pair.Name),
IsSystem: rbac.SystemRoleName(pair.Name),
ID: uuid.New(),
})
}
+25 -17
View File
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
})
require.NoError(t, err, "insert organization")
// Populate the placeholder system roles (created by DB
// trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileSystemRole needs the system:update
// Populate the placeholder organization-member system role (created by
// DB trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
// permission that `genCtx` does not have.
sysCtx := dbauthz.AsSystemRestricted(genCtx)
for roleName := range rolestore.SystemRoleNames {
role := database.CustomRole{
Name: roleName,
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
}
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
require.NoError(t, err, "create role "+roleName)
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
}
require.NoError(t, err, "reconcile role "+roleName)
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
require.NoError(t, err, "create organization-member role")
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
}
require.NoError(t, err, "reconcile organization-member role")
return org
}

Some files were not shown because too many files have changed in this diff Show More