Compare commits

..

9 Commits

Author SHA1 Message Date
Atif Ali 337f4474c4 Merge branch 'main' into feat/windows-install-script 2026-03-17 15:28:59 +05:00
Atif Ali 6e09ddc3c1 fix(site): remove duplicated bin handler after rebase 2026-03-16 17:10:54 +00:00
M Atif Ali 9cfd7ad394 fix(install): address remaining PR review feedback 2026-03-16 17:08:02 +00:00
blink-so[bot] 0e3c880455 fix: improve error handling in install.ps1 per Copilot review 2026-03-16 17:08:02 +00:00
blink-so[bot] 97c245c92c test: ignore install.ps1 warning in SpammyLogs test 2026-03-16 17:08:02 +00:00
blink-so[bot] d0083cdb06 fix: address Copilot review feedback
- Use PROCESSOR_ARCHITECTURE env var for better PowerShell 5.1 compatibility
- Add -ErrorAction Stop to Invoke-WebRequest
- Fix documentation to use env:USERPROFILE instead of HOME
- Use Join-Path consistently for path construction
- Case-insensitive PATH comparison for Windows
- Handle empty PATH variable case
- Normalize paths for binary conflict detection
2026-03-16 17:08:02 +00:00
blink-so[bot] 7742854f10 fix: use throw instead of exit to avoid terminating user session
When running via irm ... | iex, using exit 1 would terminate the
entire PowerShell host session. Using throw instead allows the
installer to fail gracefully without closing the user terminal.
2026-03-16 17:08:02 +00:00
blink-so[bot] 926b568a60 fix: address review feedback for install.ps1
- Fix PATH comparison to normalize paths (handle trailing slashes)
- Improve post-install messaging to match install.sh style
  (show PATH extension command when binary not found in PATH)
2026-03-16 17:08:02 +00:00
blink-so[bot] 775d26de97 feat: add Windows PowerShell install script (install.ps1)
This adds a PowerShell install script for Windows users, similar to the
existing install.sh for Linux/macOS. The script:

- Downloads the Coder CLI binary from the Coder server
- Supports both amd64 and arm64 architectures
- Caches downloads in %LOCALAPPDATA%\coder\local_downloads
- Installs to ~/.coder/bin by default (configurable)
- Adds the install directory to the user PATH
- Supports dry-run mode

Usage:
  irm https://your-coder-server/install.ps1 | iex

The script is served from /install.ps1 on the Coder server with the
same template processing as install.sh (injects Origin and Version).
2026-03-16 17:08:02 +00:00
421 changed files with 12644 additions and 39897 deletions
-72
View File
@@ -1,72 +0,0 @@
---
name: pull-requests
description: "Guide for creating, updating, and following up on pull requests in the Coder repository. Use when asked to open a PR, update a PR, rewrite a PR description, or follow up on CI/check failures."
---
# Pull Request Skill
## When to Use This Skill
Use this skill when asked to:
- Create a pull request for the current branch.
- Update an existing PR branch or description.
- Rewrite a PR body.
- Follow up on CI or check failures for an existing PR.
## References
Use the canonical docs for shared conventions and validation guidance:
- PR title and description conventions:
`.claude/docs/PR_STYLE_GUIDE.md`
- Local validation commands and git hooks: `AGENTS.md` (Essential Commands and
Git Hooks sections)
## Lifecycle Rules
1. **Check for an existing PR** before creating a new one:
```bash
gh pr list --head "$(git branch --show-current)" --author @me --json number --jq '.[0].number // empty'
```
If that returns a number, update that PR. If it returns empty output,
create a new one.
2. **Check you are not on main.** If the current branch is `main` or `master`,
create a feature branch before doing PR work.
3. **Default to draft.** Use `gh pr create --draft` unless the user explicitly
asks for ready-for-review.
4. **Keep description aligned with the full diff.** Re-read the diff against
the base branch before writing or updating the title and body. Describe the
entire PR diff, not just the last commit.
5. **Never auto-merge.** Do not merge or mark ready for review unless the user
explicitly asks.
6. **Never push to main or master.**
## CI / Checks Follow-up
**Always watch CI checks after pushing.** Do not push and walk away.
After pushing:
- Monitor CI with `gh pr checks <PR_NUMBER> --watch`.
- Use `gh pr view <PR_NUMBER> --json statusCheckRollup` for programmatic check
status.
If checks fail:
1. Find the failed run ID from the `gh pr checks` output.
2. Read the logs with `gh run view <run-id> --log-failed`.
3. Fix the problem locally.
4. Run `make pre-commit`.
5. Push the fix.
## What Not to Do
- Do not reference or call helper scripts that do not exist in this
repository.
- Do not auto-merge or mark ready for review without explicit user request.
- Do not push to `origin/main` or `origin/master`.
- Do not skip local validation before pushing.
- Do not fabricate or embellish PR descriptions.
-9
View File
@@ -1,9 +0,0 @@
paths:
# The triage workflow uses a quoted heredoc (<<'EOF') with ${VAR}
# placeholders that envsubst expands later. Shellcheck's SC2016
# warns about unexpanded variables in single-quoted strings, but
# the non-expansion is intentional here. Actionlint doesn't honor
# inline shellcheck disable directives inside heredocs.
.github/workflows/triage-via-chat-api.yaml:
ignore:
- 'SC2016'
+33 -33
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -191,7 +191,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@631208b7aac2daa8b707f55e7331f9112b0e062d # v1.44.0
uses: crate-ci/typos@2d0ce569feab1f8752f1dde43cc2f2aa53236e06 # v1.40.0
with:
config: .github/workflows/typos.toml
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -537,7 +537,7 @@ jobs:
embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }}
- name: Upload failed test db dumps
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -818,7 +818,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -826,7 +826,7 @@ jobs:
- name: Upload debug log
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coderd-debug-logs${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/e2e/test-results/debug.log
@@ -834,7 +834,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1108,7 +1108,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1319,7 +1319,7 @@ jobs:
id: attest_main
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1356,7 +1356,7 @@ jobs:
id: attest_latest
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1393,7 +1393,7 @@ jobs:
id: attest_version
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1457,7 +1457,7 @@ jobs:
- name: Upload build artifact (coder-linux-amd64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.tar.gz
path: ./build/*_linux_amd64.tar.gz
@@ -1465,7 +1465,7 @@ jobs:
- name: Upload build artifact (coder-linux-amd64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.deb
path: ./build/*_linux_amd64.deb
@@ -1473,7 +1473,7 @@ jobs:
- name: Upload build artifact (coder-linux-arm64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-arm64.tar.gz
path: ./build/*_linux_arm64.tar.gz
@@ -1481,7 +1481,7 @@ jobs:
- name: Upload build artifact (coder-linux-arm64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-arm64.deb
path: ./build/*_linux_arm64.deb
@@ -1489,7 +1489,7 @@ jobs:
- name: Upload build artifact (coder-linux-armv7.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-armv7.tar.gz
path: ./build/*_linux_armv7.tar.gz
@@ -1497,7 +1497,7 @@ jobs:
- name: Upload build artifact (coder-linux-armv7.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-armv7.deb
path: ./build/*_linux_armv7.deb
@@ -1505,7 +1505,7 @@ jobs:
- name: Upload build artifact (coder-windows-amd64.zip)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-windows-amd64.zip
path: ./build/*_windows_amd64.zip
@@ -1543,7 +1543,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-38
View File
@@ -23,44 +23,6 @@ permissions:
concurrency: pr-${{ github.ref }}
jobs:
community-label:
runs-on: ubuntu-latest
permissions:
pull-requests: write
if: >-
${{
github.event_name == 'pull_request_target' &&
github.event.action == 'opened' &&
github.event.pull_request.author_association != 'MEMBER' &&
github.event.pull_request.author_association != 'COLLABORATOR' &&
github.event.pull_request.author_association != 'OWNER'
}}
steps:
- name: Add community label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const params = {
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
}
const labels = context.payload.pull_request.labels.map((label) => label.name)
if (labels.includes("community")) {
console.log('PR already has "community" label.')
return
}
console.log(
'Adding "community" label for author association "%s".',
context.payload.pull_request.author_association,
)
await github.rest.issues.addLabels({
...params,
labels: ["community"],
})
cla:
runs-on: ubuntu-latest
permissions:
+4 -4
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -142,7 +142,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -48,7 +48,7 @@ jobs:
persist-credentials: false
- name: Docker login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v45.0.7
- uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.7
id: changed-files
with:
files: |
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -78,11 +78,11 @@ jobs:
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -30,7 +30,7 @@ jobs:
- name: Sync issues
id: sync
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
@@ -52,7 +52,7 @@ jobs:
- name: Complete release
id: complete
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+6 -6
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -14,12 +14,12 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Run Schmoder CI
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
with:
workflow: ci.yaml
repo: coder/schmoder
+12 -10
View File
@@ -80,7 +80,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -155,7 +155,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -358,7 +358,7 @@ jobs:
id: attest_base
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -474,7 +474,7 @@ jobs:
id: attest_main
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -518,7 +518,7 @@ jobs:
id: attest_latest
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -665,7 +665,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: release-artifacts
path: |
@@ -681,7 +681,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -700,11 +700,13 @@ jobs:
name: Publish to Homebrew tap
runs-on: ubuntu-latest
needs: release
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
if: ${{ !inputs.dry_run }}
steps:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -780,7 +782,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: SARIF file
path: results.sarif
+4 -4
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -69,7 +69,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
@@ -160,7 +160,7 @@ jobs:
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: trivy
path: trivy-results.sarif
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-295
View File
@@ -1,295 +0,0 @@
# This workflow reimplements the AI Triage Automation using the Coder Chat API
# instead of the Tasks API. The Chat API (/api/experimental/chats) is a simpler
# interface that does not require a dedicated GitHub Action or workspace
# provisioning — we just create a chat, poll for completion, and link the
# result on the issue. All API calls use curl + jq directly.
#
# Key differences from the Tasks API workflow (traiage.yaml):
# - No checkout of coder/create-task-action; everything is inline curl/jq.
# - No template_name / template_preset / prefix inputs — the Chat API handles
# resource allocation internally.
# - Uses POST /api/experimental/chats to create a chat session.
# - Polls GET /api/experimental/chats/<id> until the agent finishes.
# - Chat URL format: ${CODER_URL}/agents?chat=${CHAT_ID}
name: AI Triage via Chat API
on:
issues:
types:
- labeled
workflow_dispatch:
inputs:
issue_url:
description: "GitHub Issue URL to process"
required: true
type: string
permissions:
contents: read
jobs:
triage-chat:
name: Triage GitHub Issue via Chat API
runs-on: ubuntu-latest
if: github.event.label.name == 'chat-triage' || github.event_name == 'workflow_dispatch'
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
permissions:
contents: read
issues: write
steps:
# ------------------------------------------------------------------
# Step 1: Determine the GitHub user and issue URL.
# Identical to the Tasks API workflow — resolve the actor for
# workflow_dispatch or the issue sender for label events.
# ------------------------------------------------------------------
- name: Determine Inputs
id: determine-inputs
if: always()
env:
GITHUB_ACTOR: ${{ github.actor }}
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# For workflow_dispatch, use the actor who triggered it.
# For issues events, use the issue sender.
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
exit 1
fi
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
exit 0
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
exit 0
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
# ------------------------------------------------------------------
# Step 2: Verify the triggering user has push access.
# Unchanged from the Tasks API workflow.
# ------------------------------------------------------------------
- name: Verify push access
env:
GITHUB_REPOSITORY: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
set -euo pipefail
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
if [[ "${can_push}" != "true" ]]; then
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
exit 1
fi
# ------------------------------------------------------------------
# Step 3: Create a chat via the Coder Chat API.
# Unlike the Tasks API which provisions a full workspace, the Chat
# API creates a lightweight chat session. We POST to
# /api/experimental/chats with the triage prompt as the initial
# message and receive a chat ID back.
# ------------------------------------------------------------------
- name: Create chat via Coder Chat API
id: create-chat
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# Build the same triage prompt used by the Tasks API workflow.
TASK_PROMPT=$(cat <<'EOF'
Fix ${ISSUE_URL}
1. Use the gh CLI to read the issue description and comments.
2. Think carefully and try to understand the root cause. If the issue is unclear or not well defined, ask me to clarify and provide more information.
3. Write a proposed implementation plan to PLAN.md for me to review before starting implementation. Your plan should use TDD and only make the minimal changes necessary to fix the root cause.
4. When I approve your plan, start working on it. If you encounter issues with the plan, ask me for clarification and update the plan as required.
5. When you have finished implementation according to the plan, commit and push your changes, and create a PR using the gh CLI for me to review.
EOF
)
# Perform variable substitution on the prompt — scoped to $ISSUE_URL only.
# Using envsubst without arguments would expand every env var in scope
# (including CODER_SESSION_TOKEN), so we name the variable explicitly.
TASK_PROMPT=$(echo "${TASK_PROMPT}" | envsubst '$ISSUE_URL')
echo "Creating chat with prompt:"
echo "${TASK_PROMPT}"
# POST to the Chat API to create a new chat session.
RESPONSE=$(curl --silent --fail-with-body \
-X POST \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
-H "Content-Type: application/json" \
-d "$(jq -n --arg prompt "${TASK_PROMPT}" \
'{content: [{type: "text", text: $prompt}]}')" \
"${CODER_URL}/api/experimental/chats")
echo "Chat API response:"
echo "${RESPONSE}" | jq .
CHAT_ID=$(echo "${RESPONSE}" | jq -r '.id')
CHAT_STATUS=$(echo "${RESPONSE}" | jq -r '.status')
if [[ -z "${CHAT_ID}" || "${CHAT_ID}" == "null" ]]; then
echo "::error::Failed to create chat — no ID returned"
echo "Response: ${RESPONSE}"
exit 1
fi
# Validate that CHAT_ID is a UUID before using it in URL paths.
# This guards against unexpected API responses being interpolated
# into subsequent curl calls.
if [[ ! "${CHAT_ID}" =~ ^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$ ]]; then
echo "::error::CHAT_ID is not a valid UUID: ${CHAT_ID}"
exit 1
fi
CHAT_URL="${CODER_URL}/agents?chat=${CHAT_ID}"
echo "Chat created: ${CHAT_ID} (status: ${CHAT_STATUS})"
echo "Chat URL: ${CHAT_URL}"
echo "chat_id=${CHAT_ID}" >> "${GITHUB_OUTPUT}"
echo "chat_url=${CHAT_URL}" >> "${GITHUB_OUTPUT}"
# ------------------------------------------------------------------
# Step 4: Poll the chat status until the agent finishes.
# The Chat API is asynchronous — after creation the agent begins
# working in the background. We poll GET /api/experimental/chats/<id>
# every 5 seconds until the status is "waiting" (agent needs input),
# "completed" (agent finished), or "error". Timeout after 10 minutes.
# ------------------------------------------------------------------
- name: Poll chat status
id: poll-status
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
run: |
set -euo pipefail
POLL_INTERVAL=5
# 10 minutes = 600 seconds.
TIMEOUT=600
ELAPSED=0
echo "Polling chat ${CHAT_ID} every ${POLL_INTERVAL}s (timeout: ${TIMEOUT}s)..."
while true; do
RESPONSE=$(curl --silent --fail-with-body \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
"${CODER_URL}/api/experimental/chats/${CHAT_ID}")
STATUS=$(echo "${RESPONSE}" | jq -r '.status')
echo "[${ELAPSED}s] Chat status: ${STATUS}"
case "${STATUS}" in
waiting|completed)
echo "Chat reached terminal status: ${STATUS}"
echo "final_status=${STATUS}" >> "${GITHUB_OUTPUT}"
exit 0
;;
error)
echo "::error::Chat entered error state"
echo "${RESPONSE}" | jq .
echo "final_status=error" >> "${GITHUB_OUTPUT}"
exit 1
;;
pending|running)
# Still working — keep polling.
;;
*)
echo "::warning::Unknown chat status: ${STATUS}"
;;
esac
if [[ ${ELAPSED} -ge ${TIMEOUT} ]]; then
echo "::error::Timed out after ${TIMEOUT}s waiting for chat to finish"
echo "final_status=timeout" >> "${GITHUB_OUTPUT}"
exit 1
fi
sleep "${POLL_INTERVAL}"
ELAPSED=$((ELAPSED + POLL_INTERVAL))
done
# ------------------------------------------------------------------
# Step 5: Comment on the GitHub issue with a link to the chat.
# Only comment if the issue belongs to this repository (same guard
# as the Tasks API workflow).
# ------------------------------------------------------------------
- name: Comment on issue
if: startsWith(steps.determine-inputs.outputs.issue_url, format('{0}/{1}', github.server_url, github.repository))
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
COMMENT_BODY=$(cat <<EOF
🤖 **AI Triage Chat Created**
A Coder chat session has been created to investigate this issue.
**Chat URL:** ${CHAT_URL}
**Chat ID:** \`${CHAT_ID}\`
**Status:** ${FINAL_STATUS}
The agent is working on a triage plan. Visit the chat to follow progress or provide guidance.
EOF
)
gh issue comment "${ISSUE_URL}" --body "${COMMENT_BODY}"
echo "Comment posted on ${ISSUE_URL}"
# ------------------------------------------------------------------
# Step 6: Write a summary to the GitHub Actions step summary.
# ------------------------------------------------------------------
- name: Write summary
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
set -euo pipefail
{
echo "## AI Triage via Chat API"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Chat ID:** \`${CHAT_ID}\`"
echo "**Chat URL:** ${CHAT_URL}"
echo "**Status:** ${FINAL_STATUS}"
} >> "${GITHUB_STEP_SUMMARY}"
-2
View File
@@ -29,8 +29,6 @@ EDE = "EDE"
HELO = "HELO"
LKE = "LKE"
byt = "byt"
cpy = "cpy"
Cpy = "Cpy"
typ = "typ"
# file extensions used in seti icon theme
styl = "styl"
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+5 -19
View File
@@ -146,20 +146,11 @@ git config core.hooksPath scripts/githooks
Two hooks run automatically:
- **pre-commit**: Classifies staged files by type and runs either
the full `make pre-commit` or the lightweight `make pre-commit-light`
depending on whether Go, TypeScript, SQL, proto, or Makefile
changes are present. Falls back to the full target when
`CODER_HOOK_RUN_ALL=1` is set. A markdown-only commit takes
seconds; a Go change takes several minutes.
- **pre-push**: Classifies changed files (vs remote branch or
merge-base) and runs `make pre-push` when Go, TypeScript, SQL,
proto, or Makefile changes are detected. Skips tests entirely
for lightweight changes. Allowlisted in
`scripts/githooks/pre-push`. Runs only for developers who opt
in. Falls back to `make pre-push` when the diff range can't
be determined or `CODER_HOOK_RUN_ALL=1` is set. Allow at least
15 minutes for a full run.
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (heavier checks including tests).
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
who opt in. Allow at least 15 minutes.
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
@@ -217,11 +208,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
- PR titles follow the same `type(scope): message` format.
- When you use a scope, it must be a real filesystem path containing every
changed file.
- Use a broader path scope, or omit the scope, for cross-cutting changes.
- Example: `fix(coderd/chatd): ...` for changes only in `coderd/chatd/`.
### Frontend Patterns
+11 -39
View File
@@ -136,10 +136,18 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -506,10 +514,7 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
@@ -522,10 +527,6 @@ RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
.PHONY: fmt
# Subset of fmt that does not require Go or Node toolchains.
fmt-light: fmt/shfmt fmt/terraform fmt/markdown
.PHONY: fmt-light
fmt/go:
ifdef FILE
# Format single file
@@ -633,10 +634,6 @@ LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
.PHONY: lint
# Subset of lint that does not require Go or Node toolchains.
lint-light: lint/shellcheck lint/markdown lint/helm lint/bootstrap lint/migrations lint/actions/actionlint lint/typos
.PHONY: lint-light
lint/site-icons:
./scripts/check_site_icons.sh
.PHONY: lint/site-icons
@@ -779,25 +776,6 @@ pre-commit:
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit
# Lightweight pre-commit for changes that don't touch Go or
# TypeScript. Skips gen, lint/go, lint/ts, fmt/go, fmt/ts, and
# the binary build. Used by the pre-commit hook when only docs,
# shell, terraform, helm, or other fast-to-check files changed.
pre-commit-light:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit-light.XXXXXX")
echo "$(BOLD)pre-commit-light$(RESET) ($$logdir)"
echo "fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir fmt-light
$(check-unstaged)
echo "lint:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir lint-light
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit-light passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit-light
pre-push:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
@@ -806,7 +784,6 @@ pre-push:
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
test \
test-js \
test-storybook \
site/out/index.html
rm -rf $$logdir
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
@@ -1341,11 +1318,6 @@ test-js: site/node_modules/.installed
pnpm test:ci
.PHONY: test-js
test-storybook: site/node_modules/.installed
cd site/
pnpm exec vitest run --project=storybook
.PHONY: test-storybook
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
# dependency for any sqlc-cloud related targets.
sqlc-cloud-is-setup:
+2 -7
View File
@@ -385,16 +385,11 @@ func (a *agent) init() {
pathStore := agentgit.NewPathStore()
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
-14
View File
@@ -57,26 +57,18 @@ type fakeContainerCLI struct {
}
func (f *fakeContainerCLI) List(_ context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.containers, f.listErr
}
func (f *fakeContainerCLI) DetectArchitecture(_ context.Context, _ string) (string, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.arch, f.archErr
}
func (f *fakeContainerCLI) Copy(ctx context.Context, name, src, dst string) error {
f.mu.Lock()
defer f.mu.Unlock()
return f.copyErr
}
func (f *fakeContainerCLI) ExecAs(ctx context.Context, name, user string, args ...string) ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()
return nil, f.execErr
}
@@ -2697,9 +2689,7 @@ func TestAPI(t *testing.T) {
// When: The container is recreated (new container ID) with config changes.
terraformContainer.ID = "new-container-id"
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = terraformContainer.ID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
@@ -2831,9 +2821,7 @@ func TestAPI(t *testing.T) {
// Simulate container rebuild: new container ID, changed display apps.
newContainerID := "new-container-id"
terraformContainer.ID = newContainerID
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = newContainerID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
@@ -4938,11 +4926,9 @@ func TestDevcontainerPrebuildSupport(t *testing.T) {
)
api.Start()
fCCLI.mu.Lock()
fCCLI.containers = codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{testContainer},
}
fCCLI.mu.Unlock()
// Given: We allow the dev container to be created.
fDCCLI.upID = testContainer.ID
+169 -24
View File
@@ -2,9 +2,13 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -20,6 +24,28 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
@@ -52,31 +78,43 @@ type screenshotOutput struct {
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
dataDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: execer,
dataDir: dataDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
)
p.binPath = scriptBinPath
p.binPath = cachedPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
}
@@ -2,6 +2,11 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := t.Context()
ctx := context.Background()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(t.Context())
x, y, err := pd.CursorPosition(context.Background())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Start(ctx)
require.NoError(t, err)
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, binPath, pd.binPath)
assert.Equal(t, cachedPath, pd.binPath)
}
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
func TestEnsureBinary_Downloads(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
}
// Clear PATH so LookPath won't find a real binary.
// Ensure PATH doesn't contain a real portabledesktop binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
}
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+13 -85
View File
@@ -333,68 +333,22 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
return status, err
}
// Check if the target already exists so we can preserve its
// permissions on the temp file before rename.
var origMode os.FileMode
var haveOrigMode bool
if stat, serr := api.filesystem.Stat(path); serr == nil {
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: is a directory", path)
}
origMode = stat.Mode()
haveOrigMode = true
}
// Write to a temp file in the same directory so the rename is
// always on the same device (atomic).
tmpfile, err := afero.TempFile(api.filesystem, dir, filepath.Base(path))
f, err := api.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
return status, err
}
tmpName := tmpfile.Name()
defer f.Close()
_, err = io.Copy(tmpfile, r.Body)
if err != nil && !errors.Is(err, io.EOF) {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if haveOrigMode {
if err := api.filesystem.Chmod(tmpName, origMode); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
}
if err := api.filesystem.Rename(tmpName, path); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
api.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}
return 0, nil
@@ -506,44 +460,18 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
if err != nil {
return http.StatusInternalServerError, err
}
tmpName := tmpfile.Name()
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if err := api.filesystem.Chmod(tmpName, stat.Mode()); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
err = api.filesystem.Rename(tmpName, path)
err = api.filesystem.Rename(tmpfile.Name(), path)
if err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
return http.StatusInternalServerError, err
}
return 0, nil
-139
View File
@@ -14,7 +14,6 @@ import (
"strings"
"syscall"
"testing"
"testing/iotest"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -400,83 +399,6 @@ func TestWriteFile(t *testing.T) {
}
}
func TestWriteFile_ReportsIOError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, nil)
tmpdir := os.TempDir()
path := filepath.Join(tmpdir, "write-io-error")
err := afero.WriteFile(fs, path, []byte("original"), 0o644)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// A reader that always errors simulates a failed body read
// (e.g. network interruption). The atomic write should leave
// the original file intact.
body := iotest.ErrReader(xerrors.New("simulated I/O error"))
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path), body)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
got := &codersdk.Error{}
err = json.NewDecoder(w.Body).Decode(got)
require.NoError(t, err)
require.ErrorContains(t, got, "simulated I/O error")
// The original file must survive the failed write.
data, err := afero.ReadFile(fs, path)
require.NoError(t, err)
require.Equal(t, "original", string(data))
}
func TestWriteFile_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Overwrite the file with new content.
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path),
bytes.NewReader([]byte("#!/bin/sh\necho world\n")))
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"write_file should preserve the original file's permissions")
}
func TestEditFiles(t *testing.T) {
t.Parallel()
@@ -985,67 +907,6 @@ func TestEditFiles(t *testing.T) {
}
}
func TestEditFiles_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
// Sanity-check the initial mode.
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
body := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: path,
Edits: []workspacesdk.FileEdit{
{
Search: "hello",
Replace: "world",
},
},
},
},
}
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// Verify content was updated.
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
// Verify permissions are preserved after the
// temp-file-and-rename cycle.
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"edit_files should preserve the original file's permissions")
}
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
+2 -58
View File
@@ -1,13 +1,11 @@
package agentproc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -20,13 +18,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// maxWaitDuration is the maximum time a blocking
// process output request can wait, regardless of
// what the client requests.
maxWaitDuration = 5 * time.Minute
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
@@ -35,10 +26,10 @@ type API struct {
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv, workingDir),
manager: newManager(logger, execer, updateEnv),
pathStore: pathStore,
}
}
@@ -160,42 +151,6 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
return
}
// Enforce chat ID isolation. If the request carries
// a chat context, only allow access to processes
// belonging to that chat.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
if proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
// Check for blocking mode via query params.
waitStr := r.URL.Query().Get("wait")
wantWait := waitStr == "true"
if wantWait {
// Extend the write deadline so the HTTP server's
// WriteTimeout does not kill the connection while
// we block.
rc := http.NewResponseController(rw)
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil {
api.logger.Error(ctx, "extend write deadline for blocking process output",
slog.Error(err),
)
}
// Cap the wait at maxWaitDuration regardless of
// client-supplied timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, maxWaitDuration)
defer waitCancel()
_ = proc.waitForOutput(waitCtx)
// Fall through to read snapshot below.
}
output, truncated := proc.output()
info := proc.info()
@@ -213,17 +168,6 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// Enforce chat ID isolation.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
proc, procOK := api.manager.get(id)
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
+3 -277
View File
@@ -7,10 +7,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -78,22 +76,6 @@ func getOutput(t *testing.T, handler http.Handler, id string) *httptest.Response
return w
}
// getOutputWithHeaders sends a GET /{id}/output request with
// custom headers and returns the recorder.
func getOutputWithHeaders(t *testing.T, handler http.Handler, id string, headers http.Header) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
path := fmt.Sprintf("/%s/output", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
for k, v := range headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
@@ -115,25 +97,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, nil, nil)
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, updateEnv, nil)
}
// newTestAPIWithOptions creates a new API with optional
// updateEnv and workingDir hooks.
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
t.Cleanup(func() {
_ = api.Close()
})
@@ -278,100 +253,6 @@ func TestStartProcess(t *testing.T) {
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
t.Parallel()
// No working directory closure, so the process
// should fall back to $HOME. We verify through
// the process list API which reports the resolved
// working directory using native OS paths,
// avoiding shell path format mismatches on
// Windows (Git Bash returns POSIX paths).
handler := newTestAPI(t)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
t.Parallel()
// The closure provides a valid directory, so the
// process should start there. Use the marker file
// pattern to avoid path format mismatches on
// Windows.
tmpDir := t.TempDir()
handler := newTestAPIWithOptions(t, nil, func() string {
return tmpDir
})
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
t.Parallel()
// The closure returns a path that doesn't exist,
// so the process should fall back to $HOME.
handler := newTestAPIWithOptions(t, nil, func() string {
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
})
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
@@ -756,161 +637,6 @@ func TestProcessOutput(t *testing.T) {
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
t.Run("ChatIDEnforcement", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process with chat-a.
chatA := uuid.New()
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo secret",
Background: true,
}, http.Header{
workspacesdk.CoderChatIDHeader: {chatA.String()},
})
waitForExit(t, handler, id)
// Chat-b should NOT see this process.
chatB := uuid.New()
w1 := getOutputWithHeaders(t, handler, id, http.Header{
workspacesdk.CoderChatIDHeader: {chatB.String()},
})
require.Equal(t, http.StatusNotFound, w1.Code)
// Without any chat ID header, should return 200
// (backwards compatible).
w2 := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w2.Code)
})
t.Run("WaitForExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-wait && sleep 0.1",
})
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-wait")
})
t.Run("WaitAlreadyExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, id)
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.Contains(t, resp.Output, "done")
})
t.Run("WaitTimeout", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.IntervalMedium)
defer cancel()
w := getOutputWithWaitCtx(ctx, t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("ConcurrentWaiters", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
var (
wg sync.WaitGroup
resps [2]workspacesdk.ProcessOutputResponse
codes [2]int
)
for i := range 2 {
wg.Add(1)
go func() {
defer wg.Done()
w := getOutputWithWait(t, handler, id)
codes[i] = w.Code
_ = json.NewDecoder(w.Body).Decode(&resps[i])
}()
}
// Signal the process to exit so both waiters unblock.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
wg.Wait()
for i := range 2 {
require.Equal(t, http.StatusOK, codes[i], "waiter %d", i)
require.False(t, resps[i].Running, "waiter %d", i)
}
})
}
func getOutputWithWait(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
return getOutputWithWaitCtx(ctx, t, handler, id)
}
func getOutputWithWaitCtx(ctx context.Context, t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
path := fmt.Sprintf("/%s/output?wait=true", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
func TestSignalProcess(t *testing.T) {
@@ -1055,7 +781,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
return current, nil
}, pathStore, nil)
}, pathStore)
defer api.Close()
routes := api.Routes()
+2 -19
View File
@@ -39,13 +39,11 @@ const (
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
cond *sync.Cond
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
closed bool
totalBytes int
maxHead int
maxTail int
@@ -54,24 +52,20 @@ type HeadTailBuffer struct {
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// Write implements io.Writer. It is safe for concurrent use.
@@ -302,15 +296,6 @@ func truncateLines(s string) string {
return b.String()
}
// Close marks the buffer as closed and wakes any waiters.
// This is called when the process exits.
func (b *HeadTailBuffer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
b.cond.Broadcast()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
@@ -320,7 +305,5 @@ func (b *HeadTailBuffer) Reset() {
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.closed = false
b.totalBytes = 0
b.cond.Broadcast()
}
+17 -71
View File
@@ -70,25 +70,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
workingDir func() string
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
workingDir: workingDir,
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
@@ -111,7 +109,9 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
cmd.Dir = m.resolveWorkDir(req.WorkDir)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
@@ -158,7 +158,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc := &process{
id: id,
command: req.Command,
workDir: cmd.Dir,
workDir: req.WorkDir,
background: req.Background,
chatID: chatID,
cmd: cmd,
@@ -208,9 +208,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc.exitCode = &code
proc.mu.Unlock()
// Wake any waiters blocked on new output or
// process exit before closing the done channel.
proc.buf.Close()
close(proc.done)
}()
@@ -322,54 +319,3 @@ func (m *manager) Close() error {
return nil
}
// waitForOutput blocks until the buffer is closed (process
// exited) or the context is canceled. Returns nil when the
// buffer closed, ctx.Err() when the context expired.
func (p *process) waitForOutput(ctx context.Context) error {
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
nevermind := make(chan struct{})
defer close(nevermind)
go func() {
select {
case <-ctx.Done():
// Acquire the lock before broadcasting to
// guarantee the waiter has entered cond.Wait()
// (which atomically releases the lock).
// Without this, a Broadcast between the loop
// predicate check and cond.Wait() is lost.
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
p.buf.cond.Broadcast()
case <-nevermind:
}
}()
for ctx.Err() == nil && !p.buf.closed {
p.buf.cond.Wait()
}
return ctx.Err()
}
// resolveWorkDir returns the directory a process should start in.
// Priority: explicit request dir > agent configured dir > $HOME.
// Falls through when a candidate is empty or does not exist on
// disk, matching the behavior of SSH sessions.
func (m *manager) resolveWorkDir(requested string) string {
if requested != "" {
return requested
}
if m.workingDir != nil {
if dir := m.workingDir(); dir != "" {
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return dir
}
}
}
if home, err := os.UserHomeDir(); err == nil {
return home
}
return ""
}
+2 -2
View File
@@ -398,11 +398,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
},
})
if err != nil {
logger.Warn(ctx, "reporting script completed", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
}
})
if err != nil {
logger.Warn(ctx, "reporting script completed: track command goroutine", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
}
}()
+27 -6
View File
@@ -6,6 +6,7 @@ import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"github.com/google/uuid"
@@ -22,6 +23,26 @@ import (
"github.com/coder/coder/v2/testutil"
)
// logSink captures structured log entries for testing.
type logSink struct {
mu sync.Mutex
entries []slog.SinkEntry
}
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, e)
}
func (*logSink) Sync() {}
func (s *logSink) getEntries() []slog.SinkEntry {
s.mu.Lock()
defer s.mu.Unlock()
return append([]slog.SinkEntry{}, s.entries...)
}
// getField returns the value of a field by name from a slog.Map.
func getField(fields slog.Map, name string) interface{} {
for _, f := range fields {
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
sink := testutil.NewFakeSink(t)
logger := sink.Logger(slog.LevelInfo)
sink := &logSink{}
logger := slog.Make(sink)
workspaceID := uuid.New()
templateID := uuid.New()
templateVersionID := uuid.New()
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 1
return len(sink.getEntries()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
entries := sink.Entries()
entries := sink.getEntries()
require.Len(t, entries, 1)
entry := entries[0]
require.Equal(t, slog.LevelInfo, entry.Level)
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req2)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 2
return len(sink.getEntries()) >= 2
}, testutil.WaitShort, testutil.IntervalFast)
entries = sink.Entries()
entries = sink.getEntries()
entry = entries[1]
require.Len(t, entries, 2)
require.Equal(t, slog.LevelInfo, entry.Level)
-6
View File
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+1 -16
View File
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
Annotations struct {
ReadOnlyHint *bool `json:"readOnlyHint"`
DestructiveHint *bool `json:"destructiveHint"`
IdempotentHint *bool `json:"idempotentHint"`
OpenWorldHint *bool `json:"openWorldHint"`
} `json:"annotations"`
Name string `json:"name"`
} `json:"tools"`
} `json:"result"`
}
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
annotations := toolsResponse.Result.Tools[0].Annotations
require.NotNil(t, annotations.ReadOnlyHint)
require.NotNil(t, annotations.DestructiveHint)
require.NotNil(t, annotations.IdempotentHint)
require.NotNil(t, annotations.OpenWorldHint)
assert.True(t, *annotations.ReadOnlyHint)
assert.False(t, *annotations.DestructiveHint)
assert.True(t, *annotations.IdempotentHint)
assert.False(t, *annotations.OpenWorldHint)
// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
+1 -1
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("create role: %w", err)
return xerrors.Errorf("patch role: %w", err)
}
}
+1
View File
@@ -324,6 +324,7 @@ func TestServer(t *testing.T) {
ignoreLines := []string{
"isn't externally reachable",
"open install.sh: file does not exist",
"open install.ps1: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
-23
View File
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
)
build = workspace.LatestBuild
default:
// If the last build was a failed start, run a stop
// first to clean up any partially-provisioned
// resources.
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop,
})
if stopErr != nil {
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
}
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
if stopErr != nil {
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
}
// Re-fetch workspace after stop completes so
// startWorkspace sees the latest state.
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
}
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
// It's possible for a workspace build to fail due to the template requiring starting
// workspaces with the active version.
-52
View File
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
}
func TestStart_FailedStartCleansUp(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
store, ps := dbtestutil.NewDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: store,
Pubsub: ps,
IncludeProvisionerDaemon: true,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Insert a failed start build directly into the database so that
// the workspace's latest build is a failed "start" transition.
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
ID: workspace.ID,
OwnerID: member.ID,
OrganizationID: owner.OrganizationID,
TemplateID: template.ID,
}).
Seed(database.WorkspaceBuild{
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
}).
Failed().
Do()
inv, root := clitest.New(t, "start", workspace.Name)
clitest.SetupConfig(t, memberClient, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
// The CLI should detect the failed start and clean up first.
pty.ExpectMatch("Cleaning up before retrying")
pty.ExpectMatch("workspace has been started")
_ = testutil.TryReceive(ctx, t, doneChan)
}
+17 -26
View File
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
)
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
// Bypass rate limiting for support bundle collection since it makes many API calls.
// Note: this can only be done by the owner user.
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
cliLog.Debug(inv.Context(), "running as owner")
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
} else if !ok {
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
} else {
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
}
// Check if we're running inside a workspace
if val, found := os.LookupEnv("CODER"); found && val == "true" {
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
}
// Bypass rate limiting for support bundle collection since it makes many API calls.
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
deps := support.Deps{
Client: client,
// Support adds a sink so we don't need to supply one ourselves.
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
return
}
var docsURL string
if bun.Deployment.Config != nil {
docsURL = bun.Deployment.Config.Values.DocsURL.String()
} else {
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
if bun.Deployment.Config == nil {
cliui.Error(inv.Stdout, "No deployment configuration available!")
return
}
if bun.Deployment.HealthReport != nil {
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
} else {
cliui.Warn(inv.Stdout, "No deployment health report available.")
docsURL := bun.Deployment.Config.Values.DocsURL.String()
if bun.Deployment.HealthReport == nil {
cliui.Error(inv.Stdout, "No deployment health report available!")
return
}
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
if bun.Network.Netcheck == nil {
+22 -47
View File
@@ -28,9 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -52,21 +50,9 @@ func TestSupportBundle(t *testing.T) {
dc.Values.Prometheus.Enable = true
secretValue := uuid.NewString()
seedSecretDeploymentOptions(t, &dc, secretValue)
// Use a mock healthcheck function to avoid flaky DERP health
// checks in CI. The DERP checker performs real network operations
// (portmapper gateway probing, STUN) that can hang for 60s+ on
// macOS CI runners. Since this test validates support bundle
// generation, not healthcheck correctness, a canned report is
// sufficient.
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dc.Values,
HealthcheckFunc: func(_ context.Context, _ string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
return &healthsdk.HealthcheckReport{
Time: time.Now(),
Healthy: true,
Severity: health.SeverityOK,
}
},
DeploymentValues: dc.Values,
HealthcheckTimeout: testutil.WaitSuperLong,
})
t.Cleanup(func() { closer.Close() })
@@ -74,7 +60,7 @@ func TestSupportBundle(t *testing.T) {
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Set up test fixtures
setupCtx := testutil.Context(t, testutil.WaitLong)
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
workspaceWithAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, func(agents []*proto.Agent) []*proto.Agent {
// This should not show up in the bundle output
agents[0].Env["SECRET_VALUE"] = secretValue
@@ -83,6 +69,22 @@ func TestSupportBundle(t *testing.T) {
workspaceWithoutAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, nil)
memberWorkspace := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, member.ID, nil)
// Wait for healthcheck to complete successfully before continuing with sub-tests.
// The result is cached so subsequent requests will be fast.
healthcheckDone := make(chan *healthsdk.HealthcheckReport)
go func() {
defer close(healthcheckDone)
hc, err := healthsdk.New(client).DebugHealth(setupCtx)
if err != nil {
assert.NoError(t, err, "seed healthcheck cache")
return
}
healthcheckDone <- &hc
}()
if _, ok := testutil.AssertReceive(setupCtx, t, healthcheckDone); !ok {
t.Fatal("healthcheck did not complete in time -- this may be a transient issue")
}
t.Run("WorkspaceWithAgent", func(t *testing.T) {
t.Parallel()
@@ -130,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
assertBundleContents(t, path, true, false, []string{secretValue})
})
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
t.Run("NoPrivilege", func(t *testing.T) {
t.Parallel()
d := t.TempDir()
path := filepath.Join(d, "bundle.zip")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
clitest.SetupConfig(t, memberClient, root)
err := inv.Run()
require.NoError(t, err)
r, err := zip.OpenReader(path)
require.NoError(t, err, "open zip file")
defer r.Close()
fileNames := make(map[string]struct{}, len(r.File))
for _, f := range r.File {
fileNames[f.Name] = struct{}{}
}
// These should always be present in the zip structure, even if
// the content is null/empty for non-admin users.
for _, name := range []string{
"deployment/buildinfo.json",
"deployment/config.json",
"workspace/workspace.json",
"logs.txt",
"cli_logs.txt",
"network/netcheck.json",
"network/interfaces.json",
} {
require.Contains(t, fileNames, name)
}
require.ErrorContains(t, err, "failed authorization check")
})
// This ensures that the CLI does not panic when trying to generate a support bundle
@@ -180,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("received request: %s %s", r.Method, r.URL)
switch r.URL.Path {
case "/api/v2/users/me":
resp := codersdk.User{}
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
case "/api/v2/authcheck":
// Fake auth check
resp := codersdk.AuthorizationResponse{
+1 -1
View File
@@ -7,7 +7,7 @@
"last_seen_at": "====[timestamp]=====",
"name": "test-daemon",
"version": "v0.0.0-devel",
"api_version": "1.16",
"api_version": "1.15",
"provisioners": [
"echo"
],
-6
View File
@@ -170,12 +170,6 @@ AI BRIDGE OPTIONS:
exporting these records to external SIEM or observability systems.
AI BRIDGE PROXY OPTIONS:
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
Comma-separated list of CIDR ranges that are permitted even though
they fall within blocked private/reserved IP ranges. By default all
private ranges are blocked to prevent SSRF attacks. Use this to allow
access to specific internal networks.
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
provider requests.
+10 -11
View File
@@ -8,17 +8,16 @@ USAGE:
Aliases: user
SUBCOMMANDS:
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
oidc-claims Display the OIDC claims for the authenticated user.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user
cannot log into the platform
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user cannot
log into the platform
———
Run `coder --help` for a list of global options.
-4
View File
@@ -24,10 +24,6 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
-24
View File
@@ -1,24 +0,0 @@
coder v0.0.0-devel
USAGE:
coder users oidc-claims [flags]
Display the OIDC claims for the authenticated user.
- Display your OIDC claims:
$ coder users oidc-claims
- Display your OIDC claims as JSON:
$ coder users oidc-claims -o json
OPTIONS:
-c, --column [key|value] (default: key,value)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
-11
View File
@@ -752,11 +752,6 @@ workspace_prebuilds:
# limit; disabled when set to zero.
# (default: 3, type: int)
failure_hard_limit: 3
# Configure the background chat processing daemon.
chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
@@ -873,12 +868,6 @@ aibridgeproxy:
# by the system. If not provided, the system certificate pool is used.
# (default: <unset>, type: string)
upstream_proxy_ca: ""
# Comma-separated list of CIDR ranges that are permitted even though they fall
# within blocked private/reserved IP ranges. By default all private ranges are
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
# networks.
# (default: <unset>, type: string-array)
allowed_private_cidrs: []
# Configure data retention policies for various database tables. Retention
# policies automatically purge old data to reduce database size and improve
# performance. Setting a retention duration to 0 disables automatic purging for
+12 -37
View File
@@ -17,14 +17,13 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" && !serviceAccount {
if email == "" {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin || serviceAccount {
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
-53
View File
@@ -8,7 +8,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
-79
View File
@@ -1,79 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) userOIDCClaims() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.ChangeFormatterData(
cliui.TableFormat([]claimRow{}, []string{"key", "value"}),
func(data any) (any, error) {
resp, ok := data.(codersdk.OIDCClaimsResponse)
if !ok {
return nil, xerrors.Errorf("expected type %T, got %T", resp, data)
}
rows := make([]claimRow, 0, len(resp.Claims))
for k, v := range resp.Claims {
rows = append(rows, claimRow{
Key: k,
Value: fmt.Sprintf("%v", v),
})
}
return rows, nil
},
),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "oidc-claims",
Short: "Display the OIDC claims for the authenticated user.",
Long: FormatExamples(
Example{
Description: "Display your OIDC claims",
Command: "coder users oidc-claims",
},
Example{
Description: "Display your OIDC claims as JSON",
Command: "coder users oidc-claims -o json",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
resp, err := client.UserOIDCClaims(inv.Context())
if err != nil {
return xerrors.Errorf("get oidc claims: %w", err)
}
out, err := formatter.Format(inv.Context(), resp)
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
type claimRow struct {
Key string `json:"-" table:"key,default_sort"`
Value string `json:"-" table:"value"`
}
-161
View File
@@ -1,161 +0,0 @@
package cli_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestUserOIDCClaims(t *testing.T) {
t.Parallel()
newOIDCTest := func(t *testing.T) (*oidctest.FakeIDP, *codersdk.Client) {
t.Helper()
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
ownerClient := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
return fake, ownerClient
}
t.Run("OwnClaims", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
"groups": []string{"admin", "eng"},
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims", "-o", "json")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
var resp codersdk.OIDCClaimsResponse
err = json.Unmarshal(buf.Bytes(), &resp)
require.NoError(t, err, "unmarshal JSON output")
require.NotEmpty(t, resp.Claims, "claims should not be empty")
assert.Equal(t, "alice@coder.com", resp.Claims["email"])
})
t.Run("Table", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "bob@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
output := buf.String()
require.Contains(t, output, "email")
require.Contains(t, output, "bob@coder.com")
})
t.Run("NotOIDCUser", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, client, root)
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not an OIDC user")
})
// Verify that two different OIDC users each only see their own
// claims. The endpoint has no user parameter, so there is no way
// to request another user's claims by design.
t.Run("OnlyOwnClaims", func(t *testing.T) {
t.Parallel()
aliceFake, aliceOwnerClient := newOIDCTest(t)
aliceClaims := jwt.MapClaims{
"email": "alice-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
aliceClient, aliceLoginResp := aliceFake.Login(t, aliceOwnerClient, aliceClaims)
defer aliceLoginResp.Body.Close()
bobFake, bobOwnerClient := newOIDCTest(t)
bobClaims := jwt.MapClaims{
"email": "bob-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
bobClient, bobLoginResp := bobFake.Login(t, bobOwnerClient, bobClaims)
defer bobLoginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
// Alice sees her own claims.
aliceResp, err := aliceClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "alice-isolation@coder.com", aliceResp.Claims["email"])
// Bob sees his own claims.
bobResp, err := bobClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "bob-isolation@coder.com", bobResp.Claims["email"])
})
t.Run("ClaimsNeverNull", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
// Use minimal claims — just enough for OIDC login.
claims := jwt.MapClaims{
"email": "minimal@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.UserOIDCClaims(ctx)
require.NoError(t, err)
require.NotNil(t, resp.Claims, "claims should never be nil, expected empty map")
})
}
-1
View File
@@ -19,7 +19,6 @@ func (r *RootCmd) users() *serpent.Command {
r.userSingle(),
r.userDelete(),
r.userEditRoles(),
r.userOIDCClaims(),
r.createUserStatusCommand(codersdk.UserStatusActive),
r.createUserStatusCommand(codersdk.UserStatusSuspended),
},
+1688 -2042
View File
File diff suppressed because it is too large Load Diff
+1688 -2020
View File
File diff suppressed because it is too large Load Diff
+491 -916
View File
File diff suppressed because it is too large Load Diff
-534
View File
@@ -2,26 +2,13 @@ package chatd
import (
"context"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -97,524 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).Times(1)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
nil,
"",
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
).Times(1)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionCache: make(map[uuid.UUID]cachedInstruction),
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction := server.resolveInstructions(
ctx,
chat,
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
localMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishMessage(chatID, localMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
cachedMessage := codersdk.ChatMessage{
ID: 2,
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &cachedMessage,
})
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
catchupMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
}).Return([]database.ChatMessage{catchupMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
editedMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{editedMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishEditedMessage(chatID, editedMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(1), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
retryingAt := time.Unix(1_700_000_000, 0).UTC()
expected := &codersdk.ChatStreamRetry{
Attempt: 1,
DelayMs: (1500 * time.Millisecond).Milliseconds(),
Error: "rate limit exceeded",
RetryingAt: retryingAt,
}
server.publishRetry(chatID, expected)
event := requireStreamRetryEvent(t, events)
require.Equal(t, expected, event.Retry)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
return &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
pubsub: dbpubsub.NewInMemory(),
}
}
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
require.NotNil(t, event.Message)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream message event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
require.NotNil(t, event.Retry)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream retry event")
return codersdk.ChatStreamEvent{}
}
}
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
t.Helper()
select {
case event, ok := <-events:
if !ok {
t.Fatal("chat stream closed unexpectedly")
}
t.Fatalf("unexpected chat stream event: %+v", event)
case <-time.After(wait):
}
}
// TestPublishToStream_DropWarnRateLimiting walks through a
// realistic lifecycle: buffer fills up, subscriber channel fills
// up, counters get reset between steps. It verifies that WARN
// logs are rate-limited to at most once per streamDropWarnInterval
// and that counter resets re-enable an immediate WARN.
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
mClock := quartz.NewMock(t)
server := &Server{
logger: sink.Logger(),
clock: mClock,
}
chatID := uuid.New()
subCh := make(chan codersdk.ChatStreamEvent, 1)
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
// Set up state that mirrors a running chat: buffer at capacity,
// buffering enabled, one saturated subscriber.
state := &chatStreamState{
buffering: true,
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): subCh,
},
}
server.chatStreams.Store(chatID, state)
bufferMsg := "chat stream buffer full, dropping oldest event"
subMsg := "dropping chat stream event"
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
return func(e slog.SinkEntry) bool {
return e.Level == level && e.Message == msg
}
}
// --- Phase 1: buffer-full rate limiting ---
// message_part events hit both the buffer-full and subscriber-full
// paths. The first publish triggers a WARN for each; the rest
// within the window are DEBUG.
partEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
for i := 0; i < 50; i++ {
server.publishToStream(chatID, partEvent)
}
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
// Subscriber also saw 50 drops (one per publish).
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
// --- Phase 2: clock advance triggers second WARN with count ---
mClock.Advance(streamDropWarnInterval + time.Second)
server.publishToStream(chatID, partEvent)
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 2)
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 2)
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
// --- Phase 3: counter reset (simulates step persist) ---
state.mu.Lock()
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
state.resetDropCounters()
state.mu.Unlock()
// The very next drop should WARN immediately — the reset zeroed
// lastWarnAt so the interval check passes.
server.publishToStream(chatID, partEvent)
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
t.Helper()
for _, f := range entry.Fields {
if f.Name == name {
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
return
}
}
t.Fatalf("field %q not found in log entry", name)
}
+221 -1040
View File
File diff suppressed because it is too large Load Diff
+6 -32
View File
@@ -42,11 +42,6 @@ type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
}
// RunOptions configures a single streaming chat loop run.
@@ -127,7 +122,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
switch c.GetType() {
case fantasy.ContentTypeText:
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
if !ok || strings.TrimSpace(text.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.TextPart{
@@ -136,7 +131,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
})
case fantasy.ContentTypeReasoning:
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
if !ok || strings.TrimSpace(reasoning.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ReasoningPart{
@@ -265,7 +260,6 @@ func Run(ctx context.Context, opts RunOptions) error {
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
stepStart := time.Now()
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -371,7 +365,6 @@ func Run(ctx context.Context, opts RunOptions) error {
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -617,12 +610,10 @@ func processStepStream(
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: the stream may surface the
// cancel as context.Canceled or propagate the
// ErrInterrupted cause directly, depending on
// the provider implementation.
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
@@ -640,23 +631,6 @@ func processStepStream(
}
}
// The stream iterator may stop yielding parts without
// producing a StreamPartTypeError when the context is
// canceled (e.g. some providers close the response body
// silently). Detect this case and flush partial content
// so that persistInterruptedStep can save it.
if ctx.Err() != nil &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
-81
View File
@@ -7,7 +7,6 @@ import (
"strings"
"sync"
"testing"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
@@ -65,8 +64,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.Greater(t, persistedStep.Runtime, time.Duration(0),
"step runtime should be positive")
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
@@ -578,84 +575,6 @@ func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *test
assert.False(t, localTR.ProviderExecuted)
}
func TestToResponseMessages_FiltersEmptyTextAndReasoningParts(t *testing.T) {
t.Parallel()
sr := stepResult{
content: []fantasy.Content{
// Empty text — should be filtered.
fantasy.TextContent{Text: ""},
// Whitespace-only text — should be filtered.
fantasy.TextContent{Text: " \t\n"},
// Empty reasoning — should be filtered.
fantasy.ReasoningContent{Text: ""},
// Whitespace-only reasoning — should be filtered.
fantasy.ReasoningContent{Text: " \n"},
// Non-empty text — should pass through.
fantasy.TextContent{Text: "hello world"},
// Leading/trailing whitespace with content — kept
// with the original value (not trimmed).
fantasy.TextContent{Text: " hello "},
// Non-empty reasoning — should pass through.
fantasy.ReasoningContent{Text: "let me think"},
// Tool call — should be unaffected by filtering.
fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Input: `{"path":"main.go"}`,
},
// Local tool result — should be unaffected by filtering.
fantasy.ToolResultContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Result: fantasy.ToolResultOutputContentText{Text: "file contents"},
},
},
}
msgs := sr.toResponseMessages()
require.Len(t, msgs, 2, "expected assistant + tool messages")
// First message: assistant role with non-empty text, reasoning,
// and the tool call. The four empty/whitespace-only parts must
// have been dropped.
assistantMsg := msgs[0]
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
require.Len(t, assistantMsg.Content, 4,
"assistant message should have 2x TextPart, ReasoningPart, and ToolCallPart")
// Part 0: non-empty text.
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[0])
require.True(t, ok, "part 0 should be TextPart")
assert.Equal(t, "hello world", textPart.Text)
// Part 1: padded text — original whitespace preserved.
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[1])
require.True(t, ok, "part 1 should be TextPart")
assert.Equal(t, " hello ", paddedPart.Text)
// Part 2: non-empty reasoning.
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](assistantMsg.Content[2])
require.True(t, ok, "part 2 should be ReasoningPart")
assert.Equal(t, "let me think", reasoningPart.Text)
// Part 3: tool call (unaffected by text/reasoning filtering).
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[3])
require.True(t, ok, "part 3 should be ToolCallPart")
assert.Equal(t, "tc-1", toolCallPart.ToolCallID)
assert.Equal(t, "read_file", toolCallPart.ToolName)
// Second message: tool role with the local tool result.
toolMsg := msgs[1]
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1,
"tool message should have only the local ToolResultPart")
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "tool part should be ToolResultPart")
assert.Equal(t, "tc-1", toolResultPart.ToolCallID)
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
@@ -1,399 +0,0 @@
package chatloop
import (
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testProviderData implements fantasy.ProviderOptionsData so we can
// construct arbitrary ProviderMetadata for extractContextLimit tests.
type testProviderData struct {
data map[string]any
}
func (*testProviderData) Options() {}
func (d *testProviderData) MarshalJSON() ([]byte, error) {
return json.Marshal(d.data)
}
// Required by the ProviderOptionsData interface; unused in tests.
func (d *testProviderData) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, &d.data)
}
func TestNormalizeMetadataKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
want string
}{
{name: "lowercase", key: "camelCase", want: "camelcase"},
{name: "hyphens stripped", key: "kebab-case", want: "kebabcase"},
{name: "underscores stripped", key: "snake_case", want: "snakecase"},
{name: "uppercase", key: "UPPER", want: "upper"},
{name: "spaces stripped", key: "with spaces", want: "withspaces"},
{name: "empty", key: "", want: ""},
{name: "digits preserved", key: "123", want: "123"},
{name: "mixed separators", key: "Max_Context-Tokens", want: "maxcontexttokens"},
{name: "dots stripped", key: "context.limit", want: "contextlimit"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := normalizeMetadataKey(tt.key)
require.Equal(t, tt.want, got)
})
}
}
func TestIsContextLimitKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
want bool
skip bool
}{ // Exact matches after normalization.
{name: "context_limit", key: "context_limit", want: true},
{name: "context_window", key: "context_window", want: true},
{name: "context_length", key: "context_length", want: true},
{name: "max_context", key: "max_context", want: true},
{name: "max_context_tokens", key: "max_context_tokens", want: true},
{name: "max_input_tokens", key: "max_input_tokens", want: true},
{name: "max_input_token", key: "max_input_token", want: true},
{name: "input_token_limit", key: "input_token_limit", want: true},
// Case and separator variations.
{name: "Context-Window mixed case", key: "Context-Window", want: true},
{name: "MAX_CONTEXT_TOKENS screaming", key: "MAX_CONTEXT_TOKENS", want: true},
{name: "contextLimit camelCase", key: "contextLimit", want: true},
// Fallback heuristic: contains "context" + limit/window/length.
{name: "model_context_limit", key: "model_context_limit", want: true},
{name: "context_window_size", key: "context_window_size", want: true},
{name: "context_length_max", key: "context_length_max", want: true},
// Fallback heuristic: starts with "max" + contains "context".
// BUG(isContextLimitKey): "max_context_version" matches
// because it contains "context" and starts with "max",
// but a version field is not a context limit.
// TODO: Fix the heuristic and remove this skip.
{name: "max_context_version false positive", key: "max_context_version", want: false, skip: true}, // Non-matching keys.
{name: "context_id no limit keyword", key: "context_id", want: false},
{name: "empty string", key: "", want: false},
{name: "unrelated key", key: "model_name", want: false},
{name: "limit without context", key: "rate_limit", want: false},
{name: "max without context", key: "max_tokens", want: false},
{name: "context alone", key: "context", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if tt.skip {
t.Skip("known bug: isContextLimitKey false positive")
}
got := isContextLimitKey(tt.key)
require.Equal(t, tt.want, got)
})
}
}
func TestNumericContextLimitValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value any
want int64
wantOK bool
}{
// float64: the default numeric type from json.Unmarshal.
{name: "float64 integer", value: float64(128000), want: 128000, wantOK: true},
{name: "float64 fractional rejected", value: float64(128000.5), want: 0, wantOK: false},
{name: "float64 zero rejected", value: float64(0), want: 0, wantOK: false},
{name: "float64 negative rejected", value: float64(-1), want: 0, wantOK: false},
// int64
{name: "int64 positive", value: int64(200000), want: 200000, wantOK: true},
{name: "int64 zero rejected", value: int64(0), want: 0, wantOK: false},
{name: "int64 negative rejected", value: int64(-1), want: 0, wantOK: false},
// int32
{name: "int32 positive", value: int32(50000), want: 50000, wantOK: true},
{name: "int32 zero rejected", value: int32(0), want: 0, wantOK: false},
// int
{name: "int positive", value: int(50000), want: 50000, wantOK: true},
{name: "int zero rejected", value: int(0), want: 0, wantOK: false},
// string
{name: "string numeric", value: "128000", want: 128000, wantOK: true},
{name: "string trimmed", value: " 128000 ", want: 128000, wantOK: true},
{name: "string non-numeric rejected", value: "not a number", want: 0, wantOK: false},
{name: "string empty rejected", value: "", want: 0, wantOK: false},
{name: "string zero rejected", value: "0", want: 0, wantOK: false},
{name: "string negative rejected", value: "-1", want: 0, wantOK: false},
// json.Number
{name: "json.Number valid", value: json.Number("200000"), want: 200000, wantOK: true},
{name: "json.Number invalid rejected", value: json.Number("invalid"), want: 0, wantOK: false},
{name: "json.Number zero rejected", value: json.Number("0"), want: 0, wantOK: false},
// Unhandled types.
{name: "bool rejected", value: true, want: 0, wantOK: false},
{name: "nil rejected", value: nil, want: 0, wantOK: false},
{name: "slice rejected", value: []int{1}, want: 0, wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := numericContextLimitValue(tt.value)
require.Equal(t, tt.wantOK, ok)
require.Equal(t, tt.want, got)
})
}
}
func TestPositiveInt64(t *testing.T) {
t.Parallel()
got, ok := positiveInt64(42)
require.True(t, ok)
require.Equal(t, int64(42), got)
got, ok = positiveInt64(0)
require.False(t, ok)
require.Equal(t, int64(0), got)
got, ok = positiveInt64(-1)
require.False(t, ok)
require.Equal(t, int64(0), got)
}
func TestCollectContextLimitValues(t *testing.T) {
t.Parallel()
t.Run("FlatMap", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"context_limit": float64(200000),
"other_key": float64(999),
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{200000}, collected)
})
t.Run("NestedMaps", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"provider": map[string]any{
"info": map[string]any{
"context_window": float64(100000),
},
},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{100000}, collected)
})
t.Run("ArrayTraversal", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"context_limit": float64(50000)},
map[string]any{"context_limit": float64(80000)},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Len(t, collected, 2)
require.Contains(t, collected, int64(50000))
require.Contains(t, collected, int64(80000))
})
t.Run("MixedNesting", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"models": []any{
map[string]any{
"context_limit": float64(128000),
},
},
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Equal(t, []int64{128000}, collected)
})
t.Run("NonMatchingKey", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"model_name": "gpt-4",
"tokens": float64(1000),
}
var collected []int64
collectContextLimitValues(input, func(v int64) {
collected = append(collected, v)
})
require.Empty(t, collected)
})
t.Run("ScalarIgnored", func(t *testing.T) {
t.Parallel()
var collected []int64
collectContextLimitValues("just a string", func(v int64) {
collected = append(collected, v)
})
require.Empty(t, collected)
})
}
func TestFindContextLimitValue(t *testing.T) {
t.Parallel()
t.Run("SingleCandidate", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"context_limit": float64(200000),
}
limit, ok := findContextLimitValue(input)
require.True(t, ok)
require.Equal(t, int64(200000), limit)
})
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"a": map[string]any{"context_limit": float64(50000)},
"b": map[string]any{"context_limit": float64(200000)},
}
limit, ok := findContextLimitValue(input)
require.True(t, ok)
require.Equal(t, int64(200000), limit)
})
t.Run("NoCandidates", func(t *testing.T) {
t.Parallel()
input := map[string]any{
"model": "gpt-4",
}
_, ok := findContextLimitValue(input)
require.False(t, ok)
})
t.Run("NilInput", func(t *testing.T) {
t.Parallel()
_, ok := findContextLimitValue(nil)
require.False(t, ok)
})
}
func TestExtractContextLimit(t *testing.T) {
t.Parallel()
t.Run("AnthropicStyle", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"anthropic": &testProviderData{
data: map[string]any{
"cache_read_input_tokens": float64(100),
"context_limit": float64(200000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(200000), result.Int64)
})
t.Run("OpenAIStyle", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"openai": &testProviderData{
data: map[string]any{
"max_context_tokens": float64(128000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(128000), result.Int64)
})
t.Run("NestedDeeply", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"provider": &testProviderData{
data: map[string]any{
"info": map[string]any{
"context_window": float64(100000),
},
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(100000), result.Int64)
})
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"a": &testProviderData{
data: map[string]any{
"context_limit": float64(50000),
},
},
"b": &testProviderData{
data: map[string]any{
"context_limit": float64(200000),
},
},
}
result := extractContextLimit(metadata)
require.True(t, result.Valid)
require.Equal(t, int64(200000), result.Int64)
})
t.Run("NoMatchingKeys", func(t *testing.T) {
t.Parallel()
metadata := fantasy.ProviderMetadata{
"openai": &testProviderData{
data: map[string]any{
"model": "gpt-4",
"tokens": float64(1000),
},
},
}
result := extractContextLimit(metadata)
assert.False(t, result.Valid)
})
t.Run("NilMetadata", func(t *testing.T) {
t.Parallel()
result := extractContextLimit(nil)
assert.False(t, result.Valid)
})
t.Run("EmptyMetadata", func(t *testing.T) {
t.Parallel()
result := extractContextLimit(fantasy.ProviderMetadata{})
assert.False(t, result.Valid)
})
}
+3 -215
View File
@@ -139,13 +139,9 @@ func ConvertMessagesWithFiles(
},
})
case codersdk.ChatMessageRoleUser:
userParts := partsToMessageParts(logger, pm.parts, resolved)
if len(userParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: userParts,
Content: partsToMessageParts(logger, pm.parts, resolved),
})
case codersdk.ChatMessageRoleAssistant:
fantasyParts := normalizeAssistantToolCallInputs(
@@ -157,9 +153,6 @@ func ConvertMessagesWithFiles(
}
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
}
if len(fantasyParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleAssistant,
Content: fantasyParts,
@@ -173,13 +166,9 @@ func ConvertMessagesWithFiles(
}
}
}
toolParts := partsToMessageParts(logger, pm.parts, resolved)
if len(toolParts) == 0 {
continue
}
prompt = append(prompt, fantasy.Message{
Role: fantasy.MessageRoleTool,
Content: toolParts,
Content: partsToMessageParts(logger, pm.parts, resolved),
})
}
}
@@ -332,7 +321,6 @@ func parseContentV1(role codersdk.ChatMessageRole, raw pqtype.NullRawMessage) ([
if err := json.Unmarshal(raw.RawMessage, &parts); err != nil {
return nil, xerrors.Errorf("parse %s content: %w", role, err)
}
decodeNulInParts(parts)
return parts, nil
}
@@ -1030,16 +1018,11 @@ func sanitizeToolCallID(id string) string {
}
// MarshalParts encodes SDK chat message parts for persistence.
// NUL characters in string fields are encoded as PUA sentinel
// pairs (U+E000 U+E001) before marshaling so the resulting JSON
// never contains \u0000 (rejected by PostgreSQL jsonb). The
// encoding operates on Go string values, not JSON bytes, so it
// survives jsonb text normalization.
func MarshalParts(parts []codersdk.ChatMessagePart) (pqtype.NullRawMessage, error) {
if len(parts) == 0 {
return pqtype.NullRawMessage{}, nil
}
data, err := json.Marshal(encodeNulInParts(parts))
data, err := json.Marshal(parts)
if err != nil {
return pqtype.NullRawMessage{}, xerrors.Errorf("encode chat message parts: %w", err)
}
@@ -1186,23 +1169,11 @@ func partsToMessageParts(
for _, part := range parts {
switch part.Type {
case codersdk.ChatMessagePartTypeText:
// Anthropic rejects empty text content blocks with
// "text content blocks must be non-empty". Empty parts
// can arise when a stream sends TextStart/TextEnd with
// no delta in between. We filter them here rather than
// at persistence time to preserve the raw record.
if strings.TrimSpace(part.Text) == "" {
continue
}
result = append(result, fantasy.TextPart{
Text: part.Text,
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
})
case codersdk.ChatMessagePartTypeReasoning:
// Same guard as text parts above.
if strings.TrimSpace(part.Text) == "" {
continue
}
result = append(result, fantasy.ReasoningPart{
Text: part.Text,
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
@@ -1245,186 +1216,3 @@ func partsToMessageParts(
}
return result
}
// encodeNulInString replaces NUL (U+0000) characters in s with
// the sentinel pair U+E000 U+E001, and doubles any pre-existing
// U+E000 to U+E000 U+E000 so the encoding is reversible.
// Operates on Unicode code points, not JSON escape sequences,
// making it safe through jsonb round-trips (jsonb stores parsed
// characters, not original escape text).
func encodeNulInString(s string) string {
if !strings.ContainsRune(s, 0) && !strings.ContainsRune(s, '\uE000') {
return s
}
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
switch r {
case '\uE000':
_, _ = b.WriteRune('\uE000')
_, _ = b.WriteRune('\uE000')
case 0:
_, _ = b.WriteRune('\uE000')
_, _ = b.WriteRune('\uE001')
default:
_, _ = b.WriteRune(r)
}
}
return b.String()
}
// decodeNulInString reverses encodeNulInString: U+E000 U+E000
// becomes U+E000, and U+E000 U+E001 becomes NUL.
func decodeNulInString(s string) string {
if !strings.ContainsRune(s, '\uE000') {
return s
}
var b strings.Builder
b.Grow(len(s))
runes := []rune(s)
for i := 0; i < len(runes); i++ {
if runes[i] == '\uE000' && i+1 < len(runes) {
switch runes[i+1] {
case '\uE000':
_, _ = b.WriteRune('\uE000')
i++
case '\uE001':
_, _ = b.WriteRune(0)
i++
default:
// Unpaired sentinel — preserve as-is.
_, _ = b.WriteRune(runes[i])
}
} else {
_, _ = b.WriteRune(runes[i])
}
}
return b.String()
}
// encodeNulInValue recursively walks a JSON value (as produced
// by json.Unmarshal with UseNumber) and applies
// encodeNulInString to every string, including map keys.
func encodeNulInValue(v any) any {
switch val := v.(type) {
case string:
return encodeNulInString(val)
case map[string]any:
out := make(map[string]any, len(val))
for k, elem := range val {
out[encodeNulInString(k)] = encodeNulInValue(elem)
}
return out
case []any:
out := make([]any, len(val))
for i, elem := range val {
out[i] = encodeNulInValue(elem)
}
return out
default:
return v // numbers, bools, nil
}
}
// decodeNulInValue recursively walks a JSON value and applies
// decodeNulInString to every string, including map keys.
func decodeNulInValue(v any) any {
switch val := v.(type) {
case string:
return decodeNulInString(val)
case map[string]any:
out := make(map[string]any, len(val))
for k, elem := range val {
out[decodeNulInString(k)] = decodeNulInValue(elem)
}
return out
case []any:
out := make([]any, len(val))
for i, elem := range val {
out[i] = decodeNulInValue(elem)
}
return out
default:
return v
}
}
// encodeNulInJSON walks all string values (and keys) inside a
// json.RawMessage and applies encodeNulInString. Returns the
// original unchanged when the raw message does not contain NUL
// escapes or U+E000 bytes, or when parsing fails.
func encodeNulInJSON(raw json.RawMessage) json.RawMessage {
if len(raw) == 0 {
return raw
}
// Quick exit: no \u0000 escape and no U+E000 UTF-8 bytes.
if !bytes.Contains(raw, []byte(`\u0000`)) &&
!bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
return raw
}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
var v any
if err := dec.Decode(&v); err != nil {
return raw
}
result, err := json.Marshal(encodeNulInValue(v))
if err != nil {
return raw
}
return result
}
// decodeNulInJSON walks all string values (and keys) inside a
// json.RawMessage and applies decodeNulInString.
func decodeNulInJSON(raw json.RawMessage) json.RawMessage {
if len(raw) == 0 {
return raw
}
// U+E000 encoded as UTF-8 is 0xEE 0x80 0x80.
if !bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
return raw
}
dec := json.NewDecoder(bytes.NewReader(raw))
dec.UseNumber()
var v any
if err := dec.Decode(&v); err != nil {
return raw
}
result, err := json.Marshal(decodeNulInValue(v))
if err != nil {
return raw
}
return result
}
// encodeNulInParts returns a shallow copy of parts with all
// string and json.RawMessage fields NUL-encoded. The caller's
// slice is not modified.
func encodeNulInParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
encoded := make([]codersdk.ChatMessagePart, len(parts))
copy(encoded, parts)
for i := range encoded {
p := &encoded[i]
p.Text = encodeNulInString(p.Text)
p.Content = encodeNulInString(p.Content)
p.Args = encodeNulInJSON(p.Args)
p.ArgsDelta = encodeNulInString(p.ArgsDelta)
p.Result = encodeNulInJSON(p.Result)
p.ResultDelta = encodeNulInString(p.ResultDelta)
}
return encoded
}
// decodeNulInParts reverses encodeNulInParts in place.
func decodeNulInParts(parts []codersdk.ChatMessagePart) {
for i := range parts {
p := &parts[i]
p.Text = decodeNulInString(p.Text)
p.Content = decodeNulInString(p.Content)
p.Args = decodeNulInJSON(p.Args)
p.ArgsDelta = decodeNulInString(p.ArgsDelta)
p.Result = decodeNulInJSON(p.Result)
p.ResultDelta = decodeNulInString(p.ResultDelta)
}
}
-327
View File
@@ -17,10 +17,7 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// testMsg builds a database.ChatMessage for ParseContent tests.
@@ -1444,327 +1441,3 @@ func extractToolResultIDs(t *testing.T, msgs ...fantasy.Message) []string {
}
return ids
}
func TestNulEscapeRoundTrip(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
// Seed minimal dependencies for the DB round-trip path:
// user, provider, model config, chat.
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "openai",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "nul-roundtrip-test",
})
require.NoError(t, err)
textTests := []struct {
name string
input string
hasNul bool // Whether the input contains actual NUL bytes.
}{
// --- basic ---
{"NoNul", "hello world", false},
{"SingleNul", "a\x00b", true},
{"MultipleNuls", "a\x00b\x00c", true},
{"ConsecutiveNuls", "\x00\x00\x00", true},
// --- boundaries ---
{"EmptyString", "", false},
{"NulOnly", "\x00", true},
{"NulAtStart", "\x00hello", true},
{"NulAtEnd", "hello\x00", true},
// --- sentinel / marker in original data ---
// U+E000 is the sentinel character. The encoder must
// double it so it round-trips without being mistaken
// for an encoded NUL.
{"SentinelInOriginal", "a\uE000b", false},
{"ConsecutiveSentinels", "\uE000\uE000\uE000", false},
// U+E001 is the marker character used in the NUL pair.
{"MarkerCharInOriginal", "a\uE001b", false},
// U+E000 followed by U+E001 looks exactly like an
// encoded NUL in the encoded form, so the encoder must
// double the U+E000 to avoid confusion.
{"SentinelThenMarkerChar", "\uE000\uE001", false},
{"NulAndSentinel", "a\x00b\uE000c", true},
// Both orders: sentinel adjacent to NUL.
{"SentinelThenNul", "\uE000\x00", true},
{"NulThenSentinel", "\x00\uE000", true},
{"AlternatingSentinelNul", "\x00\uE000\x00\uE000", true},
// --- strings containing backslashes ---
// Backslashes are normal characters at the Go string
// level; no special handling needed (unlike the old
// JSON-byte approach).
{"BackslashU0000Text", "\\u0000", false},
{"BackslashThenNul", "\\\x00", true},
// --- literal text that looks like escape patterns ---
{"LiteralTextU0000", "the value is u0000 here", false},
{"LiteralTextUE000", "sentinel uE000 text", false},
// --- other control characters mixed with NUL ---
{"ControlCharsMixedWithNul", "\x01\x00\x02\x00\x1f", true},
// --- long / stress ---
{"LongNulRun", "\x00\x00\x00\x00\x00\x00\x00\x00", true},
// Simulated find -print0 output.
{"FindPrint0", "/usr/bin/ls\x00/usr/bin/cat\x00/usr/bin/grep\x00", true},
}
for _, tc := range textTests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(tc.input),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
// When the input has real NUL bytes, the stored JSON
// must not contain the \u0000 escape sequence.
if tc.hasNul {
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
}
// In-memory round-trip through ParseContent.
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 1)
require.Equal(t, tc.input, decoded[0].Text)
// Full DB round-trip: write to PostgreSQL jsonb, read
// back, and verify the value survives storage.
ctx := testutil.Context(t, testutil.WaitShort)
dbMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{user.ID},
ModelConfigID: []uuid.UUID{model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{string(encoded.RawMessage)},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
require.Len(t, dbMsgs, 1)
readBack, err := db.GetChatMessageByID(ctx, dbMsgs[0].ID)
require.NoError(t, err)
dbDecoded, err := chatprompt.ParseContent(readBack)
require.NoError(t, err)
require.Len(t, dbDecoded, 1)
require.Equal(t, tc.input, dbDecoded[0].Text)
})
}
// Tool result with NUL in the result JSON value.
t.Run("ToolResultWithNul", func(t *testing.T) {
t.Parallel()
resultJSON := json.RawMessage(`"output:\u0000done"`)
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
msg := testMsgV1(codersdk.ChatMessageRoleTool, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 1)
// JSON re-serialization may reformat, so compare
// semantically.
assert.JSONEq(t, string(resultJSON), string(decoded[0].Result))
})
// Multiple parts in one message: one with NUL, one without.
t.Run("MultiPartMixed", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText("clean text"),
codersdk.ChatMessageText("has\x00nul"),
}
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
"encoded JSON must not contain \\u0000")
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
decoded, err := chatprompt.ParseContent(msg)
require.NoError(t, err)
require.Len(t, decoded, 2)
require.Equal(t, "clean text", decoded[0].Text)
require.Equal(t, "has\x00nul", decoded[1].Text)
})
}
func TestConvertMessagesWithFiles_FiltersEmptyTextAndReasoningParts(t *testing.T) {
t.Parallel()
// Helper to build a DB message from SDK parts.
makeMsg := func(t *testing.T, role database.ChatMessageRole, parts []codersdk.ChatMessagePart) database.ChatMessage {
t.Helper()
encoded, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
return database.ChatMessage{
Role: role,
Visibility: database.ChatMessageVisibilityBoth,
Content: encoded,
ContentVersion: chatprompt.CurrentContentVersion,
}
}
t.Run("UserRole", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""), // empty — filtered
codersdk.ChatMessageText(" \t\n "), // whitespace — filtered
codersdk.ChatMessageReasoning(""), // empty — filtered
codersdk.ChatMessageReasoning(" \n"), // whitespace — filtered
codersdk.ChatMessageText("hello"), // kept
codersdk.ChatMessageText(" hello "), // kept with original whitespace
codersdk.ChatMessageReasoning("thinking deeply"), // kept
codersdk.ChatMessageToolCall("call-1", "my_tool", json.RawMessage(`{"x":1}`)),
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleUser, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
require.Len(t, prompt, 1)
resultParts := prompt[0].Content
require.Len(t, resultParts, 5, "expected 5 parts after filtering empty text/reasoning")
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
require.True(t, ok, "expected TextPart at index 0")
require.Equal(t, "hello", textPart.Text)
// Leading/trailing whitespace is preserved — only
// all-whitespace parts are dropped.
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[1])
require.True(t, ok, "expected TextPart at index 1")
require.Equal(t, " hello ", paddedPart.Text)
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](resultParts[2])
require.True(t, ok, "expected ReasoningPart at index 2")
require.Equal(t, "thinking deeply", reasoningPart.Text)
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[3])
require.True(t, ok, "expected ToolCallPart at index 3")
require.Equal(t, "call-1", toolCallPart.ToolCallID)
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](resultParts[4])
require.True(t, ok, "expected ToolResultPart at index 4")
require.Equal(t, "call-1", toolResultPart.ToolCallID)
})
t.Run("AssistantRole", func(t *testing.T) {
t.Parallel()
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""), // empty — filtered
codersdk.ChatMessageText(" "), // whitespace — filtered
codersdk.ChatMessageReasoning(""), // empty — filtered
codersdk.ChatMessageText(" reply "), // kept with whitespace
codersdk.ChatMessageToolCall("tc-1", "read_file", json.RawMessage(`{"path":"x"}`)),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
// 2 messages: assistant + synthetic tool result injected
// by injectMissingToolResults for the unmatched tool call.
require.Len(t, prompt, 2)
resultParts := prompt[0].Content
require.Len(t, resultParts, 2, "expected text + tool-call after filtering")
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
require.True(t, ok, "expected TextPart")
require.Equal(t, " reply ", textPart.Text)
tcPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[1])
require.True(t, ok, "expected ToolCallPart")
require.Equal(t, "tc-1", tcPart.ToolCallID)
})
t.Run("AllEmptyDropsMessage", func(t *testing.T) {
t.Parallel()
// When every part is filtered, the message itself should
// be dropped rather than appending an empty-content message.
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageText(""),
codersdk.ChatMessageText(" "),
codersdk.ChatMessageReasoning(""),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
context.Background(),
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
nil,
slogtest.Make(t, nil),
)
require.NoError(t, err)
require.Empty(t, prompt, "all-empty message should be dropped entirely")
})
}
+1 -9
View File
@@ -1083,7 +1083,6 @@ func openAIProviderOptionsFromChatConfig(
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
ServiceTier: openAIServiceTierFromChat(options.ServiceTier),
StrictJSONSchema: options.StrictJSONSchema,
Store: boolPtrOrDefault(options.Store, true),
TextVerbosity: OpenAITextVerbosityFromChat(options.TextVerbosity),
User: normalizedStringPointer(options.User),
}
@@ -1100,7 +1099,7 @@ func openAIProviderOptionsFromChatConfig(
MaxCompletionTokens: options.MaxCompletionTokens,
TextVerbosity: normalizedStringPointer(options.TextVerbosity),
Prediction: options.Prediction,
Store: boolPtrOrDefault(options.Store, true),
Store: options.Store,
Metadata: options.Metadata,
PromptCacheKey: normalizedStringPointer(options.PromptCacheKey),
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
@@ -1281,13 +1280,6 @@ func useOpenAIResponsesOptions(model fantasy.LanguageModel) bool {
}
}
func boolPtrOrDefault(value *bool, def bool) *bool {
if value != nil {
return value
}
return &def
}
func normalizedStringPointer(value *string) *string {
if value == nil {
return nil
+34 -23
View File
@@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
@@ -26,37 +25,37 @@ func TestReasoningEffortFromChat(t *testing.T) {
{
name: "OpenAICaseInsensitive",
provider: "openai",
input: ptr.Ref(" HIGH "),
want: ptr.Ref(string(fantasyopenai.ReasoningEffortHigh)),
input: stringPtr(" HIGH "),
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
},
{
name: "AnthropicEffort",
provider: "anthropic",
input: ptr.Ref("max"),
want: ptr.Ref(string(fantasyanthropic.EffortMax)),
input: stringPtr("max"),
want: stringPtr(string(fantasyanthropic.EffortMax)),
},
{
name: "OpenRouterEffort",
provider: "openrouter",
input: ptr.Ref("medium"),
want: ptr.Ref(string(fantasyopenrouter.ReasoningEffortMedium)),
input: stringPtr("medium"),
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
},
{
name: "VercelEffort",
provider: "vercel",
input: ptr.Ref("xhigh"),
want: ptr.Ref(string(fantasyvercel.ReasoningEffortXHigh)),
input: stringPtr("xhigh"),
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
},
{
name: "InvalidEffortReturnsNil",
provider: "openai",
input: ptr.Ref("unknown"),
input: stringPtr("unknown"),
want: nil,
},
{
name: "UnsupportedProviderReturnsNil",
provider: "bedrock",
input: ptr.Ref("high"),
input: stringPtr("high"),
want: nil,
},
{
@@ -83,8 +82,8 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
options := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: ptr.Ref(true),
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(true),
},
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"openai"},
@@ -93,22 +92,22 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
}
defaults := &codersdk.ChatModelProviderOptions{
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
Reasoning: &codersdk.ChatModelReasoningOptions{
Enabled: ptr.Ref(false),
Exclude: ptr.Ref(true),
MaxTokens: ptr.Ref[int64](123),
Effort: ptr.Ref("high"),
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
Enabled: boolPtr(false),
Exclude: boolPtr(true),
MaxTokens: int64Ptr(123),
Effort: stringPtr("high"),
},
IncludeUsage: ptr.Ref(true),
IncludeUsage: boolPtr(true),
Provider: &codersdk.ChatModelOpenRouterProvider{
Order: []string{"anthropic"},
AllowFallbacks: ptr.Ref(true),
RequireParameters: ptr.Ref(false),
DataCollection: ptr.Ref("allow"),
AllowFallbacks: boolPtr(true),
RequireParameters: boolPtr(false),
DataCollection: stringPtr("allow"),
Only: []string{"openai"},
Ignore: []string{"foo"},
Quantizations: []string{"int8"},
Sort: ptr.Ref("latency"),
Sort: stringPtr("latency"),
},
},
}
@@ -137,3 +136,15 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func stringPtr(value string) *string {
return &value
}
func boolPtr(value bool) *bool {
return &value
}
func int64Ptr(value int64) *int64 {
return &value
}
+58 -125
View File
@@ -3,12 +3,14 @@ package chattool
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@@ -21,10 +23,9 @@ const (
// maxOutputToModel is the maximum output sent to the LLM.
maxOutputToModel = 32 << 10 // 32KB
// snapshotTimeout is how long a non-blocking fallback
// request is allowed to take when retrieving a process
// output snapshot after a blocking wait times out.
snapshotTimeout = 30 * time.Second
// pollInterval is how often we check for process completion
// in foreground mode.
pollInterval = 200 * time.Millisecond
)
// nonInteractiveEnvVars are set on every process to prevent
@@ -88,7 +89,7 @@ type ExecuteArgs struct {
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command times out, the response includes a background_process_id so you can retrieve output later with process_output.",
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -126,7 +127,7 @@ func executeTool(
// run_in_background parameter, which causes the shell to fork
// and exit immediately, leaving an untracked orphan process.
trimmed := strings.TrimSpace(args.Command)
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") && !strings.HasSuffix(trimmed, "|&") {
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
background = true
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
}
@@ -172,7 +173,7 @@ func executeBackground(
return fantasy.NewTextResponse(string(data))
}
// executeForeground starts a process and waits for its
// executeForeground starts a process and polls for its
// completion, enforcing the configured timeout.
func executeForeground(
ctx context.Context,
@@ -211,7 +212,7 @@ func executeForeground(
return errorResult(fmt.Sprintf("start process: %v", err))
}
result := waitForProcess(cmdCtx, conn, resp.ID, timeout)
result := pollProcess(cmdCtx, conn, resp.ID, timeout)
result.WallDurationMs = time.Since(start).Milliseconds()
// Add an advisory note for file-dump commands.
@@ -236,84 +237,62 @@ func truncateOutput(output string) string {
return output
}
// waitForProcess waits for process completion using the
// blocking process output API instead of polling.
func waitForProcess(
// pollProcess polls for process output until the process exits
// or the context times out.
func pollProcess(
ctx context.Context,
conn workspacesdk.AgentConn,
processID string,
timeout time.Duration,
) ExecuteResult {
// Block until the process exits or the context is
// canceled.
resp, err := conn.ProcessOutput(ctx, processID, &workspacesdk.ProcessOutputOptions{
Wait: true,
})
if err != nil {
if ctx.Err() != nil {
// Timeout: fetch final snapshot with a fresh
// context. The blocking request was canceled
// so the response body was lost.
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
// Timeout — get whatever output we have. Use a
// fresh context since cmdCtx is already canceled.
bgCtx, bgCancel := context.WithTimeout(
context.Background(),
snapshotTimeout,
5*time.Second,
)
defer bgCancel()
resp, err = conn.ProcessOutput(bgCtx, processID, nil)
outputResp, outputErr := conn.ProcessOutput(bgCtx, processID)
bgCancel()
output := truncateOutput(outputResp.Output)
timeoutErr := xerrors.Errorf("command timed out after %s", timeout)
if outputErr != nil {
timeoutErr = errors.Join(timeoutErr, xerrors.Errorf("failed to get output: %w", outputErr))
}
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: timeoutErr.Error(),
Truncated: outputResp.Truncated,
}
case <-ticker.C:
outputResp, err := conn.ProcessOutput(ctx, processID)
if err != nil {
return ExecuteResult{
Success: false,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err),
BackgroundProcessID: processID,
Success: false,
Error: fmt.Sprintf("get process output: %v", err),
}
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: resp.Truncated,
BackgroundProcessID: processID,
if !outputResp.Running {
exitCode := 0
if outputResp.ExitCode != nil {
exitCode = *outputResp.ExitCode
}
output := truncateOutput(outputResp.Output)
return ExecuteResult{
Success: exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: outputResp.Truncated,
}
}
}
return ExecuteResult{
Success: false,
Error: fmt.Sprintf("get process output: %v", err),
}
}
// The server-side wait may return before the
// process exits if maxWaitDuration is shorter than
// the client's timeout. Retry if our context still
// has time left.
if resp.Running {
if ctx.Err() == nil {
// Still within the caller's timeout, retry.
return waitForProcess(ctx, conn, processID, timeout)
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: resp.Truncated,
BackgroundProcessID: processID,
}
}
exitCode := 0
if resp.ExitCode != nil {
exitCode = *resp.ExitCode
}
output := truncateOutput(resp.Output)
return ExecuteResult{
Success: exitCode == 0,
Output: output,
ExitCode: exitCode,
Truncated: resp.Truncated,
}
}
@@ -343,19 +322,10 @@ func detectFileDump(command string) string {
return ""
}
const (
// defaultProcessOutputTimeout is the default time the
// process_output tool blocks waiting for new output or
// process exit before returning. This avoids polling
// loops that waste tokens and HTTP round-trips.
defaultProcessOutputTimeout = 10 * time.Second
)
// ProcessOutputArgs are the parameters accepted by the
// process_output tool.
type ProcessOutputArgs struct {
ProcessID string `json:"process_id"`
WaitTimeout *string `json:"wait_timeout,omitempty" description:"Override the default 10s block duration. The call blocks until the process exits or this timeout is reached. Set to '0s' for an immediate snapshot without waiting."`
ProcessID string `json:"process_id"`
}
// ProcessOutput returns an AgentTool that retrieves the output
@@ -365,13 +335,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
"process_output",
"Retrieve output from a background process. "+
"Use the process_id returned by execute with "+
"run_in_background=true or from a timed-out "+
"execute's background_process_id. Blocks up to "+
"10s for the process to exit, then returns the "+
"output and exit_code. If still running after "+
"the timeout, returns the output so far. Use "+
"wait_timeout to override the default 10s wait "+
"(e.g. '30s', or '0s' for an immediate snapshot).",
"run_in_background=true. Returns the current output, "+
"whether the process is still running, and the exit "+
"code if it has finished.",
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -383,42 +349,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
timeout := defaultProcessOutputTimeout
if args.WaitTimeout != nil {
parsed, err := time.ParseDuration(*args.WaitTimeout)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("invalid wait_timeout %q: %v", *args.WaitTimeout, err),
), nil
}
timeout = parsed
}
var opts *workspacesdk.ProcessOutputOptions
// Save parent context before applying timeout.
parentCtx := ctx
if timeout > 0 {
opts = &workspacesdk.ProcessOutputOptions{
Wait: true,
}
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
resp, err := conn.ProcessOutput(ctx, args.ProcessID, opts)
resp, err := conn.ProcessOutput(ctx, args.ProcessID)
if err != nil {
// If our wait timed out but the parent is still alive,
// fetch a non-blocking snapshot.
if ctx.Err() == nil || parentCtx.Err() != nil {
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
bgCtx, bgCancel := context.WithTimeout(parentCtx, snapshotTimeout)
defer bgCancel()
resp, err = conn.ProcessOutput(bgCtx, args.ProcessID, nil)
if err != nil {
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
// Fall through to normal response handling below.
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
}
output := truncateOutput(resp.Output)
exitCode := 0
@@ -432,7 +365,7 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
Truncated: resp.Truncated,
}
if resp.Running {
// Process is still running, success is not
// Process is still running success is not
// yet determined.
result.Success = true
result.Note = "process is still running"
@@ -1,100 +0,0 @@
package chattool
import (
"context"
"encoding/json"
"strings"
"testing"
"unicode/utf8"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestTruncateOutput(t *testing.T) {
t.Parallel()
t.Run("EmptyOutput", func(t *testing.T) {
t.Parallel()
result := runForegroundWithOutput(t, "")
assert.Empty(t, result.Output)
})
t.Run("ShortOutput", func(t *testing.T) {
t.Parallel()
result := runForegroundWithOutput(t, "short")
assert.Equal(t, "short", result.Output)
})
t.Run("ExactlyAtLimit", func(t *testing.T) {
t.Parallel()
output := strings.Repeat("a", maxOutputToModel)
result := runForegroundWithOutput(t, output)
assert.Equal(t, maxOutputToModel, len(result.Output))
assert.Equal(t, output, result.Output)
})
t.Run("OverLimit", func(t *testing.T) {
t.Parallel()
output := strings.Repeat("b", maxOutputToModel+1024)
result := runForegroundWithOutput(t, output)
assert.Equal(t, maxOutputToModel, len(result.Output))
})
t.Run("MultiByteCutMidCharacter", func(t *testing.T) {
t.Parallel()
// Build output that places a 3-byte UTF-8 character
// (U+2603, snowman ☃) right at the truncation boundary
// so the cut falls mid-character.
padding := strings.Repeat("x", maxOutputToModel-1)
output := padding + "☃" // ☃ is 3 bytes, only 1 byte fits
result := runForegroundWithOutput(t, output)
assert.LessOrEqual(t, len(result.Output), maxOutputToModel)
assert.True(t, utf8.ValidString(result.Output),
"truncated output must be valid UTF-8")
})
}
// runForegroundWithOutput runs a foreground command through the
// Execute tool with a mock that returns the given output, and
// returns the parsed result.
func runForegroundWithOutput(t *testing.T, output string) ExecuteResult {
t.Helper()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: output,
}, nil)
tool := Execute(ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo test"}`,
})
require.NoError(t, err)
var result ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
return result
}
-493
View File
@@ -1,493 +0,0 @@
package chattool_test
import (
"context"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
)
func TestExecuteTool(t *testing.T) {
t.Parallel()
t.Run("EmptyCommand", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
tool := newExecuteTool(t, mockConn)
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":""}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "command is required")
})
t.Run("AmpersandDetection", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
command string
runInBackground *bool
wantCommand string
wantBackground bool
wantBackgroundResp bool // true if the response should contain a background_process_id
comment string
}{
{
name: "SimpleBackground",
command: "cmd &",
wantCommand: "cmd",
wantBackground: true,
wantBackgroundResp: true,
comment: "Trailing & is correctly detected and stripped.",
},
{
name: "TrailingDoubleAmpersand",
command: "cmd &&",
wantCommand: "cmd &&",
wantBackground: false,
wantBackgroundResp: false,
comment: "Ends with &&, excluded by the && suffix check.",
},
{
name: "NoAmpersand",
command: "cmd",
wantCommand: "cmd",
wantBackground: false,
wantBackgroundResp: false,
},
{
name: "ChainThenBackground",
command: "cmd1 && cmd2 &",
wantCommand: "cmd1 && cmd2",
wantBackground: true,
wantBackgroundResp: true,
comment: "Ends with & but not &&, so it gets promoted " +
"to background and the trailing & is stripped. " +
"The remaining command runs in background mode.",
},
{
// "|&" is bash's pipe-stderr operator, not
// backgrounding. It must not be detected as a
// trailing "&".
name: "BashPipeStderr",
command: "cmd |&",
wantCommand: "cmd |&",
wantBackground: false,
wantBackgroundResp: false,
},
{
name: "AlreadyBackgroundWithTrailingAmpersand",
command: "cmd &",
runInBackground: ptr(true),
wantCommand: "cmd &",
wantBackground: true,
wantBackgroundResp: true,
comment: "When run_in_background is already true, " +
"the stripping logic is skipped, preserving " +
"the original command.",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
var capturedReq workspacesdk.StartProcessRequest
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
capturedReq = req
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
})
// For foreground cases, ProcessOutput is polled.
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
}, nil).
AnyTimes()
tool := newExecuteTool(t, mockConn)
input := map[string]any{"command": tc.command}
if tc.runInBackground != nil {
input["run_in_background"] = *tc.runInBackground
}
inputJSON, err := json.Marshal(input)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: string(inputJSON),
})
require.NoError(t, err)
assert.False(t, resp.IsError, "response should not be an error")
assert.Equal(t, tc.wantCommand, capturedReq.Command,
"command passed to StartProcess")
assert.Equal(t, tc.wantBackground, capturedReq.Background,
"background flag passed to StartProcess")
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
if tc.wantBackgroundResp {
assert.NotEmpty(t, result.BackgroundProcessID,
"expected background_process_id in response")
} else {
assert.Empty(t, result.BackgroundProcessID,
"expected no background_process_id")
}
})
}
})
t.Run("ForegroundSuccess", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
var capturedReq workspacesdk.StartProcessRequest
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
capturedReq = req
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
})
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "hello world",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hello"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.True(t, result.Success)
assert.Equal(t, 0, result.ExitCode)
assert.Equal(t, "hello world", result.Output)
assert.Empty(t, result.BackgroundProcessID)
assert.Equal(t, "true", capturedReq.Env["CODER_CHAT_AGENT"])
})
t.Run("ForegroundNonZeroExit", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 42
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "something failed",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"exit 42"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Equal(t, 42, result.ExitCode)
assert.Equal(t, "something failed", result.Output)
})
t.Run("BackgroundExecution", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
assert.True(t, req.Background)
return workspacesdk.StartProcessResponse{ID: "bg-42"}, nil
})
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"sleep 999","run_in_background":true}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.True(t, result.Success)
assert.Equal(t, "bg-42", result.BackgroundProcessID)
})
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
// First call (blocking wait) returns context error
// because the 50ms timeout expires.
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
DoAndReturn(func(ctx context.Context, _ string, _ *workspacesdk.ProcessOutputOptions) (workspacesdk.ProcessOutputResponse, error) {
<-ctx.Done()
return workspacesdk.ProcessOutputResponse{}, ctx.Err()
})
// Second call (snapshot fallback) returns partial output.
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: true,
Output: "partial output",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
// 50ms timeout expires during the blocking wait.
Input: `{"command":"sleep 999","timeout":"50ms"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Equal(t, -1, result.ExitCode)
assert.Contains(t, result.Error, "timed out")
assert.Equal(t, "partial output", result.Output)
})
t.Run("StartProcessError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{}, xerrors.New("connection lost"))
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
// Errors from StartProcess are returned as a JSON body
// with success=false, not as a ToolResponse error.
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Contains(t, result.Error, "connection lost")
})
t.Run("ProcessOutputError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected"))
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.False(t, resp.IsError)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
assert.False(t, result.Success)
assert.Contains(t, result.Error, "agent disconnected")
})
t.Run("GetWorkspaceConnNil", func(t *testing.T) {
t.Parallel()
tool := chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: nil,
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "not configured")
})
t.Run("GetWorkspaceConnError", func(t *testing.T) {
t.Parallel()
tool := chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("workspace offline")
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: `{"command":"echo hi"}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "workspace offline")
})
}
func TestDetectFileDump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
command string
wantHit bool
}{
{
name: "CatFile",
command: "cat foo.txt",
wantHit: true,
},
{
name: "NotCatPrefix",
command: "concatenate foo",
wantHit: false,
},
{
name: "GrepIncludeAll",
command: "grep --include-all pattern",
wantHit: true,
},
{
name: "RgListFiles",
command: "rg -l pattern",
wantHit: true,
},
{
name: "GrepRecursive",
command: "grep -r pattern",
wantHit: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().
StartProcess(gomock.Any(), gomock.Any()).
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
exitCode := 0
mockConn.EXPECT().
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
Return(workspacesdk.ProcessOutputResponse{
Running: false,
ExitCode: &exitCode,
Output: "output",
}, nil)
tool := newExecuteTool(t, mockConn)
ctx := testutil.Context(t, testutil.WaitMedium)
input, err := json.Marshal(map[string]any{
"command": tc.command,
})
require.NoError(t, err)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-1",
Name: "execute",
Input: string(input),
})
require.NoError(t, err)
var result chattool.ExecuteResult
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
if tc.wantHit {
assert.Contains(t, result.Note, "read_file",
"expected advisory note for %q", tc.command)
} else {
assert.Empty(t, result.Note,
"expected no note for %q", tc.command)
}
})
}
}
// newExecuteTool creates an Execute tool wired to the given mock.
func newExecuteTool(t *testing.T, mockConn *agentconnmock.MockAgentConn) fantasy.AgentTool {
t.Helper()
return chattool.Execute(chattool.ExecuteOptions{
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
},
})
}
func ptr[T any](v T) *T {
return &v
}
-150
View File
@@ -272,156 +272,6 @@ func logMessages(t *testing.T, msgs []codersdk.ChatMessage) {
}
}
// TestOpenAIReasoningRoundTrip is an integration test that verifies
// reasoning items from OpenAI's Responses API survive the full
// persist → reconstruct → re-send cycle when Store: true. It sends
// a query to a reasoning model, waits for completion, then sends a
// follow-up message. If reasoning items are sent back without their
// required following output item, the API rejects the second request:
//
// Item 'rs_xxx' of type 'reasoning' was provided without its
// required following item.
//
// The test requires OPENAI_API_KEY to be set.
func TestOpenAIReasoningRoundTrip(t *testing.T) {
t.Parallel()
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
t.Skip("OPENAI_API_KEY not set; skipping OpenAI integration test")
}
baseURL := os.Getenv("OPENAI_BASE_URL")
ctx := testutil.Context(t, testutil.WaitSuperLong)
// Stand up a full coderd with the agents experiment.
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
})
_ = coderdtest.CreateFirstUser(t, client)
// Configure an OpenAI provider with the real API key.
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
// Create a model config for a reasoning model with Store: true
// (the default). Using o4-mini because it always produces
// reasoning items.
contextLimit := int64(200000)
isDefault := true
reasoningSummary := "auto"
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
Model: "o4-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
ModelConfig: &codersdk.ChatModelCallConfig{
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
Store: ptr.Ref(true),
ReasoningSummary: &reasoningSummary,
},
},
},
})
require.NoError(t, err)
// --- Step 1: Send a message that triggers reasoning ---
t.Log("Creating chat with reasoning query...")
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "What is 2+2? Be brief.",
},
},
})
require.NoError(t, err)
t.Logf("Chat created: %s (status=%s)", chat.ID, chat.Status)
// Stream events until the chat reaches a terminal status.
events, closer, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer.Close()
waitForChatDone(ctx, t, events, "step 1")
// Verify the chat completed and messages were persisted.
chatData, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 1: %s, messages: %d",
chatData.Status, len(chatMsgs.Messages))
logMessages(t, chatMsgs.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status,
"chat should be in waiting status after step 1")
// Verify the assistant message has reasoning content.
assistantMsg := findAssistantWithText(t, chatMsgs.Messages)
require.NotNil(t, assistantMsg,
"expected an assistant message with text content after step 1")
partTypes := partTypeSet(assistantMsg.Content)
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeReasoning,
"assistant message should contain reasoning parts from o4-mini")
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeText,
"assistant message should contain a text part")
// --- Step 2: Send a follow-up message ---
// This is the critical test: if reasoning items are sent back
// without their required following item, the API will reject
// the request with:
// Item 'rs_xxx' of type 'reasoning' was provided without its
// required following item.
t.Log("Sending follow-up message...")
_, err = client.CreateChatMessage(ctx, chat.ID,
codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "And what is 3+3? Be brief.",
},
},
})
require.NoError(t, err)
// Stream the follow-up response.
events2, closer2, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer2.Close()
waitForChatDone(ctx, t, events2, "step 2")
// Verify the follow-up completed and produced content.
chatData2, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 2: %s, messages: %d",
chatData2.Status, len(chatMsgs2.Messages))
logMessages(t, chatMsgs2.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status,
"chat should be in waiting status after step 2")
require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages),
"follow-up should have added more messages")
// The last assistant message should have text.
lastAssistant := findLastAssistantWithText(t, chatMsgs2.Messages)
require.NotNil(t, lastAssistant,
"expected an assistant message with text in the follow-up")
t.Log("OpenAI reasoning round-trip test passed.")
}
// partTypeSet returns the set of part types present in a message.
func partTypeSet(parts []codersdk.ChatMessagePart) map[codersdk.ChatMessagePartType]struct{} {
set := make(map[codersdk.ChatMessagePartType]struct{}, len(parts))
-2
View File
@@ -62,7 +62,6 @@ func (p *Server) maybeGenerateChatTitle(
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
keys chatprovider.ProviderAPIKeys,
generatedTitle *generatedChatTitle,
logger slog.Logger,
) {
input, ok := titleInput(chat, messages)
@@ -112,7 +111,6 @@ func (p *Server) maybeGenerateChatTitle(
return
}
chat.Title = title
generatedTitle.Store(title)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
return
}
+3 -10
View File
@@ -84,14 +84,6 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
return false
}
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
enabled, err := p.db.GetChatDesktopEnabled(ctx)
if err != nil {
return false
}
return enabled
}
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
tools := []fantasy.AgentTool{
fantasy.NewAgentTool(
@@ -261,8 +253,9 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
}
// Only include the computer use tool when an Anthropic
// provider is configured and desktop is enabled.
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
// provider is configured, since it requires an Anthropic
// model.
if p.isAnthropicConfigured(ctx) {
tools = append(tools, fantasy.NewAgentTool(
"spawn_computer_use_agent",
"Spawn a dedicated computer use agent that can see the desktop "+
+3 -173
View File
@@ -15,7 +15,6 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
@@ -145,20 +144,14 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
return nil
}
func chatdTestContext(t *testing.T) context.Context {
t.Helper()
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// No Anthropic key in ProviderAPIKeys.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -183,13 +176,12 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the provider check passes.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
@@ -240,42 +232,16 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
}
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-desktop-disabled",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
}
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
// Provide an Anthropic key so the tool can proceed.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := chatdTestContext(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// The parent uses an OpenAI model.
@@ -332,139 +298,3 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
assert.NotEmpty(t, chattool.ComputerUseModelName)
}
func TestIsSubagentDescendant(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, model := seedInternalChatDeps(ctx, t, db)
// Build a chain: root -> child -> grandchild.
root, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "root",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("root")},
})
require.NoError(t, err)
child, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
Title: "child",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("child")},
})
require.NoError(t, err)
grandchild, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: child.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: root.ID,
Valid: true,
},
Title: "grandchild",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("grandchild")},
})
require.NoError(t, err)
// Build a separate, unrelated chain.
unrelated, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "unrelated-root",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated")},
})
require.NoError(t, err)
unrelatedChild, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: unrelated.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: unrelated.ID,
Valid: true,
},
Title: "unrelated-child",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated-child")},
})
require.NoError(t, err)
tests := []struct {
name string
ancestor uuid.UUID
target uuid.UUID
want bool
}{
{
name: "SameID",
ancestor: root.ID,
target: root.ID,
want: false,
},
{
name: "DirectChild",
ancestor: root.ID,
target: child.ID,
want: true,
},
{
name: "GrandChild",
ancestor: root.ID,
target: grandchild.ID,
want: true,
},
{
name: "Unrelated",
ancestor: root.ID,
target: unrelatedChild.ID,
want: false,
},
{
name: "RootChat",
ancestor: child.ID,
target: root.ID,
want: false,
},
{
name: "BrokenChain",
ancestor: root.ID,
target: uuid.New(),
want: false,
},
{
name: "NotDescendant",
ancestor: unrelated.ID,
target: child.ID,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := chatdTestContext(t)
got, err := isSubagentDescendant(ctx, db, tt.ancestor, tt.target)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
+58 -396
View File
@@ -22,7 +22,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -190,7 +189,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
return
}
params := database.GetChatsParams{
params := database.GetChatsByOwnerIDParams{
OwnerID: apiKey.UserID,
Archived: searchParams.Archived,
AfterID: paginationParams.AfterID,
@@ -200,7 +199,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
LimitOpt: int32(paginationParams.Limit),
}
chats, err := api.Database.GetChats(ctx, params)
chats, err := api.Database.GetChatsByOwnerID(ctx, params)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to list chats.",
@@ -284,41 +283,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
// Validate MCP server IDs exist.
if len(req.MCPServerIDs) > 0 {
//nolint:gocritic // Need to validate MCP server IDs exist.
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to validate MCP server IDs.",
Detail: err.Error(),
})
return
}
if len(existingConfigs) != len(req.MCPServerIDs) {
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
for _, c := range existingConfigs {
found[c.ID] = struct{}{}
}
var missing []string
for _, id := range req.MCPServerIDs {
if _, ok := found[id]; !ok {
missing = append(missing, id.String())
}
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "One or more MCP server IDs are invalid.",
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
})
return
}
}
mcpServerIDs := req.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
OwnerID: apiKey.UserID,
WorkspaceID: workspaceSelection.WorkspaceID,
@@ -326,7 +290,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
ModelConfigID: modelConfigID,
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
InitialUserContent: contentBlocks,
MCPServerIDs: mcpServerIDs,
})
if err != nil {
if maybeWriteLimitErr(ctx, rw, err) {
@@ -1406,58 +1369,64 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
logger.Debug(ctx, "desktop Bicopy finished")
}
// patchChat updates a chat resource. Currently supports toggling the
// archived state via the Archived field.
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
var req codersdk.UpdateChatRequest
if !httpapi.Read(ctx, rw, r, &req) {
if chat.Archived {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat is already archived.",
})
return
}
if req.Archived != nil {
archived := *req.Archived
if archived == chat.Archived {
state := "archived"
if !archived {
state = "not archived"
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Chat is already %s.", state),
})
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple archive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to archive chat.",
Detail: err.Error(),
})
return
}
var err error
// Use chatDaemon when available so it can notify active
// subscribers. Fall back to direct DB for the simple
// archive flag — no streaming state is involved.
if archived {
if api.chatDaemon != nil {
err = api.chatDaemon.ArchiveChat(ctx, chat)
} else {
err = api.Database.ArchiveChatByID(ctx, chat.ID)
}
} else {
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
}
if err != nil {
action := "archive"
if !archived {
action = "unarchive"
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: fmt.Sprintf("Failed to %s chat.", action),
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
if !chat.Archived {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Chat is not archived.",
})
return
}
var err error
// Use chatDaemon when available so it can notify
// active subscribers. Fall back to direct DB for the
// simple unarchive flag — no streaming state is involved.
if api.chatDaemon != nil {
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
} else {
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to unarchive chat.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
@@ -1492,36 +1461,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
return
}
// Validate MCP server IDs exist.
if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 {
//nolint:gocritic // Need to validate MCP server IDs exist.
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to validate MCP server IDs.",
Detail: err.Error(),
})
return
}
if len(existingConfigs) != len(*req.MCPServerIDs) {
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
for _, c := range existingConfigs {
found[c.ID] = struct{}{}
}
var missing []string
for _, id := range *req.MCPServerIDs {
if _, ok := found[id]; !ok {
missing = append(missing, id.String())
}
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "One or more MCP server IDs are invalid.",
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
})
return
}
}
sendResult, sendErr := api.chatDaemon.SendMessage(
ctx,
chatd.SendMessageOptions{
@@ -1530,7 +1469,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
Content: contentBlocks,
ModelConfigID: req.ModelConfigID,
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
MCPServerIDs: req.MCPServerIDs,
},
)
if sendErr != nil {
@@ -2587,14 +2525,14 @@ func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
SystemPrompt: prompt,
})
}
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req codersdk.ChatSystemPrompt
var req codersdk.UpdateChatSystemPromptRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
@@ -2622,49 +2560,6 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
enabled, err := api.Database.GetChatDesktopEnabled(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching desktop setting.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDesktopEnabledResponse{
EnableDesktop: enabled,
})
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.UpdateChatDesktopEnabledRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if err := api.Database.UpsertChatDesktopEnabled(ctx, req.EnableDesktop); httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
} else if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating desktop setting.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
@@ -2687,7 +2582,7 @@ func (api *API) getUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
customPrompt = ""
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
CustomPrompt: customPrompt,
})
}
@@ -2699,7 +2594,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
apiKey = httpmw.APIKey(r)
)
var params codersdk.UserChatCustomPrompt
var params codersdk.UpdateUserChatCustomPromptRequest
if !httpapi.Read(ctx, rw, r, &params) {
return
}
@@ -2726,7 +2621,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
CustomPrompt: updatedConfig.Value,
})
}
@@ -3046,10 +2941,6 @@ func truncateRunes(value string, maxLen int) string {
}
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
mcpServerIDs := c.MCPServerIDs
if mcpServerIDs == nil {
mcpServerIDs = []uuid.UUID{}
}
chat := codersdk.Chat{
ID: c.ID,
OwnerID: c.OwnerID,
@@ -3059,7 +2950,6 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
Archived: c.Archived,
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
MCPServerIDs: mcpServerIDs,
}
if c.LastError.Valid {
chat.LastError = &c.LastError.String
@@ -3485,16 +3375,6 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
}
if err := api.Database.DeleteChatProviderByID(ctx, providerID); err != nil {
if database.IsForeignKeyViolation(err,
database.ForeignKeyChatMessagesModelConfigID,
database.ForeignKeyChatsLastModelConfigID,
) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Provider models are still referenced by existing chats.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete chat provider.",
Detail: err.Error(),
@@ -4262,221 +4142,3 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
)
return effectiveKeys.APIKey(provider.Provider) != ""
}
// @Summary Get PR insights
// @ID get-pr-insights
// @Security CoderSessionToken
// @Tags Chats
// @Produce json
// @Param start_date query string true "Start date (RFC3339)"
// @Param end_date query string true "End date (RFC3339)"
// @Success 200 {object} codersdk.PRInsightsResponse
// @Router /chats/insights/pull-requests [get]
// @x-apidocgen {"skip": true}
func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Admin-only endpoint.
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
// Parse date range.
now := time.Now()
defaultStart := now.AddDate(0, 0, -30)
qp := r.URL.Query()
p := httpapi.NewQueryParamParser()
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
endDate := p.Time(qp, now, "end_date", time.RFC3339)
p.ErrorExcessParams(qp)
if len(p.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameters.",
Validations: p.Errors,
})
return
}
// Calculate previous period of equal length for trend comparison.
duration := endDate.Sub(startDate)
prevStart := startDate.Add(-duration)
// No owner filter — admin sees all data.
ownerID := uuid.NullUUID{}
// Run all queries in parallel.
var (
currentSummary database.GetPRInsightsSummaryRow
previousSummary database.GetPRInsightsSummaryRow
timeSeries []database.GetPRInsightsTimeSeriesRow
byModel []database.GetPRInsightsPerModelRow
recentPRs []database.GetPRInsightsRecentPRsRow
)
eg, egCtx := errgroup.WithContext(ctx)
eg.SetLimit(5)
eg.Go(func() error {
var err error
currentSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
})
return err
})
eg.Go(func() error {
var err error
previousSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
StartDate: prevStart,
EndDate: startDate,
OwnerID: ownerID,
})
return err
})
eg.Go(func() error {
var err error
timeSeries, err = api.Database.GetPRInsightsTimeSeries(egCtx, database.GetPRInsightsTimeSeriesParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
})
return err
})
eg.Go(func() error {
var err error
byModel, err = api.Database.GetPRInsightsPerModel(egCtx, database.GetPRInsightsPerModelParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
})
return err
})
eg.Go(func() error {
var err error
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
StartDate: startDate,
EndDate: endDate,
OwnerID: ownerID,
LimitVal: 20,
})
return err
})
if err := eg.Wait(); err != nil {
httpapi.InternalServerError(rw, err)
return
}
// Build summary with computed fields.
summary := codersdk.PRInsightsSummary{
TotalPRsCreated: currentSummary.TotalPrsCreated,
TotalPRsMerged: currentSummary.TotalPrsMerged,
TotalAdditions: currentSummary.TotalAdditions,
TotalDeletions: currentSummary.TotalDeletions,
TotalCostMicros: currentSummary.TotalCostMicros,
PrevTotalPRsCreated: previousSummary.TotalPrsCreated,
PrevTotalPRsMerged: previousSummary.TotalPrsMerged,
}
if summary.TotalPRsCreated > 0 {
summary.MergeRate = float64(summary.TotalPRsMerged) / float64(summary.TotalPRsCreated)
}
if summary.TotalPRsMerged > 0 {
summary.CostPerMergedPRMicros = currentSummary.MergedCostMicros / summary.TotalPRsMerged
}
if summary.PrevTotalPRsCreated > 0 {
summary.PrevMergeRate = float64(summary.PrevTotalPRsMerged) / float64(summary.PrevTotalPRsCreated)
}
if summary.PrevTotalPRsMerged > 0 {
summary.PrevCostPerMergedPRMicros = previousSummary.MergedCostMicros / summary.PrevTotalPRsMerged
}
// Convert time series.
tsEntries := make([]codersdk.PRInsightsTimeSeriesEntry, 0, len(timeSeries))
for _, ts := range timeSeries {
tsEntries = append(tsEntries, codersdk.PRInsightsTimeSeriesEntry{
Date: ts.Date,
PRsCreated: ts.PrsCreated,
PRsMerged: ts.PrsMerged,
PRsClosed: ts.PrsClosed,
})
}
// Convert model breakdown.
modelEntries := make([]codersdk.PRInsightsModelBreakdown, 0, len(byModel))
for _, m := range byModel {
entry := codersdk.PRInsightsModelBreakdown{
ModelConfigID: m.ModelConfigID.UUID,
DisplayName: m.DisplayName,
Provider: m.Provider,
TotalPRs: m.TotalPrs,
MergedPRs: m.MergedPrs,
TotalAdditions: m.TotalAdditions,
TotalDeletions: m.TotalDeletions,
TotalCostMicros: m.TotalCostMicros,
}
if entry.TotalPRs > 0 {
entry.MergeRate = float64(entry.MergedPRs) / float64(entry.TotalPRs)
}
if entry.MergedPRs > 0 {
entry.CostPerMergedPRMicros = m.MergedCostMicros / entry.MergedPRs
}
modelEntries = append(modelEntries, entry)
}
// Convert recent PRs.
prEntries := make([]codersdk.PRInsightsPullRequest, 0, len(recentPRs))
for _, pr := range recentPRs {
entry := codersdk.PRInsightsPullRequest{
ChatID: pr.ChatID,
PRTitle: pr.PrTitle,
Draft: pr.Draft,
Additions: pr.Additions,
Deletions: pr.Deletions,
ChangedFiles: pr.ChangedFiles,
ChangesRequested: pr.ChangesRequested,
BaseBranch: pr.BaseBranch,
ModelDisplayName: pr.ModelDisplayName,
CostMicros: pr.CostMicros,
CreatedAt: pr.CreatedAt,
}
if pr.PrUrl.Valid {
entry.PRURL = &pr.PrUrl.String
}
if pr.PrNumber.Valid {
entry.PRNumber = &pr.PrNumber.Int32
}
if pr.State.Valid {
entry.State = pr.State.String
}
if pr.Commits.Valid {
entry.Commits = &pr.Commits.Int32
}
if pr.Approved.Valid {
entry.Approved = &pr.Approved.Bool
}
if pr.ReviewerCount.Valid {
entry.ReviewerCount = &pr.ReviewerCount.Int32
}
if pr.AuthorLogin.Valid {
entry.AuthorLogin = &pr.AuthorLogin.String
}
if pr.AuthorAvatarUrl.Valid {
entry.AuthorAvatarURL = &pr.AuthorAvatarUrl.String
}
prEntries = append(prEntries, entry)
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
Summary: summary,
TimeSeries: tsEntries,
ByModel: modelEntries,
RecentPRs: prEntries,
})
}
+95 -302
View File
@@ -6,7 +6,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"mime"
"net/http"
"net/http/httptest"
"regexp"
@@ -30,7 +29,6 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
@@ -128,24 +126,22 @@ func insertAssistantCostMessage(
})
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chatID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfigID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(assistantContent.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{totalCostMicros},
RuntimeMs: []int64{0},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
Role: database.ChatMessageRoleAssistant,
ContentVersion: chatprompt.CurrentContentVersion,
Content: assistantContent,
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{},
OutputTokens: sql.NullInt64{},
TotalTokens: sql.NullInt64{},
ReasoningTokens: sql.NullInt64{},
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
})
require.NoError(t, err)
}
@@ -1691,7 +1687,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
require.Len(t, chatsBeforeArchive, 2)
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, chatToArchive.ID)
require.NoError(t, err)
// Default (no filter) returns only non-archived chats.
@@ -1725,7 +1721,7 @@ func TestArchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err := client.ArchiveChat(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
@@ -1768,7 +1764,7 @@ func TestArchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the parent via the API.
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, parentChat.ID)
require.NoError(t, err)
// archived:false should exclude the entire archived family.
@@ -1815,7 +1811,7 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Archive the chat first.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
err = client.ArchiveChat(ctx, chat.ID)
require.NoError(t, err)
// Verify it's archived.
@@ -1826,7 +1822,7 @@ func TestUnarchiveChat(t *testing.T) {
require.Len(t, archivedChats, 1)
require.True(t, archivedChats[0].Archived)
// Unarchive the chat.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err = client.UnarchiveChat(ctx, chat.ID)
require.NoError(t, err)
// Verify it's no longer archived.
@@ -1865,9 +1861,10 @@ func TestUnarchiveChat(t *testing.T) {
require.NoError(t, err)
// Trying to unarchive a non-archived chat should fail.
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err = client.UnarchiveChat(ctx, chat.ID)
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
@@ -1875,7 +1872,7 @@ func TestUnarchiveChat(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client)
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
err := client.UnarchiveChat(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
}
@@ -2663,9 +2660,7 @@ func TestPatchChatMessage(t *testing.T) {
},
})
require.NoError(t, err)
// The edited message is soft-deleted and a new one is inserted,
// so the returned ID will differ from the original.
require.NotEqual(t, userMessageID, edited.ID)
require.Equal(t, userMessageID, edited.ID)
require.Equal(t, codersdk.ChatMessageRoleUser, edited.Role)
foundEditedText := false
@@ -2755,9 +2750,7 @@ func TestPatchChatMessage(t *testing.T) {
},
})
require.NoError(t, err)
// The edited message is soft-deleted and a new one is inserted,
// so the returned ID will differ from the original.
require.NotEqual(t, userMessageID, edited.ID)
require.Equal(t, userMessageID, edited.ID)
// Assert the edit response preserves the file_id.
var foundText, foundFile bool
@@ -3991,30 +3984,6 @@ func TestGetChatFile(t *testing.T) {
require.NotContains(t, cd, strings.Repeat("a", 256))
})
t.Run("UnicodeFilename", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, client)
// Upload with a non-ASCII filename using RFC 5987 encoding,
// which is what the frontend sends for Unicode filenames.
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "スクリーンショット.png", bytes.NewReader(data))
require.NoError(t, err)
res, err := client.Request(ctx, http.MethodGet,
fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
cd := res.Header.Get("Content-Disposition")
require.Contains(t, cd, "inline")
_, params, err := mime.ParseMediaType(cd)
require.NoError(t, err)
require.Equal(t, "スクリーンショット.png", params["filename"])
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -4088,35 +4057,24 @@ func seedChatCostFixture(t *testing.T) chatCostTestFixture {
})
require.NoError(t, err)
results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil, uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID, modelConfig.ID},
Role: []database.ChatMessageRole{"assistant", "assistant"},
Content: []string{"null", "null"},
ContentVersion: []int16{0, 0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{100, 100},
OutputTokens: []int64{50, 50},
TotalTokens: []int64{0, 0},
ReasoningTokens: []int64{0, 0},
CacheCreationTokens: []int64{0, 0},
CacheReadTokens: []int64{0, 0},
ContextLimit: []int64{0, 0},
Compressed: []bool{false, false},
TotalCostMicros: []int64{500, 500},
RuntimeMs: []int64{0, 0},
})
require.NoError(t, err)
require.Len(t, results, 2)
earliestCreatedAt := results[0].CreatedAt
latestCreatedAt := results[0].CreatedAt
for _, msg := range results {
if msg.CreatedAt.Before(earliestCreatedAt) {
earliestCreatedAt = msg.CreatedAt
var earliestCreatedAt time.Time
var latestCreatedAt time.Time
for i := 0; i < 2; i++ {
message, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
if i == 0 || message.CreatedAt.Before(earliestCreatedAt) {
earliestCreatedAt = message.CreatedAt
}
if msg.CreatedAt.After(latestCreatedAt) {
latestCreatedAt = msg.CreatedAt
if i == 0 || message.CreatedAt.After(latestCreatedAt) {
latestCreatedAt = message.CreatedAt
}
}
@@ -4204,27 +4162,16 @@ func TestChatCostSummary_AdminDrilldown(t *testing.T) {
})
require.NoError(t, err)
results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{200},
OutputTokens: []int64{100},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{750},
RuntimeMs: []int64{0},
message, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
})
require.NoError(t, err)
message := results[0]
options := codersdk.ChatCostSummaryOptions{
// Pad the DB-assigned timestamp so the query window cannot race it.
StartDate: message.CreatedAt.Add(-time.Minute),
@@ -4270,24 +4217,14 @@ func TestChatCostUsers(t *testing.T) {
Title: "admin chat",
})
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
ChatID: adminChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{100},
OutputTokens: []int64{50},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{300},
RuntimeMs: []int64{0},
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: adminChat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
})
require.NoError(t, err)
@@ -4297,24 +4234,14 @@ func TestChatCostUsers(t *testing.T) {
Title: "member chat",
})
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
ChatID: memberChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{200},
OutputTokens: []int64{100},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{800},
RuntimeMs: []int64{0},
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: memberChat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
})
require.NoError(t, err)
@@ -4381,24 +4308,14 @@ func TestChatCostSummary_DateRange(t *testing.T) {
})
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{100},
OutputTokens: []int64{50},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{500},
RuntimeMs: []int64{0},
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
@@ -4446,49 +4363,27 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
})
require.NoError(t, err)
pricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{100},
OutputTokens: []int64{50},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{500},
RuntimeMs: []int64{0},
pricedMessage, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
pricedMessage := pricedResults[0]
unpricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{"assistant"},
Content: []string{"null"},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{200},
OutputTokens: []int64{75},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
unpricedMessage, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
TotalCostMicros: sql.NullInt64{},
})
require.NoError(t, err)
unpricedMessage := unpricedResults[0]
earliestCreatedAt := pricedMessage.CreatedAt
latestCreatedAt := pricedMessage.CreatedAt
@@ -4567,7 +4462,7 @@ func TestWatchChatDesktop(t *testing.T) {
res, err := client.Request(
ctx,
http.MethodGet,
fmt.Sprintf("/api/experimental/chats/%s/stream/desktop", createdChat.ID),
fmt.Sprintf("/api/experimental/chats/%s/desktop", createdChat.ID),
nil,
)
require.NoError(t, err)
@@ -4617,7 +4512,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("AdminCanSet", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "You are a helpful coding assistant.",
})
require.NoError(t, err)
@@ -4631,7 +4526,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// Unset by sending an empty string.
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "",
})
require.NoError(t, err)
@@ -4644,7 +4539,7 @@ func TestChatSystemPrompt(t *testing.T) {
t.Run("NonAdminFails", func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: "This should fail.",
})
requireSDKError(t, err, http.StatusNotFound)
@@ -4665,7 +4560,7 @@ func TestChatSystemPrompt(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
tooLong := strings.Repeat("a", 131073)
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
SystemPrompt: tooLong,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
@@ -4673,108 +4568,6 @@ func TestChatSystemPrompt(t *testing.T) {
})
}
func TestChatDesktopEnabled(t *testing.T) {
t.Parallel()
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("AdminCanSetTrue", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("AdminCanSetFalse", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
// Set true first, then set false.
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: false,
})
require.NoError(t, err)
resp, err := adminClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.False(t, resp.EnableDesktop)
})
t.Run("NonAdminCanRead", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
require.NoError(t, err)
resp, err := memberClient.GetChatDesktopEnabled(ctx)
require.NoError(t, err)
require.True(t, resp.EnableDesktop)
})
t.Run("NonAdminWriteFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
EnableDesktop: true,
})
requireSDKError(t, err, http.StatusForbidden)
})
t.Run("UnauthenticatedFails", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
coderdtest.CreateFirstUser(t, adminClient)
anonClient := codersdk.New(adminClient.URL)
_, err := anonClient.GetChatDesktopEnabled(ctx)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+17 -55
View File
@@ -10,7 +10,6 @@ import (
"flag"
"fmt"
"io"
"math"
"net/http"
httppprof "net/http/pprof"
"net/url"
@@ -767,27 +766,17 @@ func New(options *Options) *API {
}
api.agentProvider = stn
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
if maxChatsPerAcquire > math.MaxInt32 {
maxChatsPerAcquire = math.MaxInt32
}
if maxChatsPerAcquire < math.MinInt32 {
maxChatsPerAcquire = math.MinInt32
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
Logger: options.Logger.Named("chats"),
Database: options.Database,
ReplicaID: api.ID,
SubscribeFn: options.ChatSubscribeFn,
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
AgentConn: api.agentProvider.AgentConn,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
})
gitSyncLogger := options.Logger.Named("gitsync")
refresher := gitsync.NewRefresher(
@@ -1044,12 +1033,10 @@ func New(options *Options) *API {
// OAuth2 metadata endpoint for RFC 8414 discovery
r.Route("/.well-known/oauth-authorization-server", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
r.Get("/*", api.oauth2AuthorizationServerMetadata())
})
// OAuth2 protected resource metadata endpoint for RFC 9728 discovery
r.Route("/.well-known/oauth-protected-resource", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
r.Get("/*", api.oauth2ProtectedResourceMetadata())
})
@@ -1162,9 +1149,6 @@ func New(options *Options) *API {
r.Get("/summary", api.chatCostSummary)
})
})
r.Route("/insights", func(r chi.Router) {
r.Get("/pull-requests", api.prInsights)
})
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
@@ -1173,8 +1157,6 @@ func New(options *Options) *API {
r.Route("/config", func(r chi.Router) {
r.Get("/system-prompt", api.getChatSystemPrompt)
r.Put("/system-prompt", api.putChatSystemPrompt)
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
r.Get("/user-prompt", api.getUserChatCustomPrompt)
r.Put("/user-prompt", api.putUserChatCustomPrompt)
})
@@ -1212,15 +1194,14 @@ func New(options *Options) *API {
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Patch("/", api.patchChat)
r.Get("/git/watch", api.watchChatGit)
r.Get("/desktop", api.watchChatDesktop)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Get("/messages", api.getChatMessages)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
r.Route("/stream", func(r chi.Router) {
r.Get("/", api.streamChat)
r.Get("/desktop", api.watchChatDesktop)
r.Get("/git", api.watchChatGit)
})
r.Get("/stream", api.streamChat)
r.Post("/interrupt", api.interruptChat)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
@@ -1233,27 +1214,10 @@ func New(options *Options) *API {
r.Route("/mcp", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
)
// MCP server configuration endpoints.
r.Route("/servers", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents))
r.Get("/", api.listMCPServerConfigs)
r.Post("/", api.createMCPServerConfig)
r.Route("/{mcpServer}", func(r chi.Router) {
r.Get("/", api.getMCPServerConfig)
r.Patch("/", api.updateMCPServerConfig)
r.Delete("/", api.deleteMCPServerConfig)
// OAuth2 user flow
r.Get("/oauth2/connect", api.mcpServerOAuth2Connect)
r.Get("/oauth2/callback", api.mcpServerOAuth2Callback)
r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect)
})
})
// MCP HTTP transport endpoint with mandatory authentication
r.Route("/http", func(r chi.Router) {
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP))
r.Mount("/", api.mcpHTTPHandler())
})
r.Mount("/http", api.mcpHTTPHandler())
})
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
r.Use(
@@ -1515,8 +1479,6 @@ func New(options *Options) *API {
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Post("/logout", api.postLogout)
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
r.Get("/oidc-claims", api.userOIDCClaims)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {
r.Get("/", api.AssignableSiteRoles)
-9
View File
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
m(&req)
}
// Service accounts cannot have a password or email and must
// use login_type=none. Enforce this after mutators so callers
// only need to set ServiceAccount=true.
if req.ServiceAccount {
req.Password = ""
req.Email = ""
req.UserLoginType = codersdk.LoginTypeNone
}
user, err := client.CreateUserWithOrgs(context.Background(), req)
var apiError *codersdk.Error
// If the user already exists by username or email conflict, try again up to "retries" times.
+7 -39
View File
@@ -13,64 +13,32 @@ var _ usage.Inserter = (*UsageInserter)(nil)
type UsageInserter struct {
sync.Mutex
discreteEvents []usagetypes.DiscreteEvent
heartbeatEvents []usagetypes.HeartbeatEvent
seenHeartbeats map[string]struct{}
events []usagetypes.DiscreteEvent
}
func NewUsageInserter() *UsageInserter {
return &UsageInserter{
discreteEvents: []usagetypes.DiscreteEvent{},
seenHeartbeats: map[string]struct{}{},
heartbeatEvents: []usagetypes.HeartbeatEvent{},
events: []usagetypes.DiscreteEvent{},
}
}
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
u.Lock()
defer u.Unlock()
u.discreteEvents = append(u.discreteEvents, event)
u.events = append(u.events, event)
return nil
}
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
u.Lock()
defer u.Unlock()
if _, seen := u.seenHeartbeats[id]; seen {
return nil
}
u.seenHeartbeats[id] = struct{}{}
u.heartbeatEvents = append(u.heartbeatEvents, event)
return nil
}
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
copy(eventsCopy, u.heartbeatEvents)
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
copy(eventsCopy, u.events)
return eventsCopy
}
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
u.Lock()
defer u.Unlock()
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
copy(eventsCopy, u.discreteEvents)
return eventsCopy
}
func (u *UsageInserter) TotalEventCount() int {
u.Lock()
defer u.Unlock()
return len(u.discreteEvents) + len(u.heartbeatEvents)
}
func (u *UsageInserter) Reset() {
u.Lock()
defer u.Unlock()
u.seenHeartbeats = map[string]struct{}{}
u.discreteEvents = []usagetypes.DiscreteEvent{}
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
u.events = []usagetypes.DiscreteEvent{}
}
-3
View File
@@ -20,9 +20,6 @@ const (
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs
CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs
CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
+7 -8
View File
@@ -195,14 +195,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
func ReducedUser(user database.User) codersdk.ReducedUser {
return codersdk.ReducedUser{
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
IsServiceAccount: user.IsServiceAccount,
MinimalUser: MinimalUser(user),
Email: user.Email,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
LastSeenAt: user.LastSeenAt,
Status: codersdk.UserStatus(user.Status),
LoginType: codersdk.LoginType(user.LoginType),
}
}
+22 -216
View File
@@ -705,7 +705,7 @@ var (
DisplayName: "Chat Daemon",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceWorkspace.Type: {policy.ActionRead},
rbac.ResourceDeploymentConfig.Type: {policy.ActionRead},
rbac.ResourceUser.Type: {policy.ActionReadPersonal},
}),
@@ -769,9 +769,6 @@ func AsSubAgentAPI(ctx context.Context, orgID uuid.UUID, userID uuid.UUID) conte
// AsSystemRestricted returns a context with an actor that has permissions
// required for various system operations (login, logout, metrics cache).
// DO NOT USE THIS UNLESS YOU HAVE ABSOLUTELY NO OTHER CHOICE. Prefer using a
// more specific As* helper above (or adding a new, narrowly-scoped one) so
// that permissions remain limited to the operation you need.
func AsSystemRestricted(ctx context.Context) context.Context {
return As(ctx, subjectSystemRestricted)
}
@@ -1267,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
// System roles are stored in the database but have a fixed, code-defined
// meaning. Do not rewrite the name for them so the static "who can assign
// what" mapping applies.
if !rolestore.IsSystemRoleName(roleName.Name) {
if !rbac.SystemRoleName(roleName.Name) {
// To support a dynamic mapping of what roles can assign what, we need
// to store this in the database. For now, just use a static role so
// owners and org admins can assign roles.
@@ -1694,13 +1691,6 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
return q.db.CleanTailnetTunnels(ctx)
}
func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
}
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
if err != nil {
@@ -1834,6 +1824,18 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
}
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -1930,20 +1932,6 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
return id, nil
}
func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerConfigByID(ctx, id)
}
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteMCPServerUserToken(ctx, arg)
}
func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil {
return err
@@ -2157,12 +2145,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
}
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
// This is a system-only function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
}
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
@@ -2494,17 +2482,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
// The desktop-enabled flag is a deployment-wide setting read by any
// authenticated chat user and by chatd when deciding whether to expose
// computer-use tooling. We only require that an explicit actor is present
// in the context so unauthenticated calls fail closed.
if _, ok := ActorFromContext(ctx); !ok {
return false, ErrNoActor
}
return q.db.GetChatDesktopEnabled(ctx)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2676,12 +2653,8 @@ func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedChats(ctx, arg, prep)
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
@@ -2787,13 +2760,6 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP
return q.db.GetEnabledChatProviders(ctx)
}
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetEnabledMCPServerConfigs(ctx)
}
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
}
@@ -2852,13 +2818,6 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg)
}
func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetForcedMCPServerConfigs(ctx)
}
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
}
@@ -2999,48 +2958,6 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}
func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigByID(ctx, id)
}
func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.GetMCPServerConfigBySlug(ctx, slug)
}
func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigs(ctx)
}
func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
}
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.GetMCPServerUserToken(ctx, arg)
}
func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetMCPServerUserTokensByUserID(ctx, userID)
}
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
return nil, err
@@ -3227,34 +3144,6 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsPerModel(ctx, arg)
}
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsRecentPRs(ctx, arg)
}
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetPRInsightsSummaryRow{}, err
}
return q.db.GetPRInsightsSummary(ctx, arg)
}
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.GetPRInsightsTimeSeries(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -4608,13 +4497,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
}
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
return database.AIBridgeModelThought{}, err
}
return q.db.InsertAIBridgeModelThought(ctx, arg)
}
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
// All aibridge_token_usages records belong to the initiator of their associated interception.
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
@@ -4671,16 +4553,16 @@ func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFil
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
}
func (q *querier) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
return database.ChatMessage{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return nil, err
return database.ChatMessage{}, err
}
return q.db.InsertChatMessages(ctx, arg)
return q.db.InsertChatMessage(ctx, arg)
}
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
@@ -4807,13 +4689,6 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
return q.db.InsertLicense(ctx, arg)
}
func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.InsertMCPServerConfig(ctx, arg)
}
func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
return database.WorkspaceAgentMemoryResourceMonitor{}, err
@@ -5466,32 +5341,6 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
return q.db.SelectUsageEventsForPublishing(ctx, arg)
}
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
msg, err := q.db.GetChatMessageByID(ctx, id)
if err != nil {
return err
}
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteChatMessageByID(ctx, id)
}
func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return err
}
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
return q.db.TryAcquireLock(ctx, id)
}
@@ -5573,17 +5422,6 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
return q.db.UpdateChatHeartbeat(ctx, arg)
}
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatMCPServerIDs(ctx, arg)
}
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
// Authorize update on the parent chat of the edited message.
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
@@ -5748,13 +5586,6 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da
return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args)
}
func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerConfig{}, err
}
return q.db.UpdateMCPServerConfig(ctx, arg)
}
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
@@ -6706,13 +6537,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
return q.db.UpsertBoundaryUsageStats(ctx, arg)
}
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
}
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
// Authorize update on the parent chat.
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
@@ -6800,13 +6624,6 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
return q.db.UpsertLogoURL(ctx, value)
}
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
}
return q.db.UpsertMCPServerUserToken(ctx, arg)
}
func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
@@ -6959,13 +6776,6 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
}
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
return false, err
}
return q.db.UsageEventExistsByID(ctx, id)
}
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
// This check is probably overly restrictive, but the "correct" check isn't
// necessarily obvious. It's only used as a verification check for ACLs right
@@ -7061,7 +6871,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
// database.Store interface, so dbauthz needs to implement it.
return q.ListAIBridgeModels(ctx, arg)
}
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
return q.GetChats(ctx, arg)
}
+32 -204
View File
@@ -401,27 +401,16 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.SoftDeleteChatMessagesAfterIDParams{
arg := database.DeleteChatMessagesAfterIDParams{
ChatID: chat.ID,
AfterID: 123,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("SoftDeleteChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := database.ChatMessage{
ID: 456,
ChatID: chat.ID,
}
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), msg.ID).Return(nil).AnyTimes()
check.Args(msg.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
@@ -629,17 +618,12 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because SQLFilter.
check.Args(params).Asserts()
}))
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
params := database.GetChatsParams{}
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
// No asserts here because it re-routes through GetChats which uses SQLFilter.
check.Args(params, emptyPreparedAuthorized{}).Asserts()
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
c1 := testutil.Fake(s.T(), faker, database.Chat{})
c2 := testutil.Fake(s.T(), faker, database.Chat{})
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
}))
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -652,10 +636,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
@@ -686,13 +666,13 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
}))
s.Run("InsertChatMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessagesParams{ChatID: chat.ID})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatMessages(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msgs)
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
}))
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -865,10 +845,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetUserChatSpendInPeriodParams{
UserID: uuid.New(),
@@ -1005,114 +981,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes()
check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
slug := "test-mcp-server"
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug})
dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes()
check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
ids := []uuid.UUID{configA.ID, configB.ID}
dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
}))
s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token)
}))
s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
userID := uuid.New()
tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})}
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
}))
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertMCPServerConfigParams{
DisplayName: "Test MCP Server",
Slug: "test-mcp-server",
}
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug})
dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatMCPServerIDsParams{
ID: chat.ID,
MCPServerIDs: []uuid.UUID{uuid.New()},
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
arg := database.UpdateMCPServerConfigParams{
ID: config.ID,
DisplayName: "Updated MCP Server",
Slug: "updated-mcp-server",
}
dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
AccessToken: "test-access-token",
TokenType: "bearer",
}
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1489,8 +1357,8 @@ func (s *MethodTestSuite) TestLicense() {
check.Args().Asserts().Returns("value")
}))
s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"})
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes()
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"})
}))
s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes()
@@ -1612,7 +1480,7 @@ func (s *MethodTestSuite) TestOrganization() {
org := testutil.Fake(s.T(), faker, database.Organization{})
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
ID: org.ID,
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
WorkspaceSharingDisabled: true,
}
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
@@ -2043,26 +1911,6 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
}))
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsSummaryParams{}
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsTimeSeriesParams{}
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsPerModelParams{}
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetPRInsightsRecentPRsParams{}
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
}))
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTelemetryTaskEventsParams{}
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
@@ -2551,12 +2399,9 @@ func (s *MethodTestSuite) TestWorkspace() {
check.Args(w.ID).Asserts(w, policy.ActionShare)
}))
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.DeleteWorkspaceACLsByOrganizationParams{
OrganizationID: uuid.New(),
ExcludeServiceAccounts: false,
}
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
orgID := uuid.New()
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
w := testutil.Fake(s.T(), faker, database.Workspace{})
@@ -5262,12 +5107,6 @@ func (s *MethodTestSuite) TestUsageEvents() {
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
}))
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
id := uuid.NewString()
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
}))
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
now := dbtime.Now()
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
@@ -5328,17 +5167,6 @@ func (s *MethodTestSuite) TestAIBridge() {
check.Args(params).Asserts(intc, policy.ActionCreate)
}))
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate)
}))
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intID := uuid.UUID{2}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
@@ -5697,16 +5525,12 @@ func TestAsChatd(t *testing.T) {
require.NoError(t, err, "chat %s should be allowed", action)
}
// Workspace read + update (update needed for ActivityBumpWorkspace).
for _, action := range []policy.Action{
policy.ActionRead, policy.ActionUpdate,
} {
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
require.NoError(t, err, "workspace %s should be allowed", action)
}
// Workspace read.
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace)
require.NoError(t, err, "workspace read should be allowed")
// DeploymentConfig read.
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
require.NoError(t, err, "deployment config read should be allowed")
// User read_personal (needed for GetUserChatCustomPrompt).
@@ -5717,12 +5541,16 @@ func TestAsChatd(t *testing.T) {
t.Run("DeniedActions", func(t *testing.T) {
t.Parallel()
// Cannot delete workspaces.
err := auth.Authorize(ctx, actor, policy.ActionDelete, rbac.ResourceWorkspace)
require.Error(t, err, "workspace delete should be denied")
// Cannot write workspaces.
for _, action := range []policy.Action{
policy.ActionUpdate, policy.ActionDelete,
} {
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
require.Error(t, err, "workspace %s should be denied", action)
}
// Cannot access users.
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
require.Error(t, err, "user read should be denied")
// Cannot access API keys.
+1 -2
View File
@@ -29,7 +29,6 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/rbac/regosql"
"github.com/coder/coder/v2/coderd/rbac/rolestore"
"github.com/coder/coder/v2/coderd/util/slice"
)
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
UUID: pair.OrganizationID,
Valid: pair.OrganizationID != uuid.Nil,
},
IsSystem: rolestore.IsSystemRoleName(pair.Name),
IsSystem: rbac.SystemRoleName(pair.Name),
ID: uuid.New(),
})
}
+25 -17
View File
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
})
require.NoError(t, err, "insert organization")
// Populate the placeholder system roles (created by DB
// trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileSystemRole needs the system:update
// Populate the placeholder organization-member system role (created by
// DB trigger/migration) so org members have expected permissions.
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
// permission that `genCtx` does not have.
sysCtx := dbauthz.AsSystemRestricted(genCtx)
for roleName := range rolestore.SystemRoleNames {
role := database.CustomRole{
Name: roleName,
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
}
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
require.NoError(t, err, "create role "+roleName)
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
}
require.NoError(t, err, "reconcile role "+roleName)
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
if errors.Is(err, sql.ErrNoRows) {
// The trigger that creates the placeholder role didn't run (e.g.,
// triggers were disabled in the test). Create the role manually.
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
require.NoError(t, err, "create organization-member role")
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
Name: rbac.RoleOrgMember(),
OrganizationID: uuid.NullUUID{
UUID: org.ID,
Valid: true,
},
}, org.WorkspaceSharingDisabled)
}
require.NoError(t, err, "reconcile organization-member role")
return org
}
+18 -220
View File
@@ -264,14 +264,6 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
return r0
}
func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
start := time.Now()
r0 := m.s.CleanupDeletedMCPServerIDsFromChats(ctx)
m.queryLatencies.WithLabelValues("CleanupDeletedMCPServerIDsFromChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CleanupDeletedMCPServerIDsFromChats").Inc()
return r0
}
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
@@ -392,6 +384,14 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
return r0
}
func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
start := time.Now()
r0 := m.s.DeleteChatMessagesAfterID(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesAfterID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
@@ -488,22 +488,6 @@ func (m queryMetricsStore) DeleteLicense(ctx context.Context, id int32) (int32,
return r0, r1
}
func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteMCPServerConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("DeleteMCPServerConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerConfigByID").Inc()
return r0
}
func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
start := time.Now()
r0 := m.s.DeleteMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserToken").Inc()
return r0
}
func (m queryMetricsStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteOAuth2ProviderAppByClientID(ctx, id)
@@ -712,11 +696,10 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
return r0
}
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
return r0
}
@@ -1064,14 +1047,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
@@ -1216,11 +1191,11 @@ func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, us
return r0, r1
}
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChats(ctx, arg)
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
return r0, r1
}
@@ -1352,14 +1327,6 @@ func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]datab
return r0, r1
}
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetEnabledMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
start := time.Now()
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
@@ -1416,14 +1383,6 @@ func (m queryMetricsStore) GetFilteredInboxNotificationsByUserID(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetForcedMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetForcedMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetForcedMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
start := time.Now()
r0, r1 := m.s.GetGitSSHKey(ctx, userID)
@@ -1584,54 +1543,6 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) {
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigByID(ctx, id)
m.queryLatencies.WithLabelValues("GetMCPServerConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigByID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigBySlug(ctx, slug)
m.queryLatencies.WithLabelValues("GetMCPServerConfigBySlug").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigBySlug").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigs(ctx)
m.queryLatencies.WithLabelValues("GetMCPServerConfigs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerConfigsByIDs(ctx, ids)
m.queryLatencies.WithLabelValues("GetMCPServerConfigsByIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigsByIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("GetMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserToken").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserTokensByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetMCPServerUserTokensByUserID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserTokensByUserID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
start := time.Now()
r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg)
@@ -1824,38 +1735,6 @@ func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Contex
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsPerModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsPerModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPerModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsSummary(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsSummary").Inc()
return r0, r1
}
func (m queryMetricsStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
start := time.Now()
r0, r1 := m.s.GetPRInsightsTimeSeries(ctx, arg)
m.queryLatencies.WithLabelValues("GetPRInsightsTimeSeries").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsTimeSeries").Inc()
return r0, r1
}
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
start := time.Now()
r0, r1 := m.s.GetParameterSchemasByJobID(ctx, jobID)
@@ -3072,14 +2951,6 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
@@ -3144,11 +3015,11 @@ func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.Inse
return r0, r1
}
func (m queryMetricsStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatMessages(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessages").Inc()
r0, r1 := m.s.InsertChatMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
return r0, r1
}
@@ -3272,14 +3143,6 @@ func (m queryMetricsStore) InsertLicense(ctx context.Context, arg database.Inser
return r0, r1
}
func (m queryMetricsStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.InsertMCPServerConfig(ctx, arg)
m.queryLatencies.WithLabelValues("InsertMCPServerConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertMCPServerConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
start := time.Now()
r0, r1 := m.s.InsertMemoryResourceMonitor(ctx, arg)
@@ -3864,22 +3727,6 @@ func (m queryMetricsStore) SelectUsageEventsForPublishing(ctx context.Context, n
return r0, r1
}
func (m queryMetricsStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
start := time.Now()
r0 := m.s.SoftDeleteChatMessageByID(ctx, id)
m.queryLatencies.WithLabelValues("SoftDeleteChatMessageByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessageByID").Inc()
return r0
}
func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
start := time.Now()
r0 := m.s.SoftDeleteChatMessagesAfterID(ctx, arg)
m.queryLatencies.WithLabelValues("SoftDeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessagesAfterID").Inc()
return r0
}
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
start := time.Now()
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
@@ -3952,14 +3799,6 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatMCPServerIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMCPServerIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
@@ -4064,14 +3903,6 @@ func (m queryMetricsStore) UpdateInboxNotificationReadStatus(ctx context.Context
return r0
}
func (m queryMetricsStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
start := time.Now()
r0, r1 := m.s.UpdateMCPServerConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateMCPServerConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateMCPServerConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
start := time.Now()
r0, r1 := m.s.UpdateMemberRoles(ctx, arg)
@@ -4132,7 +3963,6 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
start := time.Now()
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
return r0, r1
}
@@ -4712,14 +4542,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
return r0, r1
}
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
start := time.Now()
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
@@ -4808,14 +4630,6 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro
return r0
}
func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertMCPServerUserToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserToken").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
start := time.Now()
r0 := m.s.UpsertNotificationReportGeneratorLog(ctx, arg)
@@ -4952,14 +4766,6 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a
return r0, r1
}
func (m queryMetricsStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
start := time.Now()
r0, r1 := m.s.UsageEventExistsByID(ctx, id)
m.queryLatencies.WithLabelValues("UsageEventExistsByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UsageEventExistsByID").Inc()
return r0, r1
}
func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
start := time.Now()
r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds)
@@ -5079,11 +4885,3 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
return r0, r1
}
+31 -401
View File
@@ -334,20 +334,6 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
}
// CleanupDeletedMCPServerIDsFromChats mocks base method.
func (m *MockStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupDeletedMCPServerIDsFromChats", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupDeletedMCPServerIDsFromChats indicates an expected call of CleanupDeletedMCPServerIDsFromChats.
func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
}
// CountAIBridgeInterceptions mocks base method.
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -612,6 +598,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
}
// DeleteChatMessagesAfterID mocks base method.
func (m *MockStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatMessagesAfterID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatMessagesAfterID indicates an expected call of DeleteChatMessagesAfterID.
func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
}
// DeleteChatModelConfigByID mocks base method.
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -783,34 +783,6 @@ func (mr *MockStoreMockRecorder) DeleteLicense(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLicense", reflect.TypeOf((*MockStore)(nil).DeleteLicense), ctx, id)
}
// DeleteMCPServerConfigByID mocks base method.
func (m *MockStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMCPServerConfigByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMCPServerConfigByID indicates an expected call of DeleteMCPServerConfigByID.
func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id)
}
// DeleteMCPServerUserToken mocks base method.
func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMCPServerUserToken indicates an expected call of DeleteMCPServerUserToken.
func (mr *MockStoreMockRecorder) DeleteMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserToken), ctx, arg)
}
// DeleteOAuth2ProviderAppByClientID mocks base method.
func (m *MockStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -1183,17 +1155,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
}
// DeleteWorkspaceACLsByOrganization mocks base method.
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
}
// DeleteWorkspaceAgentPortShare mocks base method.
@@ -1759,21 +1731,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
}
// GetAuthorizedChats mocks base method.
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
}
// GetAuthorizedConnectionLogsOffset mocks base method.
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -1939,21 +1896,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDesktopEnabled mocks base method.
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -2224,19 +2166,19 @@ func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
}
// GetChats mocks base method.
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChats indicates an expected call of GetChats.
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
}
// GetConnectionLogsOffset mocks base method.
@@ -2479,21 +2421,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
}
// GetEnabledMCPServerConfigs mocks base method.
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEnabledMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEnabledMCPServerConfigs indicates an expected call of GetEnabledMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetEnabledMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledMCPServerConfigs), ctx)
}
// GetExternalAuthLink mocks base method.
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
m.ctrl.T.Helper()
@@ -2599,21 +2526,6 @@ func (mr *MockStoreMockRecorder) GetFilteredInboxNotificationsByUserID(ctx, arg
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetFilteredInboxNotificationsByUserID), ctx, arg)
}
// GetForcedMCPServerConfigs mocks base method.
func (m *MockStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetForcedMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetForcedMCPServerConfigs indicates an expected call of GetForcedMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetForcedMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForcedMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetForcedMCPServerConfigs), ctx)
}
// GetGitSSHKey mocks base method.
func (m *MockStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
m.ctrl.T.Helper()
@@ -2914,96 +2826,6 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx)
}
// GetMCPServerConfigByID mocks base method.
func (m *MockStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigByID", ctx, id)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigByID indicates an expected call of GetMCPServerConfigByID.
func (mr *MockStoreMockRecorder) GetMCPServerConfigByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigByID), ctx, id)
}
// GetMCPServerConfigBySlug mocks base method.
func (m *MockStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigBySlug", ctx, slug)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigBySlug indicates an expected call of GetMCPServerConfigBySlug.
func (mr *MockStoreMockRecorder) GetMCPServerConfigBySlug(ctx, slug any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigBySlug", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigBySlug), ctx, slug)
}
// GetMCPServerConfigs mocks base method.
func (m *MockStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigs", ctx)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigs indicates an expected call of GetMCPServerConfigs.
func (mr *MockStoreMockRecorder) GetMCPServerConfigs(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigs), ctx)
}
// GetMCPServerConfigsByIDs mocks base method.
func (m *MockStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerConfigsByIDs", ctx, ids)
ret0, _ := ret[0].([]database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerConfigsByIDs indicates an expected call of GetMCPServerConfigsByIDs.
func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids)
}
// GetMCPServerUserToken mocks base method.
func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserToken indicates an expected call of GetMCPServerUserToken.
func (mr *MockStoreMockRecorder) GetMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserToken), ctx, arg)
}
// GetMCPServerUserTokensByUserID mocks base method.
func (m *MockStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserTokensByUserID", ctx, userID)
ret0, _ := ret[0].([]database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserTokensByUserID indicates an expected call of GetMCPServerUserTokensByUserID.
func (mr *MockStoreMockRecorder) GetMCPServerUserTokensByUserID(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserTokensByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserTokensByUserID), ctx, userID)
}
// GetNotificationMessagesByStatus mocks base method.
func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
m.ctrl.T.Helper()
@@ -3364,66 +3186,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
}
// GetPRInsightsPerModel mocks base method.
func (m *MockStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsPerModel", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsPerModelRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsPerModel indicates an expected call of GetPRInsightsPerModel.
func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
}
// GetPRInsightsRecentPRs mocks base method.
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
}
// GetPRInsightsSummary mocks base method.
func (m *MockStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsSummary", ctx, arg)
ret0, _ := ret[0].(database.GetPRInsightsSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsSummary indicates an expected call of GetPRInsightsSummary.
func (mr *MockStoreMockRecorder) GetPRInsightsSummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsSummary", reflect.TypeOf((*MockStore)(nil).GetPRInsightsSummary), ctx, arg)
}
// GetPRInsightsTimeSeries mocks base method.
func (m *MockStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPRInsightsTimeSeries", ctx, arg)
ret0, _ := ret[0].([]database.GetPRInsightsTimeSeriesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPRInsightsTimeSeries indicates an expected call of GetPRInsightsTimeSeries.
func (mr *MockStoreMockRecorder) GetPRInsightsTimeSeries(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsTimeSeries", reflect.TypeOf((*MockStore)(nil).GetPRInsightsTimeSeries), ctx, arg)
}
// GetParameterSchemasByJobID mocks base method.
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
m.ctrl.T.Helper()
@@ -5748,21 +5510,6 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
}
// InsertAIBridgeModelThought mocks base method.
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeModelThought)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
}
// InsertAIBridgeTokenUsage mocks base method.
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -5883,19 +5630,19 @@ func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
}
// InsertChatMessages mocks base method.
func (m *MockStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
// InsertChatMessage mocks base method.
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatMessages", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
ret0, _ := ret[0].(database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatMessages indicates an expected call of InsertChatMessages.
func (mr *MockStoreMockRecorder) InsertChatMessages(ctx, arg any) *gomock.Call {
// InsertChatMessage indicates an expected call of InsertChatMessage.
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessages", reflect.TypeOf((*MockStore)(nil).InsertChatMessages), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
}
// InsertChatModelConfig mocks base method.
@@ -6119,21 +5866,6 @@ func (mr *MockStoreMockRecorder) InsertLicense(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), ctx, arg)
}
// InsertMCPServerConfig mocks base method.
func (m *MockStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertMCPServerConfig", ctx, arg)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertMCPServerConfig indicates an expected call of InsertMCPServerConfig.
func (mr *MockStoreMockRecorder) InsertMCPServerConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMCPServerConfig", reflect.TypeOf((*MockStore)(nil).InsertMCPServerConfig), ctx, arg)
}
// InsertMemoryResourceMonitor mocks base method.
func (m *MockStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
m.ctrl.T.Helper()
@@ -7275,34 +7007,6 @@ func (mr *MockStoreMockRecorder) SelectUsageEventsForPublishing(ctx, now any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SelectUsageEventsForPublishing", reflect.TypeOf((*MockStore)(nil).SelectUsageEventsForPublishing), ctx, now)
}
// SoftDeleteChatMessageByID mocks base method.
func (m *MockStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SoftDeleteChatMessageByID", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// SoftDeleteChatMessageByID indicates an expected call of SoftDeleteChatMessageByID.
func (mr *MockStoreMockRecorder) SoftDeleteChatMessageByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessageByID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessageByID), ctx, id)
}
// SoftDeleteChatMessagesAfterID mocks base method.
func (m *MockStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SoftDeleteChatMessagesAfterID", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// SoftDeleteChatMessagesAfterID indicates an expected call of SoftDeleteChatMessagesAfterID.
func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
}
// TryAcquireLock mocks base method.
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
m.ctrl.T.Helper()
@@ -7433,21 +7137,6 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
}
// UpdateChatMCPServerIDs mocks base method.
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatMCPServerIDs", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatMCPServerIDs indicates an expected call of UpdateChatMCPServerIDs.
func (mr *MockStoreMockRecorder) UpdateChatMCPServerIDs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMCPServerIDs", reflect.TypeOf((*MockStore)(nil).UpdateChatMCPServerIDs), ctx, arg)
}
// UpdateChatMessageByID mocks base method.
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -7641,21 +7330,6 @@ func (mr *MockStoreMockRecorder) UpdateInboxNotificationReadStatus(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInboxNotificationReadStatus", reflect.TypeOf((*MockStore)(nil).UpdateInboxNotificationReadStatus), ctx, arg)
}
// UpdateMCPServerConfig mocks base method.
func (m *MockStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateMCPServerConfig", ctx, arg)
ret0, _ := ret[0].(database.MCPServerConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateMCPServerConfig indicates an expected call of UpdateMCPServerConfig.
func (mr *MockStoreMockRecorder) UpdateMCPServerConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateMCPServerConfig), ctx, arg)
}
// UpdateMemberRoles mocks base method.
func (m *MockStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
m.ctrl.T.Helper()
@@ -8806,20 +8480,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
}
// UpsertChatDesktopEnabled mocks base method.
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
}
// UpsertChatDiffStatus mocks base method.
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -8980,21 +8640,6 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value)
}
// UpsertMCPServerUserToken mocks base method.
func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertMCPServerUserToken", ctx, arg)
ret0, _ := ret[0].(database.MCPServerUserToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertMCPServerUserToken indicates an expected call of UpsertMCPServerUserToken.
func (mr *MockStoreMockRecorder) UpsertMCPServerUserToken(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserToken), ctx, arg)
}
// UpsertNotificationReportGeneratorLog mocks base method.
func (m *MockStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
m.ctrl.T.Helper()
@@ -9241,21 +8886,6 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg)
}
// UsageEventExistsByID mocks base method.
func (m *MockStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UsageEventExistsByID", ctx, id)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UsageEventExistsByID indicates an expected call of UsageEventExistsByID.
func (mr *MockStoreMockRecorder) UsageEventExistsByID(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageEventExistsByID", reflect.TypeOf((*MockStore)(nil).UsageEventExistsByID), ctx, id)
}
// ValidateGroupIDs mocks base method.
func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
m.ctrl.T.Helper()
+12 -148
View File
@@ -512,12 +512,6 @@ CREATE TYPE resource_type AS ENUM (
'ai_seat'
);
CREATE TYPE shareable_workspace_owners AS ENUM (
'none',
'everyone',
'service_accounts'
);
CREATE TYPE startup_script_behavior AS ENUM (
'blocking',
'non-blocking'
@@ -620,35 +614,28 @@ CREATE FUNCTION aggregate_usage_event() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
-- Check for supported event types and throw error for unknown types.
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
-- Check for supported event types and throw error for unknown types
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
-- Extract the date from the created_at timestamp, always using UTC for
-- consistency
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
-- Handle simple counter events by summing the count.
-- Handle simple counter events by summing the count
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
-- Heartbeat events: keep the max value seen that day
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
jsonb_build_object(
'count',
GREATEST(
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
)
END;
RETURN NEW;
@@ -805,7 +792,7 @@ BEGIN
END;
$$;
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
@@ -820,8 +807,7 @@ BEGIN
is_system,
created_at,
updated_at
) VALUES
(
) VALUES (
'organization-member',
'',
NEW.id,
@@ -832,18 +818,6 @@ BEGIN
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
@@ -1112,15 +1086,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
metadata jsonb,
created_at timestamp with time zone NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE TABLE aibridge_token_usages (
id uuid NOT NULL,
interception_id uuid NOT NULL,
@@ -1289,9 +1254,7 @@ CREATE TABLE chat_messages (
compressed boolean DEFAULT false NOT NULL,
created_by uuid,
content_version smallint NOT NULL,
total_cost_micros bigint,
runtime_ms bigint,
deleted boolean DEFAULT false NOT NULL
total_cost_micros bigint
);
CREATE SEQUENCE chat_messages_id_seq
@@ -1393,8 +1356,7 @@ CREATE TABLE chats (
last_model_config_id uuid NOT NULL,
archived boolean DEFAULT false NOT NULL,
last_error text,
mode chat_mode,
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL
mode chat_mode
);
CREATE TABLE connection_logs (
@@ -1671,53 +1633,6 @@ CREATE SEQUENCE licenses_id_seq
ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id;
CREATE TABLE mcp_server_configs (
id uuid DEFAULT gen_random_uuid() NOT NULL,
display_name text NOT NULL,
slug text NOT NULL,
description text DEFAULT ''::text NOT NULL,
icon_url text DEFAULT ''::text NOT NULL,
transport text DEFAULT 'streamable_http'::text NOT NULL,
url text NOT NULL,
auth_type text DEFAULT 'none'::text NOT NULL,
oauth2_client_id text DEFAULT ''::text NOT NULL,
oauth2_client_secret text DEFAULT ''::text NOT NULL,
oauth2_client_secret_key_id text,
oauth2_auth_url text DEFAULT ''::text NOT NULL,
oauth2_token_url text DEFAULT ''::text NOT NULL,
oauth2_scopes text DEFAULT ''::text NOT NULL,
api_key_header text DEFAULT 'Authorization'::text NOT NULL,
api_key_value text DEFAULT ''::text NOT NULL,
api_key_value_key_id text,
custom_headers text DEFAULT '{}'::text NOT NULL,
custom_headers_key_id text,
tool_allow_list text[] DEFAULT '{}'::text[] NOT NULL,
tool_deny_list text[] DEFAULT '{}'::text[] NOT NULL,
availability text DEFAULT 'default_off'::text NOT NULL,
enabled boolean DEFAULT false NOT NULL,
created_by uuid,
updated_by uuid,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
);
CREATE TABLE mcp_server_user_tokens (
id uuid DEFAULT gen_random_uuid() NOT NULL,
mcp_server_config_id uuid NOT NULL,
user_id uuid NOT NULL,
access_token text NOT NULL,
access_token_key_id text,
refresh_token text DEFAULT ''::text NOT NULL,
refresh_token_key_id text,
token_type text DEFAULT 'Bearer'::text NOT NULL,
expiry timestamp with time zone,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE TABLE notification_messages (
id uuid NOT NULL,
notification_template_id uuid NOT NULL,
@@ -1908,11 +1823,9 @@ CREATE TABLE organizations (
display_name text NOT NULL,
icon text DEFAULT ''::text NOT NULL,
deleted boolean DEFAULT false NOT NULL,
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
workspace_sharing_disabled boolean DEFAULT false NOT NULL
);
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
CREATE TABLE parameter_schemas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -2712,7 +2625,7 @@ CREATE TABLE usage_events (
publish_started_at timestamp with time zone,
published_at timestamp with time zone,
failure_message text,
CONSTRAINT usage_event_type_check CHECK ((event_type = ANY (ARRAY['dc_managed_agents_v1'::text, 'hb_ai_seats_v1'::text])))
CONSTRAINT usage_event_type_check CHECK ((event_type = 'dc_managed_agents_v1'::text))
);
COMMENT ON TABLE usage_events IS 'usage_events contains usage data that is collected from the product and potentially shipped to the usage collector service.';
@@ -3391,18 +3304,6 @@ ALTER TABLE ONLY licenses
ALTER TABLE ONLY licenses
ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
ALTER TABLE ONLY notification_messages
ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
@@ -3661,8 +3562,6 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
@@ -3751,12 +3650,6 @@ CREATE INDEX idx_inbox_notifications_user_id_read_at ON inbox_notifications USIN
CREATE INDEX idx_inbox_notifications_user_id_template_id_targets ON inbox_notifications USING btree (user_id, template_id, targets);
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (enabled) WHERE (enabled = true);
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text));
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id);
CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status);
CREATE INDEX idx_organization_member_organization_id_uuid ON organization_members USING btree (organization_id);
@@ -3785,8 +3678,6 @@ CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree
CREATE UNIQUE INDEX idx_unique_preset_name ON template_version_presets USING btree (name, template_version_id);
CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, created_at) WHERE (event_type = 'hb_ai_seats_v1'::text);
CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at);
CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at);
@@ -3961,7 +3852,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
@@ -4081,33 +3972,6 @@ ALTER TABLE ONLY jfrog_xray_scans
ALTER TABLE ONLY jfrog_xray_scans
ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_configs
ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ALTER TABLE ONLY mcp_server_user_tokens
ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY notification_messages
ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
@@ -40,15 +40,6 @@ const (
ForeignKeyInboxNotificationsUserID ForeignKeyConstraint = "inbox_notifications_user_id_fkey" // ALTER TABLE ONLY inbox_notifications ADD CONSTRAINT inbox_notifications_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
ForeignKeyMcpServerConfigsAPIKeyValueKeyID ForeignKeyConstraint = "mcp_server_configs_api_key_value_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsCreatedBy ForeignKeyConstraint = "mcp_server_configs_created_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
ForeignKeyMcpServerUserTokensUserID ForeignKeyConstraint = "mcp_server_user_tokens_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyNotificationMessagesNotificationTemplateID ForeignKeyConstraint = "notification_messages_notification_template_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
ForeignKeyNotificationMessagesUserID ForeignKeyConstraint = "notification_messages_user_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyNotificationPreferencesNotificationTemplateID ForeignKeyConstraint = "notification_preferences_notification_template_id_fkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
@@ -26,7 +26,6 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
"GetWorkspaces": "GetAuthorizedWorkspaces",
"GetUsers": "GetAuthorizedUsers",
"GetChats": "GetAuthorizedChats",
}
// Scan custom
@@ -1,3 +0,0 @@
DROP INDEX idx_aibridge_model_thoughts_interception_id;
DROP TABLE aibridge_model_thoughts;
@@ -1,10 +0,0 @@
CREATE TABLE aibridge_model_thoughts (
interception_id UUID NOT NULL,
content TEXT NOT NULL,
metadata jsonb,
created_at TIMESTAMPTZ NOT NULL
);
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
@@ -1,52 +0,0 @@
DELETE FROM custom_roles
WHERE name = 'organization-service-account' AND is_system = true;
ALTER TABLE organizations
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
-- Migrate back: 'none' -> disabled, everything else -> enabled.
UPDATE organizations
SET workspace_sharing_disabled = true
WHERE shareable_workspace_owners = 'none';
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
DROP TYPE shareable_workspace_owners;
-- Restore the original single-role trigger from migration 408.
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
DROP FUNCTION IF EXISTS insert_organization_system_roles;
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES (
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_org_member_system_role
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_org_member_system_role();
@@ -1,101 +0,0 @@
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
ALTER TABLE organizations
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
-- Migrate existing data from the boolean column.
UPDATE organizations
SET shareable_workspace_owners = 'none'
WHERE workspace_sharing_disabled = true;
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
-- Defensively rename any existing 'organization-service-account' roles
-- so they don't collide with the new system role.
UPDATE custom_roles
SET name = name || '-' || id::text
-- lower(name) is part of the existing unique index
WHERE lower(name) = 'organization-service-account';
-- Create skeleton organization-service-account system roles for all
-- existing organizations, mirroring what migration 408 did for
-- organization-member.
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
)
SELECT
'organization-service-account',
'',
id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
FROM
organizations;
-- Replace the single-role trigger with one that creates both system
-- roles when a new organization is inserted.
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
DROP FUNCTION IF EXISTS insert_org_member_system_role;
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
BEGIN
INSERT INTO custom_roles (
name,
display_name,
organization_id,
site_permissions,
org_permissions,
user_permissions,
member_permissions,
is_system,
created_at,
updated_at
) VALUES
(
'organization-member',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
),
(
'organization-service-account',
'',
NEW.id,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
'[]'::jsonb,
true,
NOW(),
NOW()
);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_insert_organization_system_roles
AFTER INSERT ON organizations
FOR EACH ROW
EXECUTE FUNCTION insert_organization_system_roles();
@@ -1,38 +0,0 @@
DROP INDEX IF EXISTS idx_usage_events_ai_seats;
-- Remove hb_ai_seats_v1 rows so the original constraint can be restored.
DELETE FROM usage_events WHERE event_type = 'hb_ai_seats_v1';
DELETE FROM usage_events_daily WHERE event_type = 'hb_ai_seats_v1';
-- Restore original constraint.
ALTER TABLE usage_events
DROP CONSTRAINT usage_event_type_check,
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1'));
-- Restore the original aggregate function without hb_ai_seats_v1 support.
CREATE OR REPLACE FUNCTION aggregate_usage_event()
RETURNS TRIGGER AS $$
BEGIN
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
END;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -1,50 +0,0 @@
-- Expand the CHECK constraint to allow hb_ai_seats_v1.
ALTER TABLE usage_events
DROP CONSTRAINT usage_event_type_check,
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1', 'hb_ai_seats_v1'));
-- Partial index for efficient lookups of AI seat heartbeat events by time.
-- This will be used for the admin dashboard to see seat count over time.
CREATE INDEX idx_usage_events_ai_seats
ON usage_events (event_type, created_at)
WHERE event_type = 'hb_ai_seats_v1';
-- Update the aggregate function to handle hb_ai_seats_v1 events.
-- Heartbeat events replace the previous value for the same time period.
CREATE OR REPLACE FUNCTION aggregate_usage_event()
RETURNS TRIGGER AS $$
BEGIN
-- Check for supported event types and throw error for unknown types.
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
END IF;
INSERT INTO usage_events_daily (day, event_type, usage_data)
VALUES (
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
NEW.event_type,
NEW.event_data
)
ON CONFLICT (day, event_type) DO UPDATE SET
usage_data = CASE
-- Handle simple counter events by summing the count.
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
jsonb_build_object(
'count',
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
-- Heartbeat events: keep the max value seen that day
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
jsonb_build_object(
'count',
GREATEST(
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
COALESCE((NEW.event_data->>'count')::bigint, 0)
)
)
END;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -1 +0,0 @@
ALTER TABLE chat_messages DROP COLUMN runtime_ms;
@@ -1 +0,0 @@
ALTER TABLE chat_messages ADD COLUMN runtime_ms bigint;
@@ -1,2 +0,0 @@
DELETE FROM chat_messages WHERE deleted = true;
ALTER TABLE chat_messages DROP COLUMN deleted;
@@ -1 +0,0 @@
ALTER TABLE chat_messages ADD COLUMN deleted boolean NOT NULL DEFAULT false;
@@ -1,6 +0,0 @@
ALTER TABLE chats DROP COLUMN IF EXISTS mcp_server_ids;
DROP INDEX IF EXISTS idx_mcp_server_configs_enabled;
DROP INDEX IF EXISTS idx_mcp_server_configs_forced;
DROP INDEX IF EXISTS idx_mcp_server_user_tokens_user_id;
DROP TABLE IF EXISTS mcp_server_user_tokens;
DROP TABLE IF EXISTS mcp_server_configs;

Some files were not shown because too many files have changed in this diff Show More