Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 337f4474c4 | |||
| 6e09ddc3c1 | |||
| 9cfd7ad394 | |||
| 0e3c880455 | |||
| 97c245c92c | |||
| d0083cdb06 | |||
| 7742854f10 | |||
| 926b568a60 | |||
| 775d26de97 |
@@ -1,72 +0,0 @@
|
||||
---
|
||||
name: pull-requests
|
||||
description: "Guide for creating, updating, and following up on pull requests in the Coder repository. Use when asked to open a PR, update a PR, rewrite a PR description, or follow up on CI/check failures."
|
||||
---
|
||||
|
||||
# Pull Request Skill
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use this skill when asked to:
|
||||
|
||||
- Create a pull request for the current branch.
|
||||
- Update an existing PR branch or description.
|
||||
- Rewrite a PR body.
|
||||
- Follow up on CI or check failures for an existing PR.
|
||||
|
||||
## References
|
||||
|
||||
Use the canonical docs for shared conventions and validation guidance:
|
||||
|
||||
- PR title and description conventions:
|
||||
`.claude/docs/PR_STYLE_GUIDE.md`
|
||||
- Local validation commands and git hooks: `AGENTS.md` (Essential Commands and
|
||||
Git Hooks sections)
|
||||
|
||||
## Lifecycle Rules
|
||||
|
||||
1. **Check for an existing PR** before creating a new one:
|
||||
|
||||
```bash
|
||||
gh pr list --head "$(git branch --show-current)" --author @me --json number --jq '.[0].number // empty'
|
||||
```
|
||||
|
||||
If that returns a number, update that PR. If it returns empty output,
|
||||
create a new one.
|
||||
2. **Check you are not on main.** If the current branch is `main` or `master`,
|
||||
create a feature branch before doing PR work.
|
||||
3. **Default to draft.** Use `gh pr create --draft` unless the user explicitly
|
||||
asks for ready-for-review.
|
||||
4. **Keep description aligned with the full diff.** Re-read the diff against
|
||||
the base branch before writing or updating the title and body. Describe the
|
||||
entire PR diff, not just the last commit.
|
||||
5. **Never auto-merge.** Do not merge or mark ready for review unless the user
|
||||
explicitly asks.
|
||||
6. **Never push to main or master.**
|
||||
|
||||
## CI / Checks Follow-up
|
||||
|
||||
**Always watch CI checks after pushing.** Do not push and walk away.
|
||||
|
||||
After pushing:
|
||||
|
||||
- Monitor CI with `gh pr checks <PR_NUMBER> --watch`.
|
||||
- Use `gh pr view <PR_NUMBER> --json statusCheckRollup` for programmatic check
|
||||
status.
|
||||
|
||||
If checks fail:
|
||||
|
||||
1. Find the failed run ID from the `gh pr checks` output.
|
||||
2. Read the logs with `gh run view <run-id> --log-failed`.
|
||||
3. Fix the problem locally.
|
||||
4. Run `make pre-commit`.
|
||||
5. Push the fix.
|
||||
|
||||
## What Not to Do
|
||||
|
||||
- Do not reference or call helper scripts that do not exist in this
|
||||
repository.
|
||||
- Do not auto-merge or mark ready for review without explicit user request.
|
||||
- Do not push to `origin/main` or `origin/master`.
|
||||
- Do not skip local validation before pushing.
|
||||
- Do not fabricate or embellish PR descriptions.
|
||||
@@ -1,9 +0,0 @@
|
||||
paths:
|
||||
# The triage workflow uses a quoted heredoc (<<'EOF') with ${VAR}
|
||||
# placeholders that envsubst expands later. Shellcheck's SC2016
|
||||
# warns about unexpanded variables in single-quoted strings, but
|
||||
# the non-expansion is intentional here. Actionlint doesn't honor
|
||||
# inline shellcheck disable directives inside heredocs.
|
||||
.github/workflows/triage-via-chat-api.yaml:
|
||||
ignore:
|
||||
- 'SC2016'
|
||||
+33
-33
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -191,7 +191,7 @@ jobs:
|
||||
|
||||
# Check for any typos
|
||||
- name: Check for typos
|
||||
uses: crate-ci/typos@631208b7aac2daa8b707f55e7331f9112b0e062d # v1.44.0
|
||||
uses: crate-ci/typos@2d0ce569feab1f8752f1dde43cc2f2aa53236e06 # v1.40.0
|
||||
with:
|
||||
config: .github/workflows/typos.toml
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -327,7 +327,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -379,7 +379,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -537,7 +537,7 @@ jobs:
|
||||
embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }}
|
||||
|
||||
- name: Upload failed test db dumps
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: failed-test-db-dump-${{matrix.os}}
|
||||
path: "**/*.test.sql"
|
||||
@@ -575,7 +575,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -637,7 +637,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -709,7 +709,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -736,7 +736,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -769,7 +769,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -818,7 +818,7 @@ jobs:
|
||||
|
||||
- name: Upload Playwright Failed Tests
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/*.webm
|
||||
@@ -826,7 +826,7 @@ jobs:
|
||||
|
||||
- name: Upload debug log
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coderd-debug-logs${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/e2e/test-results/debug.log
|
||||
@@ -834,7 +834,7 @@ jobs:
|
||||
|
||||
- name: Upload pprof dumps
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
|
||||
path: ./site/test-results/**/debug-pprof-*.txt
|
||||
@@ -849,7 +849,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -930,7 +930,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1005,7 +1005,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1043,7 +1043,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1097,7 +1097,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1108,7 +1108,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -1319,7 +1319,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:main"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1356,7 +1356,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:latest"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1393,7 +1393,7 @@ jobs:
|
||||
id: attest_version
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1457,7 +1457,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-amd64.tar.gz)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-amd64.tar.gz
|
||||
path: ./build/*_linux_amd64.tar.gz
|
||||
@@ -1465,7 +1465,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-amd64.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-amd64.deb
|
||||
path: ./build/*_linux_amd64.deb
|
||||
@@ -1473,7 +1473,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-arm64.tar.gz)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-arm64.tar.gz
|
||||
path: ./build/*_linux_arm64.tar.gz
|
||||
@@ -1481,7 +1481,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-arm64.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-arm64.deb
|
||||
path: ./build/*_linux_arm64.deb
|
||||
@@ -1489,7 +1489,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-armv7.tar.gz)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-armv7.tar.gz
|
||||
path: ./build/*_linux_armv7.tar.gz
|
||||
@@ -1497,7 +1497,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-linux-armv7.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-armv7.deb
|
||||
path: ./build/*_linux_armv7.deb
|
||||
@@ -1505,7 +1505,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifact (coder-windows-amd64.zip)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-windows-amd64.zip
|
||||
path: ./build/*_windows_amd64.zip
|
||||
@@ -1543,7 +1543,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -23,44 +23,6 @@ permissions:
|
||||
concurrency: pr-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
community-label:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
if: >-
|
||||
${{
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.action == 'opened' &&
|
||||
github.event.pull_request.author_association != 'MEMBER' &&
|
||||
github.event.pull_request.author_association != 'COLLABORATOR' &&
|
||||
github.event.pull_request.author_association != 'OWNER'
|
||||
}}
|
||||
steps:
|
||||
- name: Add community label
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
const params = {
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
}
|
||||
|
||||
const labels = context.payload.pull_request.labels.map((label) => label.name)
|
||||
if (labels.includes("community")) {
|
||||
console.log('PR already has "community" label.')
|
||||
return
|
||||
}
|
||||
|
||||
console.log(
|
||||
'Adding "community" label for author association "%s".',
|
||||
context.payload.pull_request.author_association,
|
||||
)
|
||||
await github.rest.issues.addLabels({
|
||||
...params,
|
||||
labels: ["community"],
|
||||
})
|
||||
|
||||
cla:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker login
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v45.0.7
|
||||
- uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.7
|
||||
id: changed-files
|
||||
with:
|
||||
files: |
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -78,11 +78,11 @@ jobs:
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -248,7 +248,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,12 +14,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Run Schmoder CI
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
|
||||
with:
|
||||
workflow: ci.yaml
|
||||
repo: coder/schmoder
|
||||
|
||||
@@ -80,7 +80,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
cat "$CODER_RELEASE_NOTES_FILE"
|
||||
|
||||
- name: Docker Login
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -358,7 +358,7 @@ jobs:
|
||||
id: attest_base
|
||||
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: ${{ steps.image-base-tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -474,7 +474,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: ${{ !inputs.dry_run }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -518,7 +518,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
|
||||
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
|
||||
with:
|
||||
subject-name: ${{ steps.latest_tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -665,7 +665,7 @@ jobs:
|
||||
|
||||
- name: Upload artifacts to actions (if dry-run)
|
||||
if: ${{ inputs.dry_run }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: release-artifacts
|
||||
path: |
|
||||
@@ -681,7 +681,7 @@ jobs:
|
||||
|
||||
- name: Upload latest sbom artifact to actions (if dry-run)
|
||||
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: latest-sbom-artifact
|
||||
path: ./coder_latest_sbom.spdx.json
|
||||
@@ -700,11 +700,13 @@ jobs:
|
||||
name: Publish to Homebrew tap
|
||||
runs-on: ubuntu-latest
|
||||
needs: release
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
|
||||
if: ${{ !inputs.dry_run }}
|
||||
|
||||
steps:
|
||||
# TODO: skip this if it's not a new release (i.e. a backport). This is
|
||||
# fine right now because it just makes a PR that we can close.
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -780,7 +782,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
# Upload the results as artifacts.
|
||||
- name: "Upload artifact"
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: SARIF file
|
||||
path: results.sarif
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
|
||||
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
category: "Trivy"
|
||||
|
||||
- name: Upload Trivy scan results as an artifact
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: trivy
|
||||
path: trivy-results.sarif
|
||||
|
||||
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -1,295 +0,0 @@
|
||||
# This workflow reimplements the AI Triage Automation using the Coder Chat API
|
||||
# instead of the Tasks API. The Chat API (/api/experimental/chats) is a simpler
|
||||
# interface that does not require a dedicated GitHub Action or workspace
|
||||
# provisioning — we just create a chat, poll for completion, and link the
|
||||
# result on the issue. All API calls use curl + jq directly.
|
||||
#
|
||||
# Key differences from the Tasks API workflow (traiage.yaml):
|
||||
# - No checkout of coder/create-task-action; everything is inline curl/jq.
|
||||
# - No template_name / template_preset / prefix inputs — the Chat API handles
|
||||
# resource allocation internally.
|
||||
# - Uses POST /api/experimental/chats to create a chat session.
|
||||
# - Polls GET /api/experimental/chats/<id> until the agent finishes.
|
||||
# - Chat URL format: ${CODER_URL}/agents?chat=${CHAT_ID}
|
||||
|
||||
name: AI Triage via Chat API
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- labeled
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
issue_url:
|
||||
description: "GitHub Issue URL to process"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
triage-chat:
|
||||
name: Triage GitHub Issue via Chat API
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.label.name == 'chat-triage' || github.event_name == 'workflow_dispatch'
|
||||
timeout-minutes: 30
|
||||
env:
|
||||
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
|
||||
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Determine the GitHub user and issue URL.
|
||||
# Identical to the Tasks API workflow — resolve the actor for
|
||||
# workflow_dispatch or the issue sender for label events.
|
||||
# ------------------------------------------------------------------
|
||||
- name: Determine Inputs
|
||||
id: determine-inputs
|
||||
if: always()
|
||||
env:
|
||||
GITHUB_ACTOR: ${{ github.actor }}
|
||||
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
|
||||
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
|
||||
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# For workflow_dispatch, use the actor who triggered it.
|
||||
# For issues events, use the issue sender.
|
||||
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
|
||||
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
|
||||
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
|
||||
exit 1
|
||||
fi
|
||||
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
|
||||
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
|
||||
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
|
||||
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
exit 0
|
||||
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
|
||||
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
|
||||
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
|
||||
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
|
||||
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
|
||||
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
exit 0
|
||||
else
|
||||
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Verify the triggering user has push access.
|
||||
# Unchanged from the Tasks API workflow.
|
||||
# ------------------------------------------------------------------
|
||||
- name: Verify push access
|
||||
env:
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
|
||||
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
|
||||
if [[ "${can_push}" != "true" ]]; then
|
||||
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Create a chat via the Coder Chat API.
|
||||
# Unlike the Tasks API which provisions a full workspace, the Chat
|
||||
# API creates a lightweight chat session. We POST to
|
||||
# /api/experimental/chats with the triage prompt as the initial
|
||||
# message and receive a chat ID back.
|
||||
# ------------------------------------------------------------------
|
||||
- name: Create chat via Coder Chat API
|
||||
id: create-chat
|
||||
env:
|
||||
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Build the same triage prompt used by the Tasks API workflow.
|
||||
TASK_PROMPT=$(cat <<'EOF'
|
||||
Fix ${ISSUE_URL}
|
||||
|
||||
1. Use the gh CLI to read the issue description and comments.
|
||||
2. Think carefully and try to understand the root cause. If the issue is unclear or not well defined, ask me to clarify and provide more information.
|
||||
3. Write a proposed implementation plan to PLAN.md for me to review before starting implementation. Your plan should use TDD and only make the minimal changes necessary to fix the root cause.
|
||||
4. When I approve your plan, start working on it. If you encounter issues with the plan, ask me for clarification and update the plan as required.
|
||||
5. When you have finished implementation according to the plan, commit and push your changes, and create a PR using the gh CLI for me to review.
|
||||
EOF
|
||||
)
|
||||
# Perform variable substitution on the prompt — scoped to $ISSUE_URL only.
|
||||
# Using envsubst without arguments would expand every env var in scope
|
||||
# (including CODER_SESSION_TOKEN), so we name the variable explicitly.
|
||||
TASK_PROMPT=$(echo "${TASK_PROMPT}" | envsubst '$ISSUE_URL')
|
||||
|
||||
echo "Creating chat with prompt:"
|
||||
echo "${TASK_PROMPT}"
|
||||
|
||||
# POST to the Chat API to create a new chat session.
|
||||
RESPONSE=$(curl --silent --fail-with-body \
|
||||
-X POST \
|
||||
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$(jq -n --arg prompt "${TASK_PROMPT}" \
|
||||
'{content: [{type: "text", text: $prompt}]}')" \
|
||||
"${CODER_URL}/api/experimental/chats")
|
||||
|
||||
echo "Chat API response:"
|
||||
echo "${RESPONSE}" | jq .
|
||||
|
||||
CHAT_ID=$(echo "${RESPONSE}" | jq -r '.id')
|
||||
CHAT_STATUS=$(echo "${RESPONSE}" | jq -r '.status')
|
||||
|
||||
if [[ -z "${CHAT_ID}" || "${CHAT_ID}" == "null" ]]; then
|
||||
echo "::error::Failed to create chat — no ID returned"
|
||||
echo "Response: ${RESPONSE}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate that CHAT_ID is a UUID before using it in URL paths.
|
||||
# This guards against unexpected API responses being interpolated
|
||||
# into subsequent curl calls.
|
||||
if [[ ! "${CHAT_ID}" =~ ^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$ ]]; then
|
||||
echo "::error::CHAT_ID is not a valid UUID: ${CHAT_ID}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CHAT_URL="${CODER_URL}/agents?chat=${CHAT_ID}"
|
||||
|
||||
echo "Chat created: ${CHAT_ID} (status: ${CHAT_STATUS})"
|
||||
echo "Chat URL: ${CHAT_URL}"
|
||||
|
||||
echo "chat_id=${CHAT_ID}" >> "${GITHUB_OUTPUT}"
|
||||
echo "chat_url=${CHAT_URL}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Poll the chat status until the agent finishes.
|
||||
# The Chat API is asynchronous — after creation the agent begins
|
||||
# working in the background. We poll GET /api/experimental/chats/<id>
|
||||
# every 5 seconds until the status is "waiting" (agent needs input),
|
||||
# "completed" (agent finished), or "error". Timeout after 10 minutes.
|
||||
# ------------------------------------------------------------------
|
||||
- name: Poll chat status
|
||||
id: poll-status
|
||||
env:
|
||||
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
POLL_INTERVAL=5
|
||||
# 10 minutes = 600 seconds.
|
||||
TIMEOUT=600
|
||||
ELAPSED=0
|
||||
|
||||
echo "Polling chat ${CHAT_ID} every ${POLL_INTERVAL}s (timeout: ${TIMEOUT}s)..."
|
||||
|
||||
while true; do
|
||||
RESPONSE=$(curl --silent --fail-with-body \
|
||||
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
|
||||
"${CODER_URL}/api/experimental/chats/${CHAT_ID}")
|
||||
|
||||
STATUS=$(echo "${RESPONSE}" | jq -r '.status')
|
||||
|
||||
echo "[${ELAPSED}s] Chat status: ${STATUS}"
|
||||
|
||||
case "${STATUS}" in
|
||||
waiting|completed)
|
||||
echo "Chat reached terminal status: ${STATUS}"
|
||||
echo "final_status=${STATUS}" >> "${GITHUB_OUTPUT}"
|
||||
exit 0
|
||||
;;
|
||||
error)
|
||||
echo "::error::Chat entered error state"
|
||||
echo "${RESPONSE}" | jq .
|
||||
echo "final_status=error" >> "${GITHUB_OUTPUT}"
|
||||
exit 1
|
||||
;;
|
||||
pending|running)
|
||||
# Still working — keep polling.
|
||||
;;
|
||||
*)
|
||||
echo "::warning::Unknown chat status: ${STATUS}"
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ ${ELAPSED} -ge ${TIMEOUT} ]]; then
|
||||
echo "::error::Timed out after ${TIMEOUT}s waiting for chat to finish"
|
||||
echo "final_status=timeout" >> "${GITHUB_OUTPUT}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep "${POLL_INTERVAL}"
|
||||
ELAPSED=$((ELAPSED + POLL_INTERVAL))
|
||||
done
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Comment on the GitHub issue with a link to the chat.
|
||||
# Only comment if the issue belongs to this repository (same guard
|
||||
# as the Tasks API workflow).
|
||||
# ------------------------------------------------------------------
|
||||
- name: Comment on issue
|
||||
if: startsWith(steps.determine-inputs.outputs.issue_url, format('{0}/{1}', github.server_url, github.repository))
|
||||
env:
|
||||
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
|
||||
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
|
||||
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
|
||||
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
COMMENT_BODY=$(cat <<EOF
|
||||
🤖 **AI Triage Chat Created**
|
||||
|
||||
A Coder chat session has been created to investigate this issue.
|
||||
|
||||
**Chat URL:** ${CHAT_URL}
|
||||
**Chat ID:** \`${CHAT_ID}\`
|
||||
**Status:** ${FINAL_STATUS}
|
||||
|
||||
The agent is working on a triage plan. Visit the chat to follow progress or provide guidance.
|
||||
EOF
|
||||
)
|
||||
|
||||
gh issue comment "${ISSUE_URL}" --body "${COMMENT_BODY}"
|
||||
echo "Comment posted on ${ISSUE_URL}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 6: Write a summary to the GitHub Actions step summary.
|
||||
# ------------------------------------------------------------------
|
||||
- name: Write summary
|
||||
env:
|
||||
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
|
||||
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
|
||||
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
|
||||
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
{
|
||||
echo "## AI Triage via Chat API"
|
||||
echo ""
|
||||
echo "**Issue:** ${ISSUE_URL}"
|
||||
echo "**Chat ID:** \`${CHAT_ID}\`"
|
||||
echo "**Chat URL:** ${CHAT_URL}"
|
||||
echo "**Status:** ${FINAL_STATUS}"
|
||||
} >> "${GITHUB_STEP_SUMMARY}"
|
||||
@@ -29,8 +29,6 @@ EDE = "EDE"
|
||||
HELO = "HELO"
|
||||
LKE = "LKE"
|
||||
byt = "byt"
|
||||
cpy = "cpy"
|
||||
Cpy = "Cpy"
|
||||
typ = "typ"
|
||||
# file extensions used in seti icon theme
|
||||
styl = "styl"
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -146,20 +146,11 @@ git config core.hooksPath scripts/githooks
|
||||
|
||||
Two hooks run automatically:
|
||||
|
||||
- **pre-commit**: Classifies staged files by type and runs either
|
||||
the full `make pre-commit` or the lightweight `make pre-commit-light`
|
||||
depending on whether Go, TypeScript, SQL, proto, or Makefile
|
||||
changes are present. Falls back to the full target when
|
||||
`CODER_HOOK_RUN_ALL=1` is set. A markdown-only commit takes
|
||||
seconds; a Go change takes several minutes.
|
||||
- **pre-push**: Classifies changed files (vs remote branch or
|
||||
merge-base) and runs `make pre-push` when Go, TypeScript, SQL,
|
||||
proto, or Makefile changes are detected. Skips tests entirely
|
||||
for lightweight changes. Allowlisted in
|
||||
`scripts/githooks/pre-push`. Runs only for developers who opt
|
||||
in. Falls back to `make pre-push` when the diff range can't
|
||||
be determined or `CODER_HOOK_RUN_ALL=1` is set. Allow at least
|
||||
15 minutes for a full run.
|
||||
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
|
||||
Fast checks that catch most CI failures. Allow at least 5 minutes.
|
||||
- **pre-push**: `make pre-push` (heavier checks including tests).
|
||||
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
|
||||
who opt in. Allow at least 15 minutes.
|
||||
|
||||
`git commit` and `git push` will appear to hang while hooks run.
|
||||
This is normal. Do not interrupt, retry, or reduce the timeout.
|
||||
@@ -217,11 +208,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
|
||||
|
||||
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
|
||||
- Commit format: `type(scope): message`
|
||||
- PR titles follow the same `type(scope): message` format.
|
||||
- When you use a scope, it must be a real filesystem path containing every
|
||||
changed file.
|
||||
- Use a broader path scope, or omit the scope, for cross-cutting changes.
|
||||
- Example: `fix(coderd/chatd): ...` for changes only in `coderd/chatd/`.
|
||||
|
||||
### Frontend Patterns
|
||||
|
||||
|
||||
@@ -136,10 +136,18 @@ endif
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
# Makefile dependencies (e.g. pnpm).
|
||||
MOST_GO_SRC_FILES := $(shell \
|
||||
find . \
|
||||
$(FIND_EXCLUSIONS) \
|
||||
-type f \
|
||||
-name '*.go' \
|
||||
-not -name '*_test.go' \
|
||||
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
|
||||
)
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
@@ -506,10 +514,7 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
cp "$<" "$$output_file"
|
||||
.PHONY: install
|
||||
|
||||
# Only wildcard the go files in the develop directory to avoid rebuilds
|
||||
# when project files are changd. Technically changes to some imports may
|
||||
# not be detected, but it's unlikely to cause any issues.
|
||||
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
|
||||
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
|
||||
CGO_ENABLED=0 go build -o $@ ./scripts/develop
|
||||
|
||||
BOLD := $(shell tput bold 2>/dev/null)
|
||||
@@ -522,10 +527,6 @@ RESET := $(shell tput sgr0 2>/dev/null)
|
||||
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
|
||||
.PHONY: fmt
|
||||
|
||||
# Subset of fmt that does not require Go or Node toolchains.
|
||||
fmt-light: fmt/shfmt fmt/terraform fmt/markdown
|
||||
.PHONY: fmt-light
|
||||
|
||||
fmt/go:
|
||||
ifdef FILE
|
||||
# Format single file
|
||||
@@ -633,10 +634,6 @@ LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
|
||||
.PHONY: lint
|
||||
|
||||
# Subset of lint that does not require Go or Node toolchains.
|
||||
lint-light: lint/shellcheck lint/markdown lint/helm lint/bootstrap lint/migrations lint/actions/actionlint lint/typos
|
||||
.PHONY: lint-light
|
||||
|
||||
lint/site-icons:
|
||||
./scripts/check_site_icons.sh
|
||||
.PHONY: lint/site-icons
|
||||
@@ -779,25 +776,6 @@ pre-commit:
|
||||
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
|
||||
.PHONY: pre-commit
|
||||
|
||||
# Lightweight pre-commit for changes that don't touch Go or
|
||||
# TypeScript. Skips gen, lint/go, lint/ts, fmt/go, fmt/ts, and
|
||||
# the binary build. Used by the pre-commit hook when only docs,
|
||||
# shell, terraform, helm, or other fast-to-check files changed.
|
||||
pre-commit-light:
|
||||
start=$$(date +%s)
|
||||
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit-light.XXXXXX")
|
||||
echo "$(BOLD)pre-commit-light$(RESET) ($$logdir)"
|
||||
echo "fmt:"
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir fmt-light
|
||||
$(check-unstaged)
|
||||
echo "lint:"
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir lint-light
|
||||
$(check-unstaged)
|
||||
$(check-untracked)
|
||||
rm -rf $$logdir
|
||||
echo "$(GREEN)✓ pre-commit-light passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
|
||||
.PHONY: pre-commit-light
|
||||
|
||||
pre-push:
|
||||
start=$$(date +%s)
|
||||
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
|
||||
@@ -806,7 +784,6 @@ pre-push:
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
|
||||
test \
|
||||
test-js \
|
||||
test-storybook \
|
||||
site/out/index.html
|
||||
rm -rf $$logdir
|
||||
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
|
||||
@@ -1341,11 +1318,6 @@ test-js: site/node_modules/.installed
|
||||
pnpm test:ci
|
||||
.PHONY: test-js
|
||||
|
||||
test-storybook: site/node_modules/.installed
|
||||
cd site/
|
||||
pnpm exec vitest run --project=storybook
|
||||
.PHONY: test-storybook
|
||||
|
||||
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
|
||||
# dependency for any sqlc-cloud related targets.
|
||||
sqlc-cloud-is-setup:
|
||||
|
||||
+2
-7
@@ -385,16 +385,11 @@ func (a *agent) init() {
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
|
||||
@@ -57,26 +57,18 @@ type fakeContainerCLI struct {
|
||||
}
|
||||
|
||||
func (f *fakeContainerCLI) List(_ context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.containers, f.listErr
|
||||
}
|
||||
|
||||
func (f *fakeContainerCLI) DetectArchitecture(_ context.Context, _ string) (string, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.arch, f.archErr
|
||||
}
|
||||
|
||||
func (f *fakeContainerCLI) Copy(ctx context.Context, name, src, dst string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.copyErr
|
||||
}
|
||||
|
||||
func (f *fakeContainerCLI) ExecAs(ctx context.Context, name, user string, args ...string) ([]byte, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return nil, f.execErr
|
||||
}
|
||||
|
||||
@@ -2697,9 +2689,7 @@ func TestAPI(t *testing.T) {
|
||||
|
||||
// When: The container is recreated (new container ID) with config changes.
|
||||
terraformContainer.ID = "new-container-id"
|
||||
fCCLI.mu.Lock()
|
||||
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
|
||||
fCCLI.mu.Unlock()
|
||||
fDCCLI.upID = terraformContainer.ID
|
||||
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
|
||||
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
|
||||
@@ -2831,9 +2821,7 @@ func TestAPI(t *testing.T) {
|
||||
// Simulate container rebuild: new container ID, changed display apps.
|
||||
newContainerID := "new-container-id"
|
||||
terraformContainer.ID = newContainerID
|
||||
fCCLI.mu.Lock()
|
||||
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
|
||||
fCCLI.mu.Unlock()
|
||||
fDCCLI.upID = newContainerID
|
||||
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
|
||||
DisplayApps: map[codersdk.DisplayApp]bool{
|
||||
@@ -4938,11 +4926,9 @@ func TestDevcontainerPrebuildSupport(t *testing.T) {
|
||||
)
|
||||
api.Start()
|
||||
|
||||
fCCLI.mu.Lock()
|
||||
fCCLI.containers = codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{testContainer},
|
||||
}
|
||||
fCCLI.mu.Unlock()
|
||||
|
||||
// Given: We allow the dev container to be created.
|
||||
fDCCLI.upID = testContainer.ID
|
||||
|
||||
@@ -2,9 +2,13 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -20,6 +24,28 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
portableDesktopVersion = "v0.0.4"
|
||||
downloadRetries = 3
|
||||
downloadRetryDelay = time.Second
|
||||
)
|
||||
|
||||
// platformBinaries maps GOARCH to download URL and expected SHA-256
|
||||
// digest for each supported platform.
|
||||
var platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
"amd64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
|
||||
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
|
||||
},
|
||||
"arm64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
|
||||
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
|
||||
},
|
||||
}
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
@@ -52,31 +78,43 @@ type screenshotOutput struct {
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
dataDir string // agent's ScriptDataDir, used for binary caching
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
|
||||
// httpClient is used for downloading the binary. If nil,
|
||||
// http.DefaultClient is used.
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
// CLI binary, using execer to spawn child processes. dataDir is used
|
||||
// to cache the downloaded binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
scriptBinDir string,
|
||||
dataDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
}
|
||||
|
||||
// httpDo returns the HTTP client to use for downloads.
|
||||
func (p *portableDesktop) httpDo() *http.Client {
|
||||
if p.httpClient != nil {
|
||||
return p.httpClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves the portabledesktop binary from PATH or the
|
||||
// coder script bin directory. It must be called while p.mu is held.
|
||||
// ensureBinary resolves or downloads the portabledesktop binary. It
|
||||
// must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Check the coder script bin directory.
|
||||
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
|
||||
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
|
||||
// On Windows, permission bits don't indicate executability,
|
||||
// so accept any regular file.
|
||||
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
|
||||
p.logger.Info(ctx, "found portabledesktop in script bin directory",
|
||||
slog.F("path", scriptBinPath),
|
||||
// 2. Platform checks.
|
||||
if runtime.GOOS != "linux" {
|
||||
return xerrors.New("portabledesktop is only supported on Linux")
|
||||
}
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
// 3. Check cache.
|
||||
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
|
||||
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
|
||||
// Verify it is executable.
|
||||
if info.Mode()&0o100 != 0 {
|
||||
p.logger.Info(ctx, "using cached portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
p.binPath = scriptBinPath
|
||||
p.binPath = cachedPath
|
||||
return nil
|
||||
}
|
||||
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
|
||||
slog.F("path", scriptBinPath),
|
||||
slog.F("mode", info.Mode().String()),
|
||||
}
|
||||
|
||||
// 4. Download with retry.
|
||||
p.logger.Info(ctx, "downloading portabledesktop binary",
|
||||
slog.F("url", bin.URL),
|
||||
slog.F("version", portableDesktopVersion),
|
||||
slog.F("arch", runtime.GOARCH),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := range downloadRetries {
|
||||
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
|
||||
lastErr = err
|
||||
p.logger.Warn(ctx, "download attempt failed",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", downloadRetries),
|
||||
slog.Error(err),
|
||||
)
|
||||
if attempt < downloadRetries-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(downloadRetryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
p.binPath = cachedPath
|
||||
p.logger.Info(ctx, "downloaded portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
|
||||
}
|
||||
|
||||
// downloadBinary fetches a binary from url, verifies its SHA-256
|
||||
// digest matches expectedSHA256, and atomically writes it to destPath.
|
||||
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
|
||||
return xerrors.Errorf("create cache directory: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("HTTP GET %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the final rename
|
||||
// is atomic on the same filesystem.
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up the temp file on any error path.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream the response body while computing SHA-256.
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return xerrors.Errorf("download body: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return xerrors.Errorf("close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Verify digest.
|
||||
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualSHA256 != expectedSHA256 {
|
||||
return xerrors.Errorf(
|
||||
"SHA-256 mismatch: expected %s, got %s",
|
||||
expectedSHA256, actualSHA256,
|
||||
)
|
||||
}
|
||||
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
if err := os.Chmod(tmpPath, 0o700); err != nil {
|
||||
return xerrors.Errorf("chmod: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return xerrors.Errorf("rename to final path: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(t.Context())
|
||||
x, y, err := pd.CursorPosition(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
ctx := context.Background()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- downloadBinary tests ---
|
||||
|
||||
func TestDownloadBinary_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho portable\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and has correct content.
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
// Verify executable permissions.
|
||||
info, err := os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("real binary content"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SHA-256 mismatch")
|
||||
|
||||
// The destination file should not exist (temp file cleaned up).
|
||||
_, statErr := os.Stat(destPath)
|
||||
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
|
||||
|
||||
// No leftover temp files in the directory.
|
||||
entries, err := os.ReadDir(destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, entries, "no leftover temp files should remain")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_HTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 404")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
|
||||
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
require.NoError(t, os.Chmod(binPath, 0o755))
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
|
||||
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binPath, pd.binPath)
|
||||
assert.Equal(t, cachedPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows does not support Unix permission bits")
|
||||
}
|
||||
func TestEnsureBinary_Downloads(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
// environment and we override the package-level platformBinaries.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
// Write without execute permission.
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
_ = binPath
|
||||
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Save and restore platformBinaries for this test.
|
||||
origBinaries := platformBinaries
|
||||
platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
runtime.GOARCH: {
|
||||
URL: srv.URL + "/portabledesktop",
|
||||
SHA256: expectedSHA,
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { platformBinaries = origBinaries })
|
||||
|
||||
dataDir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
httpClient: srv.Client(),
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
// Ensure PATH doesn't contain a real portabledesktop binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
|
||||
assert.Equal(t, expectedPath, pd.binPath)
|
||||
|
||||
// Verify the downloaded file has correct content.
|
||||
got, err := os.ReadFile(expectedPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(), // empty directory
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
binaryContent := []byte("#!/bin/sh\necho retried\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
var mu sync.Mutex
|
||||
attempt := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
mu.Lock()
|
||||
current := attempt
|
||||
attempt++
|
||||
mu.Unlock()
|
||||
|
||||
// Fail the first 2 attempts, succeed on the third.
|
||||
if current < 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Test downloadBinary directly to avoid time.Sleep in
|
||||
// ensureBinary's retry loop. We call it 3 times to simulate
|
||||
// what ensureBinary would do.
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
var lastErr error
|
||||
for i := range 3 {
|
||||
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
if i < 2 {
|
||||
// In the real code, ensureBinary sleeps here.
|
||||
// We skip the sleep in tests.
|
||||
continue
|
||||
}
|
||||
}
|
||||
require.NoError(t, lastErr, "download should succeed on the third attempt")
|
||||
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
mu.Lock()
|
||||
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
// Silence the linter about unused imports — agentexec.DefaultExecer
|
||||
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
|
||||
// fmt.Sscanf is used indirectly via the implementation.
|
||||
var (
|
||||
_ = agentexec.DefaultExecer
|
||||
_ = fmt.Sprintf
|
||||
)
|
||||
|
||||
+13
-85
@@ -333,68 +333,22 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
|
||||
return status, err
|
||||
}
|
||||
|
||||
// Check if the target already exists so we can preserve its
|
||||
// permissions on the temp file before rename.
|
||||
var origMode os.FileMode
|
||||
var haveOrigMode bool
|
||||
if stat, serr := api.filesystem.Stat(path); serr == nil {
|
||||
if stat.IsDir() {
|
||||
return http.StatusBadRequest, xerrors.Errorf("open %s: is a directory", path)
|
||||
}
|
||||
origMode = stat.Mode()
|
||||
haveOrigMode = true
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the rename is
|
||||
// always on the same device (atomic).
|
||||
tmpfile, err := afero.TempFile(api.filesystem, dir, filepath.Base(path))
|
||||
f, err := api.filesystem.Create(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
switch {
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
case errors.Is(err, syscall.EISDIR):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
tmpName := tmpfile.Name()
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(tmpfile, r.Body)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
_ = tmpfile.Close()
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Close before rename to flush buffered data and catch write
|
||||
// errors (e.g. delayed allocation failures).
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Set permissions on the temp file before rename so there is
|
||||
// no window where the target has wrong permissions.
|
||||
if haveOrigMode {
|
||||
if err := api.filesystem.Chmod(tmpName, origMode); err != nil {
|
||||
api.logger.Warn(ctx, "unable to set file permissions",
|
||||
slog.F("path", path),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if err := api.filesystem.Rename(tmpName, path); err != nil {
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
return status, err
|
||||
_, err = io.Copy(f, r.Body)
|
||||
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
|
||||
api.logger.Error(ctx, "workspace agent write file", slog.Error(err))
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
@@ -506,44 +460,18 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
tmpName := tmpfile.Name()
|
||||
defer tmpfile.Close()
|
||||
|
||||
if _, err := tmpfile.Write([]byte(content)); err != nil {
|
||||
_ = tmpfile.Close()
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Close before rename to flush buffered data and catch write
|
||||
// errors (e.g. delayed allocation failures).
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Set permissions on the temp file before rename so there is
|
||||
// no window where the target has wrong permissions.
|
||||
if err := api.filesystem.Chmod(tmpName, stat.Mode()); err != nil {
|
||||
api.logger.Warn(ctx, "unable to set file permissions",
|
||||
slog.F("path", path),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
err = api.filesystem.Rename(tmpName, path)
|
||||
err = api.filesystem.Rename(tmpfile.Name(), path)
|
||||
if err != nil {
|
||||
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
return status, err
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
@@ -400,83 +399,6 @@ func TestWriteFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFile_ReportsIOError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
|
||||
tmpdir := os.TempDir()
|
||||
path := filepath.Join(tmpdir, "write-io-error")
|
||||
err := afero.WriteFile(fs, path, []byte("original"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
// A reader that always errors simulates a failed body read
|
||||
// (e.g. network interruption). The atomic write should leave
|
||||
// the original file intact.
|
||||
body := iotest.ErrReader(xerrors.New("simulated I/O error"))
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
|
||||
fmt.Sprintf("/write-file?path=%s", path), body)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
got := &codersdk.Error{}
|
||||
err = json.NewDecoder(w.Body).Decode(got)
|
||||
require.NoError(t, err)
|
||||
require.ErrorContains(t, got, "simulated I/O error")
|
||||
|
||||
// The original file must survive the failed write.
|
||||
data, err := afero.ReadFile(fs, path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original", string(data))
|
||||
}
|
||||
|
||||
func TestWriteFile_PreservesPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permissions are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
path := filepath.Join(dir, "script.sh")
|
||||
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
info, err := osFs.Stat(path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
// Overwrite the file with new content.
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
|
||||
fmt.Sprintf("/write-file?path=%s", path),
|
||||
bytes.NewReader([]byte("#!/bin/sh\necho world\n")))
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
data, err := afero.ReadFile(osFs, path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
|
||||
|
||||
info, err = osFs.Stat(path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
|
||||
"write_file should preserve the original file's permissions")
|
||||
}
|
||||
|
||||
func TestEditFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -985,67 +907,6 @@ func TestEditFiles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditFiles_PreservesPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permissions are not reliably supported on Windows")
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
osFs := afero.NewOsFs()
|
||||
api := agentfiles.NewAPI(logger, osFs, nil)
|
||||
|
||||
path := filepath.Join(dir, "script.sh")
|
||||
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Sanity-check the initial mode.
|
||||
info, err := osFs.Stat(path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
body := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: path,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "hello",
|
||||
Replace: "world",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
buf := bytes.NewBuffer(nil)
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
err = enc.Encode(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify content was updated.
|
||||
data, err := afero.ReadFile(osFs, path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
|
||||
|
||||
// Verify permissions are preserved after the
|
||||
// temp-file-and-rename cycle.
|
||||
info, err = osFs.Stat(path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
|
||||
"edit_files should preserve the original file's permissions")
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+2
-58
@@ -1,13 +1,11 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
@@ -20,13 +18,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxWaitDuration is the maximum time a blocking
|
||||
// process output request can wait, regardless of
|
||||
// what the client requests.
|
||||
maxWaitDuration = 5 * time.Minute
|
||||
)
|
||||
|
||||
// API exposes process-related operations through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
@@ -35,10 +26,10 @@ type API struct {
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv, workingDir),
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
pathStore: pathStore,
|
||||
}
|
||||
}
|
||||
@@ -160,42 +151,6 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Enforce chat ID isolation. If the request carries
|
||||
// a chat context, only allow access to processes
|
||||
// belonging to that chat.
|
||||
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
if proc.chatID != "" && proc.chatID != chatID.String() {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check for blocking mode via query params.
|
||||
waitStr := r.URL.Query().Get("wait")
|
||||
wantWait := waitStr == "true"
|
||||
|
||||
if wantWait {
|
||||
// Extend the write deadline so the HTTP server's
|
||||
// WriteTimeout does not kill the connection while
|
||||
// we block.
|
||||
rc := http.NewResponseController(rw)
|
||||
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil {
|
||||
api.logger.Error(ctx, "extend write deadline for blocking process output",
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// Cap the wait at maxWaitDuration regardless of
|
||||
// client-supplied timeout.
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, maxWaitDuration)
|
||||
defer waitCancel()
|
||||
|
||||
_ = proc.waitForOutput(waitCtx)
|
||||
// Fall through to read snapshot below.
|
||||
}
|
||||
|
||||
output, truncated := proc.output()
|
||||
info := proc.info()
|
||||
|
||||
@@ -213,17 +168,6 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
// Enforce chat ID isolation.
|
||||
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
proc, procOK := api.manager.get(id)
|
||||
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var req workspacesdk.SignalProcessRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
|
||||
+3
-277
@@ -7,10 +7,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -78,22 +76,6 @@ func getOutput(t *testing.T, handler http.Handler, id string) *httptest.Response
|
||||
return w
|
||||
}
|
||||
|
||||
// getOutputWithHeaders sends a GET /{id}/output request with
|
||||
// custom headers and returns the recorder.
|
||||
func getOutputWithHeaders(t *testing.T, handler http.Handler, id string, headers http.Header) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
path := fmt.Sprintf("/%s/output", id)
|
||||
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
|
||||
for k, v := range headers {
|
||||
req.Header[k] = v
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
// postSignal sends a POST /{id}/signal request and returns
|
||||
// the recorder.
|
||||
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
|
||||
@@ -115,25 +97,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
|
||||
// execer, returning the handler and API.
|
||||
func newTestAPI(t *testing.T) http.Handler {
|
||||
t.Helper()
|
||||
return newTestAPIWithOptions(t, nil, nil)
|
||||
return newTestAPIWithUpdateEnv(t, nil)
|
||||
}
|
||||
|
||||
// newTestAPIWithUpdateEnv creates a new API with an optional
|
||||
// updateEnv hook for testing environment injection.
|
||||
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
|
||||
t.Helper()
|
||||
return newTestAPIWithOptions(t, updateEnv, nil)
|
||||
}
|
||||
|
||||
// newTestAPIWithOptions creates a new API with optional
|
||||
// updateEnv and workingDir hooks.
|
||||
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
@@ -278,100 +253,6 @@ func TestStartProcess(t *testing.T) {
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// No working directory closure, so the process
|
||||
// should fall back to $HOME. We verify through
|
||||
// the process list API which reports the resolved
|
||||
// working directory using native OS paths,
|
||||
// avoiding shell path format mismatches on
|
||||
// Windows (Git Bash returns POSIX paths).
|
||||
handler := newTestAPI(t)
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo ok",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var listResp workspacesdk.ListProcessesResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
|
||||
var proc *workspacesdk.ProcessInfo
|
||||
for i := range listResp.Processes {
|
||||
if listResp.Processes[i].ID == id {
|
||||
proc = &listResp.Processes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, proc, "process not found in list")
|
||||
require.Equal(t, homeDir, proc.WorkDir)
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The closure provides a valid directory, so the
|
||||
// process should start there. Use the marker file
|
||||
// pattern to avoid path format mismatches on
|
||||
// Windows.
|
||||
tmpDir := t.TempDir()
|
||||
handler := newTestAPIWithOptions(t, nil, func() string {
|
||||
return tmpDir
|
||||
})
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "touch marker.txt && ls marker.txt",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The closure returns a path that doesn't exist,
|
||||
// so the process should fall back to $HOME.
|
||||
handler := newTestAPIWithOptions(t, nil, func() string {
|
||||
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
})
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo ok",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var listResp workspacesdk.ListProcessesResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
|
||||
var proc *workspacesdk.ProcessInfo
|
||||
for i := range listResp.Processes {
|
||||
if listResp.Processes[i].ID == id {
|
||||
proc = &listResp.Processes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, proc, "process not found in list")
|
||||
require.Equal(t, homeDir, proc.WorkDir)
|
||||
})
|
||||
|
||||
t.Run("CustomEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -756,161 +637,6 @@ func TestProcessOutput(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "not found")
|
||||
})
|
||||
|
||||
t.Run("ChatIDEnforcement", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start a process with chat-a.
|
||||
chatA := uuid.New()
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo secret",
|
||||
Background: true,
|
||||
}, http.Header{
|
||||
workspacesdk.CoderChatIDHeader: {chatA.String()},
|
||||
})
|
||||
waitForExit(t, handler, id)
|
||||
|
||||
// Chat-b should NOT see this process.
|
||||
chatB := uuid.New()
|
||||
w1 := getOutputWithHeaders(t, handler, id, http.Header{
|
||||
workspacesdk.CoderChatIDHeader: {chatB.String()},
|
||||
})
|
||||
require.Equal(t, http.StatusNotFound, w1.Code)
|
||||
|
||||
// Without any chat ID header, should return 200
|
||||
// (backwards compatible).
|
||||
w2 := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w2.Code)
|
||||
})
|
||||
|
||||
t.Run("WaitForExit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello-wait && sleep 0.1",
|
||||
})
|
||||
|
||||
w := getOutputWithWait(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "hello-wait")
|
||||
})
|
||||
|
||||
t.Run("WaitAlreadyExited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
|
||||
waitForExit(t, handler, id)
|
||||
|
||||
w := getOutputWithWait(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.Running)
|
||||
require.Contains(t, resp.Output, "done")
|
||||
})
|
||||
|
||||
t.Run("WaitTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.IntervalMedium)
|
||||
defer cancel()
|
||||
|
||||
w := getOutputWithWaitCtx(ctx, t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Running)
|
||||
|
||||
// Kill and wait for the process so cleanup does
|
||||
// not hang.
|
||||
postSignal(
|
||||
t, handler, id,
|
||||
workspacesdk.SignalProcessRequest{Signal: "kill"},
|
||||
)
|
||||
waitForExit(t, handler, id)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentWaiters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
resps [2]workspacesdk.ProcessOutputResponse
|
||||
codes [2]int
|
||||
)
|
||||
for i := range 2 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
w := getOutputWithWait(t, handler, id)
|
||||
codes[i] = w.Code
|
||||
_ = json.NewDecoder(w.Body).Decode(&resps[i])
|
||||
}()
|
||||
}
|
||||
|
||||
// Signal the process to exit so both waiters unblock.
|
||||
postSignal(
|
||||
t, handler, id,
|
||||
workspacesdk.SignalProcessRequest{Signal: "kill"},
|
||||
)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for i := range 2 {
|
||||
require.Equal(t, http.StatusOK, codes[i], "waiter %d", i)
|
||||
require.False(t, resps[i].Running, "waiter %d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getOutputWithWait(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
return getOutputWithWaitCtx(ctx, t, handler, id)
|
||||
}
|
||||
|
||||
func getOutputWithWaitCtx(ctx context.Context, t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
path := fmt.Sprintf("/%s/output?wait=true", id)
|
||||
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func TestSignalProcess(t *testing.T) {
|
||||
@@ -1055,7 +781,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
|
||||
return current, nil
|
||||
}, pathStore, nil)
|
||||
}, pathStore)
|
||||
defer api.Close()
|
||||
|
||||
routes := api.Routes()
|
||||
|
||||
@@ -39,13 +39,11 @@ const (
|
||||
// how much output is written.
|
||||
type HeadTailBuffer struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
head []byte
|
||||
tail []byte
|
||||
tailPos int
|
||||
tailFull bool
|
||||
headFull bool
|
||||
closed bool
|
||||
totalBytes int
|
||||
maxHead int
|
||||
maxTail int
|
||||
@@ -54,24 +52,20 @@ type HeadTailBuffer struct {
|
||||
// NewHeadTailBuffer creates a new HeadTailBuffer with the
|
||||
// default head and tail sizes.
|
||||
func NewHeadTailBuffer() *HeadTailBuffer {
|
||||
b := &HeadTailBuffer{
|
||||
return &HeadTailBuffer{
|
||||
maxHead: MaxHeadBytes,
|
||||
maxTail: MaxTailBytes,
|
||||
}
|
||||
b.cond = sync.NewCond(&b.mu)
|
||||
return b
|
||||
}
|
||||
|
||||
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
|
||||
// head and tail sizes. This is useful for testing truncation
|
||||
// logic with smaller buffers.
|
||||
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
|
||||
b := &HeadTailBuffer{
|
||||
return &HeadTailBuffer{
|
||||
maxHead: maxHead,
|
||||
maxTail: maxTail,
|
||||
}
|
||||
b.cond = sync.NewCond(&b.mu)
|
||||
return b
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It is safe for concurrent use.
|
||||
@@ -302,15 +296,6 @@ func truncateLines(s string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Close marks the buffer as closed and wakes any waiters.
|
||||
// This is called when the process exits.
|
||||
func (b *HeadTailBuffer) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.closed = true
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
// Reset clears the buffer, discarding all data.
|
||||
func (b *HeadTailBuffer) Reset() {
|
||||
b.mu.Lock()
|
||||
@@ -320,7 +305,5 @@ func (b *HeadTailBuffer) Reset() {
|
||||
b.tailPos = 0
|
||||
b.tailFull = false
|
||||
b.headFull = false
|
||||
b.closed = false
|
||||
b.totalBytes = 0
|
||||
b.cond.Broadcast()
|
||||
}
|
||||
|
||||
+17
-71
@@ -70,25 +70,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
|
||||
|
||||
// manager tracks processes spawned by the agent.
|
||||
type manager struct {
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
updateEnv func(current []string) (updated []string, err error)
|
||||
workingDir func() string
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
updateEnv func(current []string) (updated []string, err error)
|
||||
}
|
||||
|
||||
// newManager creates a new process manager.
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
|
||||
return &manager{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
updateEnv: updateEnv,
|
||||
workingDir: workingDir,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
updateEnv: updateEnv,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,7 +109,9 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
// the process is not tied to any HTTP request.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
|
||||
cmd.Dir = m.resolveWorkDir(req.WorkDir)
|
||||
if req.WorkDir != "" {
|
||||
cmd.Dir = req.WorkDir
|
||||
}
|
||||
cmd.Stdin = nil
|
||||
cmd.SysProcAttr = procSysProcAttr()
|
||||
|
||||
@@ -158,7 +158,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
proc := &process{
|
||||
id: id,
|
||||
command: req.Command,
|
||||
workDir: cmd.Dir,
|
||||
workDir: req.WorkDir,
|
||||
background: req.Background,
|
||||
chatID: chatID,
|
||||
cmd: cmd,
|
||||
@@ -208,9 +208,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
proc.exitCode = &code
|
||||
proc.mu.Unlock()
|
||||
|
||||
// Wake any waiters blocked on new output or
|
||||
// process exit before closing the done channel.
|
||||
proc.buf.Close()
|
||||
close(proc.done)
|
||||
}()
|
||||
|
||||
@@ -322,54 +319,3 @@ func (m *manager) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForOutput blocks until the buffer is closed (process
|
||||
// exited) or the context is canceled. Returns nil when the
|
||||
// buffer closed, ctx.Err() when the context expired.
|
||||
func (p *process) waitForOutput(ctx context.Context) error {
|
||||
p.buf.cond.L.Lock()
|
||||
defer p.buf.cond.L.Unlock()
|
||||
|
||||
nevermind := make(chan struct{})
|
||||
defer close(nevermind)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Acquire the lock before broadcasting to
|
||||
// guarantee the waiter has entered cond.Wait()
|
||||
// (which atomically releases the lock).
|
||||
// Without this, a Broadcast between the loop
|
||||
// predicate check and cond.Wait() is lost.
|
||||
p.buf.cond.L.Lock()
|
||||
defer p.buf.cond.L.Unlock()
|
||||
p.buf.cond.Broadcast()
|
||||
case <-nevermind:
|
||||
}
|
||||
}()
|
||||
|
||||
for ctx.Err() == nil && !p.buf.closed {
|
||||
p.buf.cond.Wait()
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// resolveWorkDir returns the directory a process should start in.
|
||||
// Priority: explicit request dir > agent configured dir > $HOME.
|
||||
// Falls through when a candidate is empty or does not exist on
|
||||
// disk, matching the behavior of SSH sessions.
|
||||
func (m *manager) resolveWorkDir(requested string) string {
|
||||
if requested != "" {
|
||||
return requested
|
||||
}
|
||||
if m.workingDir != nil {
|
||||
if dir := m.workingDir(); dir != "" {
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
return dir
|
||||
}
|
||||
}
|
||||
}
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return home
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -398,11 +398,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "reporting script completed", slog.Error(err))
|
||||
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "reporting script completed: track command goroutine", slog.Error(err))
|
||||
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -22,6 +23,26 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// logSink captures structured log entries for testing.
|
||||
type logSink struct {
|
||||
mu sync.Mutex
|
||||
entries []slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.entries = append(s.entries, e)
|
||||
}
|
||||
|
||||
func (*logSink) Sync() {}
|
||||
|
||||
func (s *logSink) getEntries() []slog.SinkEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]slog.SinkEntry{}, s.entries...)
|
||||
}
|
||||
|
||||
// getField returns the value of a field by name from a slog.Map.
|
||||
func getField(fields slog.Map, name string) interface{} {
|
||||
for _, f := range fields {
|
||||
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
sink := testutil.NewFakeSink(t)
|
||||
logger := sink.Logger(slog.LevelInfo)
|
||||
sink := &logSink{}
|
||||
logger := slog.Make(sink)
|
||||
workspaceID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
templateVersionID := uuid.New()
|
||||
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
sendBoundaryLogsRequest(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.Entries()) >= 1
|
||||
return len(sink.getEntries()) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries := sink.Entries()
|
||||
entries := sink.getEntries()
|
||||
require.Len(t, entries, 1)
|
||||
entry := entries[0]
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
sendBoundaryLogsRequest(t, conn, req2)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.Entries()) >= 2
|
||||
return len(sink.getEntries()) >= 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries = sink.Entries()
|
||||
entries = sink.getEntries()
|
||||
entry = entries[1]
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
|
||||
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
|
||||
Properties: sdkTool.Schema.Properties,
|
||||
Required: sdkTool.Schema.Required,
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{
|
||||
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
|
||||
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
|
||||
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
|
||||
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
+1
-16
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
|
||||
var toolsResponse struct {
|
||||
Result struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Annotations struct {
|
||||
ReadOnlyHint *bool `json:"readOnlyHint"`
|
||||
DestructiveHint *bool `json:"destructiveHint"`
|
||||
IdempotentHint *bool `json:"idempotentHint"`
|
||||
OpenWorldHint *bool `json:"openWorldHint"`
|
||||
} `json:"annotations"`
|
||||
Name string `json:"name"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
|
||||
}
|
||||
slices.Sort(foundTools)
|
||||
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
|
||||
annotations := toolsResponse.Result.Tools[0].Annotations
|
||||
require.NotNil(t, annotations.ReadOnlyHint)
|
||||
require.NotNil(t, annotations.DestructiveHint)
|
||||
require.NotNil(t, annotations.IdempotentHint)
|
||||
require.NotNil(t, annotations.OpenWorldHint)
|
||||
assert.True(t, *annotations.ReadOnlyHint)
|
||||
assert.False(t, *annotations.DestructiveHint)
|
||||
assert.True(t, *annotations.IdempotentHint)
|
||||
assert.False(t, *annotations.OpenWorldHint)
|
||||
|
||||
// Call the tool and ensure it works.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
|
||||
|
||||
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
} else {
|
||||
updated, err = client.CreateOrganizationRole(ctx, customRole)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create role: %w", err)
|
||||
return xerrors.Errorf("patch role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -324,6 +324,7 @@ func TestServer(t *testing.T) {
|
||||
ignoreLines := []string{
|
||||
"isn't externally reachable",
|
||||
"open install.sh: file does not exist",
|
||||
"open install.ps1: file does not exist",
|
||||
"telemetry disabled, unable to notify of security issues",
|
||||
"installed terraform version newer than expected",
|
||||
"report generator",
|
||||
|
||||
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
)
|
||||
build = workspace.LatestBuild
|
||||
default:
|
||||
// If the last build was a failed start, run a stop
|
||||
// first to clean up any partially-provisioned
|
||||
// resources.
|
||||
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
|
||||
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
|
||||
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionStop,
|
||||
})
|
||||
if stopErr != nil {
|
||||
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
|
||||
}
|
||||
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
|
||||
if stopErr != nil {
|
||||
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
|
||||
}
|
||||
// Re-fetch workspace after stop completes so
|
||||
// startWorkspace sees the latest state.
|
||||
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
|
||||
// It's possible for a workspace build to fail due to the template requiring starting
|
||||
// workspaces with the active version.
|
||||
|
||||
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
|
||||
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
|
||||
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
|
||||
}
|
||||
|
||||
func TestStart_FailedStartCleansUp(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: store,
|
||||
Pubsub: ps,
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// Insert a failed start build directly into the database so that
|
||||
// the workspace's latest build is a failed "start" transition.
|
||||
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
ID: workspace.ID,
|
||||
OwnerID: member.ID,
|
||||
OrganizationID: owner.OrganizationID,
|
||||
TemplateID: template.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
|
||||
}).
|
||||
Failed().
|
||||
Do()
|
||||
|
||||
inv, root := clitest.New(t, "start", workspace.Name)
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// The CLI should detect the failed start and clean up first.
|
||||
pty.ExpectMatch("Cleaning up before retrying")
|
||||
pty.ExpectMatch("workspace has been started")
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
}
|
||||
|
||||
+17
-26
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
)
|
||||
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
// Note: this can only be done by the owner user.
|
||||
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
|
||||
cliLog.Debug(inv.Context(), "running as owner")
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
} else if !ok {
|
||||
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
|
||||
} else {
|
||||
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
|
||||
}
|
||||
|
||||
// Check if we're running inside a workspace
|
||||
if val, found := os.LookupEnv("CODER"); found && val == "true" {
|
||||
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
|
||||
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
|
||||
}
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
|
||||
deps := support.Deps{
|
||||
Client: client,
|
||||
// Support adds a sink so we don't need to supply one ourselves.
|
||||
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
|
||||
return
|
||||
}
|
||||
|
||||
var docsURL string
|
||||
if bun.Deployment.Config != nil {
|
||||
docsURL = bun.Deployment.Config.Values.DocsURL.String()
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
|
||||
if bun.Deployment.Config == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment configuration available!")
|
||||
return
|
||||
}
|
||||
|
||||
if bun.Deployment.HealthReport != nil {
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
}
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment health report available.")
|
||||
docsURL := bun.Deployment.Config.Values.DocsURL.String()
|
||||
if bun.Deployment.HealthReport == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment health report available!")
|
||||
return
|
||||
}
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
}
|
||||
|
||||
if bun.Network.Netcheck == nil {
|
||||
|
||||
+22
-47
@@ -28,9 +28,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
@@ -52,21 +50,9 @@ func TestSupportBundle(t *testing.T) {
|
||||
dc.Values.Prometheus.Enable = true
|
||||
secretValue := uuid.NewString()
|
||||
seedSecretDeploymentOptions(t, &dc, secretValue)
|
||||
// Use a mock healthcheck function to avoid flaky DERP health
|
||||
// checks in CI. The DERP checker performs real network operations
|
||||
// (portmapper gateway probing, STUN) that can hang for 60s+ on
|
||||
// macOS CI runners. Since this test validates support bundle
|
||||
// generation, not healthcheck correctness, a canned report is
|
||||
// sufficient.
|
||||
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
DeploymentValues: dc.Values,
|
||||
HealthcheckFunc: func(_ context.Context, _ string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
return &healthsdk.HealthcheckReport{
|
||||
Time: time.Now(),
|
||||
Healthy: true,
|
||||
Severity: health.SeverityOK,
|
||||
}
|
||||
},
|
||||
DeploymentValues: dc.Values,
|
||||
HealthcheckTimeout: testutil.WaitSuperLong,
|
||||
})
|
||||
|
||||
t.Cleanup(func() { closer.Close() })
|
||||
@@ -74,7 +60,7 @@ func TestSupportBundle(t *testing.T) {
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
// Set up test fixtures
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
workspaceWithAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, func(agents []*proto.Agent) []*proto.Agent {
|
||||
// This should not show up in the bundle output
|
||||
agents[0].Env["SECRET_VALUE"] = secretValue
|
||||
@@ -83,6 +69,22 @@ func TestSupportBundle(t *testing.T) {
|
||||
workspaceWithoutAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, nil)
|
||||
memberWorkspace := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, member.ID, nil)
|
||||
|
||||
// Wait for healthcheck to complete successfully before continuing with sub-tests.
|
||||
// The result is cached so subsequent requests will be fast.
|
||||
healthcheckDone := make(chan *healthsdk.HealthcheckReport)
|
||||
go func() {
|
||||
defer close(healthcheckDone)
|
||||
hc, err := healthsdk.New(client).DebugHealth(setupCtx)
|
||||
if err != nil {
|
||||
assert.NoError(t, err, "seed healthcheck cache")
|
||||
return
|
||||
}
|
||||
healthcheckDone <- &hc
|
||||
}()
|
||||
if _, ok := testutil.AssertReceive(setupCtx, t, healthcheckDone); !ok {
|
||||
t.Fatal("healthcheck did not complete in time -- this may be a transient issue")
|
||||
}
|
||||
|
||||
t.Run("WorkspaceWithAgent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -130,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
|
||||
assertBundleContents(t, path, true, false, []string{secretValue})
|
||||
})
|
||||
|
||||
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
|
||||
t.Run("NoPrivilege", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := t.TempDir()
|
||||
path := filepath.Join(d, "bundle.zip")
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
r, err := zip.OpenReader(path)
|
||||
require.NoError(t, err, "open zip file")
|
||||
defer r.Close()
|
||||
fileNames := make(map[string]struct{}, len(r.File))
|
||||
for _, f := range r.File {
|
||||
fileNames[f.Name] = struct{}{}
|
||||
}
|
||||
// These should always be present in the zip structure, even if
|
||||
// the content is null/empty for non-admin users.
|
||||
for _, name := range []string{
|
||||
"deployment/buildinfo.json",
|
||||
"deployment/config.json",
|
||||
"workspace/workspace.json",
|
||||
"logs.txt",
|
||||
"cli_logs.txt",
|
||||
"network/netcheck.json",
|
||||
"network/interfaces.json",
|
||||
} {
|
||||
require.Contains(t, fileNames, name)
|
||||
}
|
||||
require.ErrorContains(t, err, "failed authorization check")
|
||||
})
|
||||
|
||||
// This ensures that the CLI does not panic when trying to generate a support bundle
|
||||
@@ -180,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("received request: %s %s", r.Method, r.URL)
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me":
|
||||
resp := codersdk.User{}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
case "/api/v2/authcheck":
|
||||
// Fake auth check
|
||||
resp := codersdk.AuthorizationResponse{
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"last_seen_at": "====[timestamp]=====",
|
||||
"name": "test-daemon",
|
||||
"version": "v0.0.0-devel",
|
||||
"api_version": "1.16",
|
||||
"api_version": "1.15",
|
||||
"provisioners": [
|
||||
"echo"
|
||||
],
|
||||
|
||||
-6
@@ -170,12 +170,6 @@ AI BRIDGE OPTIONS:
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
|
||||
Comma-separated list of CIDR ranges that are permitted even though
|
||||
they fall within blocked private/reserved IP ranges. By default all
|
||||
private ranges are blocked to prevent SSRF attacks. Use this to allow
|
||||
access to specific internal networks.
|
||||
|
||||
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
||||
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
||||
provider requests.
|
||||
|
||||
+10
-11
@@ -8,17 +8,16 @@ USAGE:
|
||||
Aliases: user
|
||||
|
||||
SUBCOMMANDS:
|
||||
activate Update a user's status to 'active'. Active users can fully
|
||||
interact with the platform
|
||||
create Create a new user.
|
||||
delete Delete a user by username or user_id.
|
||||
edit-roles Edit a user's roles by username or id
|
||||
list Prints the list of users.
|
||||
oidc-claims Display the OIDC claims for the authenticated user.
|
||||
show Show a single user. Use 'me' to indicate the currently
|
||||
authenticated user.
|
||||
suspend Update a user's status to 'suspended'. A suspended user
|
||||
cannot log into the platform
|
||||
activate Update a user's status to 'active'. Active users can fully
|
||||
interact with the platform
|
||||
create Create a new user.
|
||||
delete Delete a user by username or user_id.
|
||||
edit-roles Edit a user's roles by username or id
|
||||
list Prints the list of users.
|
||||
show Show a single user. Use 'me' to indicate the currently
|
||||
authenticated user.
|
||||
suspend Update a user's status to 'suspended'. A suspended user cannot
|
||||
log into the platform
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -24,10 +24,6 @@ OPTIONS:
|
||||
-p, --password string
|
||||
Specifies a password for the new user.
|
||||
|
||||
--service-account bool
|
||||
Create a user account intended to be used by a service or as an
|
||||
intermediary rather than by a human.
|
||||
|
||||
-u, --username string
|
||||
Specifies a username for the new user.
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder users oidc-claims [flags]
|
||||
|
||||
Display the OIDC claims for the authenticated user.
|
||||
|
||||
- Display your OIDC claims:
|
||||
|
||||
$ coder users oidc-claims
|
||||
|
||||
- Display your OIDC claims as JSON:
|
||||
|
||||
$ coder users oidc-claims -o json
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [key|value] (default: key,value)
|
||||
Columns to display in table output.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
-11
@@ -752,11 +752,6 @@ workspace_prebuilds:
|
||||
# limit; disabled when set to zero.
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
# Configure the background chat processing daemon.
|
||||
chat:
|
||||
# How many pending chats a worker should acquire per polling cycle.
|
||||
# (default: 10, type: int)
|
||||
acquireBatchSize: 10
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
@@ -873,12 +868,6 @@ aibridgeproxy:
|
||||
# by the system. If not provided, the system certificate pool is used.
|
||||
# (default: <unset>, type: string)
|
||||
upstream_proxy_ca: ""
|
||||
# Comma-separated list of CIDR ranges that are permitted even though they fall
|
||||
# within blocked private/reserved IP ranges. By default all private ranges are
|
||||
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
|
||||
# networks.
|
||||
# (default: <unset>, type: string-array)
|
||||
allowed_private_cidrs: []
|
||||
# Configure data retention policies for various database tables. Retention
|
||||
# policies automatically purge old data to reduce database size and improve
|
||||
# performance. Setting a retention duration to 0 disables automatic purging for
|
||||
|
||||
+12
-37
@@ -17,14 +17,13 @@ import (
|
||||
|
||||
func (r *RootCmd) userCreate() *serpent.Command {
|
||||
var (
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
serviceAccount bool
|
||||
orgContext = NewOrganizationContext()
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
orgContext = NewOrganizationContext()
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
serpent.RequireNArgs(0),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
if serviceAccount {
|
||||
switch {
|
||||
case loginType != "":
|
||||
return xerrors.New("You cannot use --login-type with --service-account")
|
||||
case password != "":
|
||||
return xerrors.New("You cannot use --password with --service-account")
|
||||
case email != "":
|
||||
return xerrors.New("You cannot use --email with --service-account")
|
||||
case disableLogin:
|
||||
return xerrors.New("You cannot use --disable-login with --service-account")
|
||||
}
|
||||
}
|
||||
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if email == "" && !serviceAccount {
|
||||
if email == "" {
|
||||
email, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Email:",
|
||||
Validate: func(s string) error {
|
||||
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
}
|
||||
}
|
||||
userLoginType := codersdk.LoginTypePassword
|
||||
if disableLogin || serviceAccount {
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
if disableLogin {
|
||||
userLoginType = codersdk.LoginTypeNone
|
||||
} else if loginType != "" {
|
||||
userLoginType = codersdk.LoginType(loginType)
|
||||
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
Password: password,
|
||||
OrganizationIDs: []uuid.UUID{organization.ID},
|
||||
UserLoginType: userLoginType,
|
||||
ServiceAccount: serviceAccount,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
case codersdk.LoginTypeOIDC:
|
||||
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
|
||||
}
|
||||
if serviceAccount {
|
||||
email = "n/a"
|
||||
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
|
||||
Share the instructions below to get them started.
|
||||
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
|
||||
)),
|
||||
Value: serpent.StringOf(&loginType),
|
||||
},
|
||||
{
|
||||
Flag: "service-account",
|
||||
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
|
||||
Value: serpent.BoolOf(&serviceAccount),
|
||||
},
|
||||
}
|
||||
|
||||
orgContext.AttachOptions(cmd)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
|
||||
assert.Equal(t, args[5], created.Username)
|
||||
assert.Empty(t, created.Name)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
|
||||
err: "You cannot use --login-type with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountDisableLogin",
|
||||
args: []string{"--service-account", "-u", "dean", "--disable-login"},
|
||||
err: "You cannot use --disable-login with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountEmail",
|
||||
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
|
||||
err: "You cannot use --email with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountPassword",
|
||||
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
|
||||
err: "You cannot use --password with --service-account",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
err := inv.Run()
|
||||
if tt.err == "" {
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created, err := client.User(ctx, "dean")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) userOIDCClaims() *serpent.Command {
|
||||
formatter := cliui.NewOutputFormatter(
|
||||
cliui.ChangeFormatterData(
|
||||
cliui.TableFormat([]claimRow{}, []string{"key", "value"}),
|
||||
func(data any) (any, error) {
|
||||
resp, ok := data.(codersdk.OIDCClaimsResponse)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected type %T, got %T", resp, data)
|
||||
}
|
||||
rows := make([]claimRow, 0, len(resp.Claims))
|
||||
for k, v := range resp.Claims {
|
||||
rows = append(rows, claimRow{
|
||||
Key: k,
|
||||
Value: fmt.Sprintf("%v", v),
|
||||
})
|
||||
}
|
||||
return rows, nil
|
||||
},
|
||||
),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "oidc-claims",
|
||||
Short: "Display the OIDC claims for the authenticated user.",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Display your OIDC claims",
|
||||
Command: "coder users oidc-claims",
|
||||
},
|
||||
Example{
|
||||
Description: "Display your OIDC claims as JSON",
|
||||
Command: "coder users oidc-claims -o json",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.UserOIDCClaims(inv.Context())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get oidc claims: %w", err)
|
||||
}
|
||||
|
||||
out, err := formatter.Format(inv.Context(), resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintln(inv.Stdout, out)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
type claimRow struct {
|
||||
Key string `json:"-" table:"key,default_sort"`
|
||||
Value string `json:"-" table:"value"`
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestUserOIDCClaims(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newOIDCTest := func(t *testing.T) (*oidctest.FakeIDP, *codersdk.Client) {
|
||||
t.Helper()
|
||||
|
||||
fake := oidctest.NewFakeIDP(t,
|
||||
oidctest.WithServing(),
|
||||
)
|
||||
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
|
||||
cfg.AllowSignups = true
|
||||
})
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
OIDCConfig: cfg,
|
||||
})
|
||||
return fake, ownerClient
|
||||
}
|
||||
|
||||
t.Run("OwnClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fake, ownerClient := newOIDCTest(t)
|
||||
claims := jwt.MapClaims{
|
||||
"email": "alice@coder.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
"groups": []string{"admin", "eng"},
|
||||
}
|
||||
userClient, loginResp := fake.Login(t, ownerClient, claims)
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
inv, root := clitest.New(t, "users", "oidc-claims", "-o", "json")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var resp codersdk.OIDCClaimsResponse
|
||||
err = json.Unmarshal(buf.Bytes(), &resp)
|
||||
require.NoError(t, err, "unmarshal JSON output")
|
||||
require.NotEmpty(t, resp.Claims, "claims should not be empty")
|
||||
assert.Equal(t, "alice@coder.com", resp.Claims["email"])
|
||||
})
|
||||
|
||||
t.Run("Table", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fake, ownerClient := newOIDCTest(t)
|
||||
claims := jwt.MapClaims{
|
||||
"email": "bob@coder.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
}
|
||||
userClient, loginResp := fake.Login(t, ownerClient, claims)
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
inv, root := clitest.New(t, "users", "oidc-claims")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
inv.Stdout = buf
|
||||
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
require.Contains(t, output, "email")
|
||||
require.Contains(t, output, "bob@coder.com")
|
||||
})
|
||||
|
||||
t.Run("NotOIDCUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(t, "users", "oidc-claims")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an OIDC user")
|
||||
})
|
||||
|
||||
// Verify that two different OIDC users each only see their own
|
||||
// claims. The endpoint has no user parameter, so there is no way
|
||||
// to request another user's claims by design.
|
||||
t.Run("OnlyOwnClaims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
aliceFake, aliceOwnerClient := newOIDCTest(t)
|
||||
aliceClaims := jwt.MapClaims{
|
||||
"email": "alice-isolation@coder.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
}
|
||||
aliceClient, aliceLoginResp := aliceFake.Login(t, aliceOwnerClient, aliceClaims)
|
||||
defer aliceLoginResp.Body.Close()
|
||||
|
||||
bobFake, bobOwnerClient := newOIDCTest(t)
|
||||
bobClaims := jwt.MapClaims{
|
||||
"email": "bob-isolation@coder.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
}
|
||||
bobClient, bobLoginResp := bobFake.Login(t, bobOwnerClient, bobClaims)
|
||||
defer bobLoginResp.Body.Close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Alice sees her own claims.
|
||||
aliceResp, err := aliceClient.UserOIDCClaims(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "alice-isolation@coder.com", aliceResp.Claims["email"])
|
||||
|
||||
// Bob sees his own claims.
|
||||
bobResp, err := bobClient.UserOIDCClaims(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bob-isolation@coder.com", bobResp.Claims["email"])
|
||||
})
|
||||
|
||||
t.Run("ClaimsNeverNull", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fake, ownerClient := newOIDCTest(t)
|
||||
// Use minimal claims — just enough for OIDC login.
|
||||
claims := jwt.MapClaims{
|
||||
"email": "minimal@coder.com",
|
||||
"email_verified": true,
|
||||
"sub": uuid.NewString(),
|
||||
}
|
||||
userClient, loginResp := fake.Login(t, ownerClient, claims)
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := userClient.UserOIDCClaims(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Claims, "claims should never be nil, expected empty map")
|
||||
})
|
||||
}
|
||||
@@ -19,7 +19,6 @@ func (r *RootCmd) users() *serpent.Command {
|
||||
r.userSingle(),
|
||||
r.userDelete(),
|
||||
r.userEditRoles(),
|
||||
r.userOIDCClaims(),
|
||||
r.createUserStatusCommand(codersdk.UserStatusActive),
|
||||
r.createUserStatusCommand(codersdk.UserStatusSuspended),
|
||||
},
|
||||
|
||||
Generated
+1688
-2042
File diff suppressed because it is too large
Load Diff
Generated
+1688
-2020
File diff suppressed because it is too large
Load Diff
+491
-916
File diff suppressed because it is too large
Load Diff
@@ -2,26 +2,13 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
@@ -97,524 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
||||
require.ErrorContains(t, err, loadErr.Error())
|
||||
require.Equal(t, chat, refreshed)
|
||||
}
|
||||
|
||||
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).Times(1)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil,
|
||||
"",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
instructionCache: make(map[uuid.UUID]cachedInstruction),
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction := server.resolveInstructions(
|
||||
ctx,
|
||||
chat,
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.Contains(t, instruction, "Operating System: linux")
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{initialAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
|
||||
)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
|
||||
var dialed []uuid.UUID
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialed = append(dialed, agentID)
|
||||
if agentID == initialAgent.ID {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
return conn, func() {}, nil
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
|
||||
}
|
||||
|
||||
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
defer cancelCtx()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
initialMessage := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
localMessage := database.ChatMessage{
|
||||
ID: 2,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
server.publishMessage(chatID, localMessage)
|
||||
|
||||
event := requireStreamMessageEvent(t, events)
|
||||
require.Equal(t, int64(2), event.Message.ID)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
defer cancelCtx()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
initialMessage := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
cachedMessage := codersdk.ChatMessage{
|
||||
ID: 2,
|
||||
ChatID: chatID,
|
||||
Role: codersdk.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessage,
|
||||
ChatID: chatID,
|
||||
Message: &cachedMessage,
|
||||
})
|
||||
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
||||
AfterMessageID: 1,
|
||||
})
|
||||
|
||||
event := requireStreamMessageEvent(t, events)
|
||||
require.Equal(t, int64(2), event.Message.ID)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
defer cancelCtx()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
initialMessage := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
catchupMessage := database.ChatMessage{
|
||||
ID: 2,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 1,
|
||||
}).Return([]database.ChatMessage{catchupMessage}, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
||||
AfterMessageID: 1,
|
||||
})
|
||||
|
||||
event := requireStreamMessageEvent(t, events)
|
||||
require.Equal(t, int64(2), event.Message.ID)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
defer cancelCtx()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
initialMessage := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
editedMessage := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: chatID,
|
||||
Role: database.ChatMessageRoleUser,
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{initialMessage}, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{editedMessage}, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
server.publishEditedMessage(chatID, editedMessage)
|
||||
|
||||
event := requireStreamMessageEvent(t, events)
|
||||
require.Equal(t, int64(1), event.Message.ID)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(context.Background())
|
||||
defer cancelCtx()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
chatID := uuid.New()
|
||||
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return(nil, nil),
|
||||
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
|
||||
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
|
||||
)
|
||||
|
||||
server := newSubscribeTestServer(t, db)
|
||||
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
|
||||
require.True(t, ok)
|
||||
defer cancel()
|
||||
|
||||
retryingAt := time.Unix(1_700_000_000, 0).UTC()
|
||||
expected := &codersdk.ChatStreamRetry{
|
||||
Attempt: 1,
|
||||
DelayMs: (1500 * time.Millisecond).Milliseconds(),
|
||||
Error: "rate limit exceeded",
|
||||
RetryingAt: retryingAt,
|
||||
}
|
||||
|
||||
server.publishRetry(chatID, expected)
|
||||
|
||||
event := requireStreamRetryEvent(t, events)
|
||||
require.Equal(t, expected, event.Retry)
|
||||
requireNoStreamEvent(t, events, 200*time.Millisecond)
|
||||
}
|
||||
|
||||
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
|
||||
t.Helper()
|
||||
|
||||
return &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
pubsub: dbpubsub.NewInMemory(),
|
||||
}
|
||||
}
|
||||
|
||||
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
require.True(t, ok, "chat stream closed before delivering an event")
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
|
||||
require.NotNil(t, event.Message)
|
||||
return event
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for chat stream message event")
|
||||
return codersdk.ChatStreamEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
require.True(t, ok, "chat stream closed before delivering an event")
|
||||
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
|
||||
require.NotNil(t, event.Retry)
|
||||
return event
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for chat stream retry event")
|
||||
return codersdk.ChatStreamEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if !ok {
|
||||
t.Fatal("chat stream closed unexpectedly")
|
||||
}
|
||||
t.Fatalf("unexpected chat stream event: %+v", event)
|
||||
case <-time.After(wait):
|
||||
}
|
||||
}
|
||||
|
||||
// TestPublishToStream_DropWarnRateLimiting walks through a
|
||||
// realistic lifecycle: buffer fills up, subscriber channel fills
|
||||
// up, counters get reset between steps. It verifies that WARN
|
||||
// logs are rate-limited to at most once per streamDropWarnInterval
|
||||
// and that counter resets re-enable an immediate WARN.
|
||||
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sink := testutil.NewFakeSink(t)
|
||||
mClock := quartz.NewMock(t)
|
||||
|
||||
server := &Server{
|
||||
logger: sink.Logger(),
|
||||
clock: mClock,
|
||||
}
|
||||
|
||||
chatID := uuid.New()
|
||||
subCh := make(chan codersdk.ChatStreamEvent, 1)
|
||||
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
|
||||
|
||||
// Set up state that mirrors a running chat: buffer at capacity,
|
||||
// buffering enabled, one saturated subscriber.
|
||||
state := &chatStreamState{
|
||||
buffering: true,
|
||||
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
|
||||
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
|
||||
uuid.New(): subCh,
|
||||
},
|
||||
}
|
||||
server.chatStreams.Store(chatID, state)
|
||||
|
||||
bufferMsg := "chat stream buffer full, dropping oldest event"
|
||||
subMsg := "dropping chat stream event"
|
||||
|
||||
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
|
||||
return func(e slog.SinkEntry) bool {
|
||||
return e.Level == level && e.Message == msg
|
||||
}
|
||||
}
|
||||
|
||||
// --- Phase 1: buffer-full rate limiting ---
|
||||
// message_part events hit both the buffer-full and subscriber-full
|
||||
// paths. The first publish triggers a WARN for each; the rest
|
||||
// within the window are DEBUG.
|
||||
partEvent := codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeMessagePart,
|
||||
MessagePart: &codersdk.ChatStreamMessagePart{},
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
server.publishToStream(chatID, partEvent)
|
||||
}
|
||||
|
||||
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
|
||||
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
|
||||
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
|
||||
|
||||
// Subscriber also saw 50 drops (one per publish).
|
||||
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
|
||||
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
|
||||
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
|
||||
|
||||
// --- Phase 2: clock advance triggers second WARN with count ---
|
||||
mClock.Advance(streamDropWarnInterval + time.Second)
|
||||
server.publishToStream(chatID, partEvent)
|
||||
|
||||
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
||||
require.Len(t, bufWarn, 2)
|
||||
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
|
||||
|
||||
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
|
||||
require.Len(t, subWarn, 2)
|
||||
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
|
||||
|
||||
// --- Phase 3: counter reset (simulates step persist) ---
|
||||
state.mu.Lock()
|
||||
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
|
||||
state.resetDropCounters()
|
||||
state.mu.Unlock()
|
||||
|
||||
// The very next drop should WARN immediately — the reset zeroed
|
||||
// lastWarnAt so the interval check passes.
|
||||
server.publishToStream(chatID, partEvent)
|
||||
|
||||
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
|
||||
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
|
||||
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
|
||||
|
||||
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
|
||||
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
|
||||
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
|
||||
}
|
||||
|
||||
// requireFieldValue asserts that a SinkEntry contains a field with
|
||||
// the given name and value.
|
||||
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
|
||||
t.Helper()
|
||||
for _, f := range entry.Fields {
|
||||
if f.Name == name {
|
||||
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("field %q not found in log entry", name)
|
||||
}
|
||||
|
||||
+221
-1040
File diff suppressed because it is too large
Load Diff
@@ -42,11 +42,6 @@ type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
// Runtime is the wall-clock duration of this step,
|
||||
// covering LLM streaming, tool execution, and retries.
|
||||
// Zero indicates the duration was not measured (e.g.
|
||||
// interrupted steps).
|
||||
Runtime time.Duration
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
@@ -127,7 +122,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
|
||||
switch c.GetType() {
|
||||
case fantasy.ContentTypeText:
|
||||
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
|
||||
if !ok || strings.TrimSpace(text.Text) == "" {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
assistantParts = append(assistantParts, fantasy.TextPart{
|
||||
@@ -136,7 +131,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
|
||||
})
|
||||
case fantasy.ContentTypeReasoning:
|
||||
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
|
||||
if !ok || strings.TrimSpace(reasoning.Text) == "" {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
assistantParts = append(assistantParts, fantasy.ReasoningPart{
|
||||
@@ -265,7 +260,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
stepStart := time.Now()
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -371,7 +365,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
Runtime: time.Since(stepStart),
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
@@ -617,12 +610,10 @@ func processStepStream(
|
||||
result.providerMetadata = part.ProviderMetadata
|
||||
|
||||
case fantasy.StreamPartTypeError:
|
||||
// Detect interruption: the stream may surface the
|
||||
// cancel as context.Canceled or propagate the
|
||||
// ErrInterrupted cause directly, depending on
|
||||
// the provider implementation.
|
||||
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
|
||||
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
|
||||
// Detect interruption: context canceled with
|
||||
// ErrInterrupted as the cause.
|
||||
if errors.Is(part.Error, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
// Flush in-progress content so that
|
||||
// persistInterruptedStep has access to partial
|
||||
// text, reasoning, and tool calls that were
|
||||
@@ -640,23 +631,6 @@ func processStepStream(
|
||||
}
|
||||
}
|
||||
|
||||
// The stream iterator may stop yielding parts without
|
||||
// producing a StreamPartTypeError when the context is
|
||||
// canceled (e.g. some providers close the response body
|
||||
// silently). Detect this case and flush partial content
|
||||
// so that persistInterruptedStep can save it.
|
||||
if ctx.Err() != nil &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
flushActiveState(
|
||||
&result,
|
||||
activeTextContent,
|
||||
activeReasoningContent,
|
||||
activeToolCalls,
|
||||
toolNames,
|
||||
)
|
||||
return result, ErrInterrupted
|
||||
}
|
||||
|
||||
hasLocalToolCalls := false
|
||||
for _, tc := range result.toolCalls {
|
||||
if !tc.ProviderExecuted {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
@@ -65,8 +64,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
require.Greater(t, persistedStep.Runtime, time.Duration(0),
|
||||
"step runtime should be positive")
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
@@ -578,84 +575,6 @@ func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *test
|
||||
assert.False(t, localTR.ProviderExecuted)
|
||||
}
|
||||
|
||||
func TestToResponseMessages_FiltersEmptyTextAndReasoningParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sr := stepResult{
|
||||
content: []fantasy.Content{
|
||||
// Empty text — should be filtered.
|
||||
fantasy.TextContent{Text: ""},
|
||||
// Whitespace-only text — should be filtered.
|
||||
fantasy.TextContent{Text: " \t\n"},
|
||||
// Empty reasoning — should be filtered.
|
||||
fantasy.ReasoningContent{Text: ""},
|
||||
// Whitespace-only reasoning — should be filtered.
|
||||
fantasy.ReasoningContent{Text: " \n"},
|
||||
// Non-empty text — should pass through.
|
||||
fantasy.TextContent{Text: "hello world"},
|
||||
// Leading/trailing whitespace with content — kept
|
||||
// with the original value (not trimmed).
|
||||
fantasy.TextContent{Text: " hello "},
|
||||
// Non-empty reasoning — should pass through.
|
||||
fantasy.ReasoningContent{Text: "let me think"},
|
||||
// Tool call — should be unaffected by filtering.
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "read_file",
|
||||
Input: `{"path":"main.go"}`,
|
||||
},
|
||||
// Local tool result — should be unaffected by filtering.
|
||||
fantasy.ToolResultContent{
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "read_file",
|
||||
Result: fantasy.ToolResultOutputContentText{Text: "file contents"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgs := sr.toResponseMessages()
|
||||
require.Len(t, msgs, 2, "expected assistant + tool messages")
|
||||
|
||||
// First message: assistant role with non-empty text, reasoning,
|
||||
// and the tool call. The four empty/whitespace-only parts must
|
||||
// have been dropped.
|
||||
assistantMsg := msgs[0]
|
||||
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
|
||||
require.Len(t, assistantMsg.Content, 4,
|
||||
"assistant message should have 2x TextPart, ReasoningPart, and ToolCallPart")
|
||||
|
||||
// Part 0: non-empty text.
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[0])
|
||||
require.True(t, ok, "part 0 should be TextPart")
|
||||
assert.Equal(t, "hello world", textPart.Text)
|
||||
|
||||
// Part 1: padded text — original whitespace preserved.
|
||||
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[1])
|
||||
require.True(t, ok, "part 1 should be TextPart")
|
||||
assert.Equal(t, " hello ", paddedPart.Text)
|
||||
|
||||
// Part 2: non-empty reasoning.
|
||||
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](assistantMsg.Content[2])
|
||||
require.True(t, ok, "part 2 should be ReasoningPart")
|
||||
assert.Equal(t, "let me think", reasoningPart.Text)
|
||||
|
||||
// Part 3: tool call (unaffected by text/reasoning filtering).
|
||||
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[3])
|
||||
require.True(t, ok, "part 3 should be ToolCallPart")
|
||||
assert.Equal(t, "tc-1", toolCallPart.ToolCallID)
|
||||
assert.Equal(t, "read_file", toolCallPart.ToolName)
|
||||
|
||||
// Second message: tool role with the local tool result.
|
||||
toolMsg := msgs[1]
|
||||
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
|
||||
require.Len(t, toolMsg.Content, 1,
|
||||
"tool message should have only the local ToolResultPart")
|
||||
|
||||
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
|
||||
require.True(t, ok, "tool part should be ToolResultPart")
|
||||
assert.Equal(t, "tc-1", toolResultPart.ToolCallID)
|
||||
}
|
||||
|
||||
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
|
||||
if len(message.ProviderOptions) == 0 {
|
||||
return false
|
||||
|
||||
@@ -1,399 +0,0 @@
|
||||
package chatloop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testProviderData implements fantasy.ProviderOptionsData so we can
|
||||
// construct arbitrary ProviderMetadata for extractContextLimit tests.
|
||||
type testProviderData struct {
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
func (*testProviderData) Options() {}
|
||||
|
||||
func (d *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.data)
|
||||
}
|
||||
|
||||
// Required by the ProviderOptionsData interface; unused in tests.
|
||||
func (d *testProviderData) UnmarshalJSON(b []byte) error {
|
||||
return json.Unmarshal(b, &d.data)
|
||||
}
|
||||
|
||||
func TestNormalizeMetadataKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want string
|
||||
}{
|
||||
{name: "lowercase", key: "camelCase", want: "camelcase"},
|
||||
{name: "hyphens stripped", key: "kebab-case", want: "kebabcase"},
|
||||
{name: "underscores stripped", key: "snake_case", want: "snakecase"},
|
||||
{name: "uppercase", key: "UPPER", want: "upper"},
|
||||
{name: "spaces stripped", key: "with spaces", want: "withspaces"},
|
||||
{name: "empty", key: "", want: ""},
|
||||
{name: "digits preserved", key: "123", want: "123"},
|
||||
{name: "mixed separators", key: "Max_Context-Tokens", want: "maxcontexttokens"},
|
||||
{name: "dots stripped", key: "context.limit", want: "contextlimit"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := normalizeMetadataKey(tt.key)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsContextLimitKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
want bool
|
||||
skip bool
|
||||
}{ // Exact matches after normalization.
|
||||
{name: "context_limit", key: "context_limit", want: true},
|
||||
{name: "context_window", key: "context_window", want: true},
|
||||
{name: "context_length", key: "context_length", want: true},
|
||||
{name: "max_context", key: "max_context", want: true},
|
||||
{name: "max_context_tokens", key: "max_context_tokens", want: true},
|
||||
{name: "max_input_tokens", key: "max_input_tokens", want: true},
|
||||
{name: "max_input_token", key: "max_input_token", want: true},
|
||||
{name: "input_token_limit", key: "input_token_limit", want: true},
|
||||
|
||||
// Case and separator variations.
|
||||
{name: "Context-Window mixed case", key: "Context-Window", want: true},
|
||||
{name: "MAX_CONTEXT_TOKENS screaming", key: "MAX_CONTEXT_TOKENS", want: true},
|
||||
{name: "contextLimit camelCase", key: "contextLimit", want: true},
|
||||
|
||||
// Fallback heuristic: contains "context" + limit/window/length.
|
||||
{name: "model_context_limit", key: "model_context_limit", want: true},
|
||||
{name: "context_window_size", key: "context_window_size", want: true},
|
||||
{name: "context_length_max", key: "context_length_max", want: true},
|
||||
|
||||
// Fallback heuristic: starts with "max" + contains "context".
|
||||
// BUG(isContextLimitKey): "max_context_version" matches
|
||||
// because it contains "context" and starts with "max",
|
||||
// but a version field is not a context limit.
|
||||
// TODO: Fix the heuristic and remove this skip.
|
||||
{name: "max_context_version false positive", key: "max_context_version", want: false, skip: true}, // Non-matching keys.
|
||||
{name: "context_id no limit keyword", key: "context_id", want: false},
|
||||
{name: "empty string", key: "", want: false},
|
||||
{name: "unrelated key", key: "model_name", want: false},
|
||||
{name: "limit without context", key: "rate_limit", want: false},
|
||||
{name: "max without context", key: "max_tokens", want: false},
|
||||
{name: "context alone", key: "context", want: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if tt.skip {
|
||||
t.Skip("known bug: isContextLimitKey false positive")
|
||||
}
|
||||
got := isContextLimitKey(tt.key)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumericContextLimitValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value any
|
||||
want int64
|
||||
wantOK bool
|
||||
}{
|
||||
// float64: the default numeric type from json.Unmarshal.
|
||||
{name: "float64 integer", value: float64(128000), want: 128000, wantOK: true},
|
||||
{name: "float64 fractional rejected", value: float64(128000.5), want: 0, wantOK: false},
|
||||
{name: "float64 zero rejected", value: float64(0), want: 0, wantOK: false},
|
||||
{name: "float64 negative rejected", value: float64(-1), want: 0, wantOK: false},
|
||||
|
||||
// int64
|
||||
{name: "int64 positive", value: int64(200000), want: 200000, wantOK: true},
|
||||
{name: "int64 zero rejected", value: int64(0), want: 0, wantOK: false},
|
||||
{name: "int64 negative rejected", value: int64(-1), want: 0, wantOK: false},
|
||||
|
||||
// int32
|
||||
{name: "int32 positive", value: int32(50000), want: 50000, wantOK: true},
|
||||
{name: "int32 zero rejected", value: int32(0), want: 0, wantOK: false},
|
||||
|
||||
// int
|
||||
{name: "int positive", value: int(50000), want: 50000, wantOK: true},
|
||||
{name: "int zero rejected", value: int(0), want: 0, wantOK: false},
|
||||
|
||||
// string
|
||||
{name: "string numeric", value: "128000", want: 128000, wantOK: true},
|
||||
{name: "string trimmed", value: " 128000 ", want: 128000, wantOK: true},
|
||||
{name: "string non-numeric rejected", value: "not a number", want: 0, wantOK: false},
|
||||
{name: "string empty rejected", value: "", want: 0, wantOK: false},
|
||||
{name: "string zero rejected", value: "0", want: 0, wantOK: false},
|
||||
{name: "string negative rejected", value: "-1", want: 0, wantOK: false},
|
||||
|
||||
// json.Number
|
||||
{name: "json.Number valid", value: json.Number("200000"), want: 200000, wantOK: true},
|
||||
{name: "json.Number invalid rejected", value: json.Number("invalid"), want: 0, wantOK: false},
|
||||
{name: "json.Number zero rejected", value: json.Number("0"), want: 0, wantOK: false},
|
||||
|
||||
// Unhandled types.
|
||||
{name: "bool rejected", value: true, want: 0, wantOK: false},
|
||||
{name: "nil rejected", value: nil, want: 0, wantOK: false},
|
||||
{name: "slice rejected", value: []int{1}, want: 0, wantOK: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, ok := numericContextLimitValue(tt.value)
|
||||
require.Equal(t, tt.wantOK, ok)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPositiveInt64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := positiveInt64(42)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(42), got)
|
||||
|
||||
got, ok = positiveInt64(0)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), got)
|
||||
|
||||
got, ok = positiveInt64(-1)
|
||||
require.False(t, ok)
|
||||
require.Equal(t, int64(0), got)
|
||||
}
|
||||
|
||||
func TestCollectContextLimitValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("FlatMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"context_limit": float64(200000),
|
||||
"other_key": float64(999),
|
||||
}
|
||||
var collected []int64
|
||||
collectContextLimitValues(input, func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Equal(t, []int64{200000}, collected)
|
||||
})
|
||||
|
||||
t.Run("NestedMaps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"provider": map[string]any{
|
||||
"info": map[string]any{
|
||||
"context_window": float64(100000),
|
||||
},
|
||||
},
|
||||
}
|
||||
var collected []int64
|
||||
collectContextLimitValues(input, func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Equal(t, []int64{100000}, collected)
|
||||
})
|
||||
|
||||
t.Run("ArrayTraversal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"context_limit": float64(50000)},
|
||||
map[string]any{"context_limit": float64(80000)},
|
||||
}
|
||||
var collected []int64
|
||||
collectContextLimitValues(input, func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Len(t, collected, 2)
|
||||
require.Contains(t, collected, int64(50000))
|
||||
require.Contains(t, collected, int64(80000))
|
||||
})
|
||||
|
||||
t.Run("MixedNesting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"models": []any{
|
||||
map[string]any{
|
||||
"context_limit": float64(128000),
|
||||
},
|
||||
},
|
||||
}
|
||||
var collected []int64
|
||||
collectContextLimitValues(input, func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Equal(t, []int64{128000}, collected)
|
||||
})
|
||||
|
||||
t.Run("NonMatchingKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"model_name": "gpt-4",
|
||||
"tokens": float64(1000),
|
||||
}
|
||||
var collected []int64
|
||||
collectContextLimitValues(input, func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Empty(t, collected)
|
||||
})
|
||||
|
||||
t.Run("ScalarIgnored", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var collected []int64
|
||||
collectContextLimitValues("just a string", func(v int64) {
|
||||
collected = append(collected, v)
|
||||
})
|
||||
require.Empty(t, collected)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindContextLimitValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SingleCandidate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"context_limit": float64(200000),
|
||||
}
|
||||
limit, ok := findContextLimitValue(input)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(200000), limit)
|
||||
})
|
||||
|
||||
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"a": map[string]any{"context_limit": float64(50000)},
|
||||
"b": map[string]any{"context_limit": float64(200000)},
|
||||
}
|
||||
limit, ok := findContextLimitValue(input)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(200000), limit)
|
||||
})
|
||||
|
||||
t.Run("NoCandidates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := map[string]any{
|
||||
"model": "gpt-4",
|
||||
}
|
||||
_, ok := findContextLimitValue(input)
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("NilInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, ok := findContextLimitValue(nil)
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractContextLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AnthropicStyle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"anthropic": &testProviderData{
|
||||
data: map[string]any{
|
||||
"cache_read_input_tokens": float64(100),
|
||||
"context_limit": float64(200000),
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
require.True(t, result.Valid)
|
||||
require.Equal(t, int64(200000), result.Int64)
|
||||
})
|
||||
|
||||
t.Run("OpenAIStyle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"openai": &testProviderData{
|
||||
data: map[string]any{
|
||||
"max_context_tokens": float64(128000),
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
require.True(t, result.Valid)
|
||||
require.Equal(t, int64(128000), result.Int64)
|
||||
})
|
||||
|
||||
t.Run("NestedDeeply", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"provider": &testProviderData{
|
||||
data: map[string]any{
|
||||
"info": map[string]any{
|
||||
"context_window": float64(100000),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
require.True(t, result.Valid)
|
||||
require.Equal(t, int64(100000), result.Int64)
|
||||
})
|
||||
|
||||
t.Run("MultipleCandidatesTakesMax", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"a": &testProviderData{
|
||||
data: map[string]any{
|
||||
"context_limit": float64(50000),
|
||||
},
|
||||
},
|
||||
"b": &testProviderData{
|
||||
data: map[string]any{
|
||||
"context_limit": float64(200000),
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
require.True(t, result.Valid)
|
||||
require.Equal(t, int64(200000), result.Int64)
|
||||
})
|
||||
|
||||
t.Run("NoMatchingKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
metadata := fantasy.ProviderMetadata{
|
||||
"openai": &testProviderData{
|
||||
data: map[string]any{
|
||||
"model": "gpt-4",
|
||||
"tokens": float64(1000),
|
||||
},
|
||||
},
|
||||
}
|
||||
result := extractContextLimit(metadata)
|
||||
assert.False(t, result.Valid)
|
||||
})
|
||||
|
||||
t.Run("NilMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := extractContextLimit(nil)
|
||||
assert.False(t, result.Valid)
|
||||
})
|
||||
|
||||
t.Run("EmptyMetadata", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := extractContextLimit(fantasy.ProviderMetadata{})
|
||||
assert.False(t, result.Valid)
|
||||
})
|
||||
}
|
||||
@@ -139,13 +139,9 @@ func ConvertMessagesWithFiles(
|
||||
},
|
||||
})
|
||||
case codersdk.ChatMessageRoleUser:
|
||||
userParts := partsToMessageParts(logger, pm.parts, resolved)
|
||||
if len(userParts) == 0 {
|
||||
continue
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: userParts,
|
||||
Content: partsToMessageParts(logger, pm.parts, resolved),
|
||||
})
|
||||
case codersdk.ChatMessageRoleAssistant:
|
||||
fantasyParts := normalizeAssistantToolCallInputs(
|
||||
@@ -157,9 +153,6 @@ func ConvertMessagesWithFiles(
|
||||
}
|
||||
toolNameByCallID[sanitizeToolCallID(toolCall.ToolCallID)] = toolCall.ToolName
|
||||
}
|
||||
if len(fantasyParts) == 0 {
|
||||
continue
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: fantasyParts,
|
||||
@@ -173,13 +166,9 @@ func ConvertMessagesWithFiles(
|
||||
}
|
||||
}
|
||||
}
|
||||
toolParts := partsToMessageParts(logger, pm.parts, resolved)
|
||||
if len(toolParts) == 0 {
|
||||
continue
|
||||
}
|
||||
prompt = append(prompt, fantasy.Message{
|
||||
Role: fantasy.MessageRoleTool,
|
||||
Content: toolParts,
|
||||
Content: partsToMessageParts(logger, pm.parts, resolved),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -332,7 +321,6 @@ func parseContentV1(role codersdk.ChatMessageRole, raw pqtype.NullRawMessage) ([
|
||||
if err := json.Unmarshal(raw.RawMessage, &parts); err != nil {
|
||||
return nil, xerrors.Errorf("parse %s content: %w", role, err)
|
||||
}
|
||||
decodeNulInParts(parts)
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
@@ -1030,16 +1018,11 @@ func sanitizeToolCallID(id string) string {
|
||||
}
|
||||
|
||||
// MarshalParts encodes SDK chat message parts for persistence.
|
||||
// NUL characters in string fields are encoded as PUA sentinel
|
||||
// pairs (U+E000 U+E001) before marshaling so the resulting JSON
|
||||
// never contains \u0000 (rejected by PostgreSQL jsonb). The
|
||||
// encoding operates on Go string values, not JSON bytes, so it
|
||||
// survives jsonb text normalization.
|
||||
func MarshalParts(parts []codersdk.ChatMessagePart) (pqtype.NullRawMessage, error) {
|
||||
if len(parts) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
data, err := json.Marshal(encodeNulInParts(parts))
|
||||
data, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return pqtype.NullRawMessage{}, xerrors.Errorf("encode chat message parts: %w", err)
|
||||
}
|
||||
@@ -1186,23 +1169,11 @@ func partsToMessageParts(
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeText:
|
||||
// Anthropic rejects empty text content blocks with
|
||||
// "text content blocks must be non-empty". Empty parts
|
||||
// can arise when a stream sends TextStart/TextEnd with
|
||||
// no delta in between. We filter them here rather than
|
||||
// at persistence time to preserve the raw record.
|
||||
if strings.TrimSpace(part.Text) == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, fantasy.TextPart{
|
||||
Text: part.Text,
|
||||
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
|
||||
})
|
||||
case codersdk.ChatMessagePartTypeReasoning:
|
||||
// Same guard as text parts above.
|
||||
if strings.TrimSpace(part.Text) == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, fantasy.ReasoningPart{
|
||||
Text: part.Text,
|
||||
ProviderOptions: providerMetadataToOptions(logger, part.ProviderMetadata),
|
||||
@@ -1245,186 +1216,3 @@ func partsToMessageParts(
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeNulInString replaces NUL (U+0000) characters in s with
|
||||
// the sentinel pair U+E000 U+E001, and doubles any pre-existing
|
||||
// U+E000 to U+E000 U+E000 so the encoding is reversible.
|
||||
// Operates on Unicode code points, not JSON escape sequences,
|
||||
// making it safe through jsonb round-trips (jsonb stores parsed
|
||||
// characters, not original escape text).
|
||||
func encodeNulInString(s string) string {
|
||||
if !strings.ContainsRune(s, 0) && !strings.ContainsRune(s, '\uE000') {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '\uE000':
|
||||
_, _ = b.WriteRune('\uE000')
|
||||
_, _ = b.WriteRune('\uE000')
|
||||
case 0:
|
||||
_, _ = b.WriteRune('\uE000')
|
||||
_, _ = b.WriteRune('\uE001')
|
||||
default:
|
||||
_, _ = b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// decodeNulInString reverses encodeNulInString: U+E000 U+E000
|
||||
// becomes U+E000, and U+E000 U+E001 becomes NUL.
|
||||
func decodeNulInString(s string) string {
|
||||
if !strings.ContainsRune(s, '\uE000') {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
runes := []rune(s)
|
||||
for i := 0; i < len(runes); i++ {
|
||||
if runes[i] == '\uE000' && i+1 < len(runes) {
|
||||
switch runes[i+1] {
|
||||
case '\uE000':
|
||||
_, _ = b.WriteRune('\uE000')
|
||||
i++
|
||||
case '\uE001':
|
||||
_, _ = b.WriteRune(0)
|
||||
i++
|
||||
default:
|
||||
// Unpaired sentinel — preserve as-is.
|
||||
_, _ = b.WriteRune(runes[i])
|
||||
}
|
||||
} else {
|
||||
_, _ = b.WriteRune(runes[i])
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// encodeNulInValue recursively walks a JSON value (as produced
|
||||
// by json.Unmarshal with UseNumber) and applies
|
||||
// encodeNulInString to every string, including map keys.
|
||||
func encodeNulInValue(v any) any {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return encodeNulInString(val)
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(val))
|
||||
for k, elem := range val {
|
||||
out[encodeNulInString(k)] = encodeNulInValue(elem)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(val))
|
||||
for i, elem := range val {
|
||||
out[i] = encodeNulInValue(elem)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return v // numbers, bools, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeNulInValue recursively walks a JSON value and applies
|
||||
// decodeNulInString to every string, including map keys.
|
||||
func decodeNulInValue(v any) any {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return decodeNulInString(val)
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(val))
|
||||
for k, elem := range val {
|
||||
out[decodeNulInString(k)] = decodeNulInValue(elem)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(val))
|
||||
for i, elem := range val {
|
||||
out[i] = decodeNulInValue(elem)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// encodeNulInJSON walks all string values (and keys) inside a
|
||||
// json.RawMessage and applies encodeNulInString. Returns the
|
||||
// original unchanged when the raw message does not contain NUL
|
||||
// escapes or U+E000 bytes, or when parsing fails.
|
||||
func encodeNulInJSON(raw json.RawMessage) json.RawMessage {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
// Quick exit: no \u0000 escape and no U+E000 UTF-8 bytes.
|
||||
if !bytes.Contains(raw, []byte(`\u0000`)) &&
|
||||
!bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
|
||||
return raw
|
||||
}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
return raw
|
||||
}
|
||||
result, err := json.Marshal(encodeNulInValue(v))
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// decodeNulInJSON walks all string values (and keys) inside a
|
||||
// json.RawMessage and applies decodeNulInString.
|
||||
func decodeNulInJSON(raw json.RawMessage) json.RawMessage {
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
// U+E000 encoded as UTF-8 is 0xEE 0x80 0x80.
|
||||
if !bytes.Contains(raw, []byte{0xEE, 0x80, 0x80}) {
|
||||
return raw
|
||||
}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
var v any
|
||||
if err := dec.Decode(&v); err != nil {
|
||||
return raw
|
||||
}
|
||||
result, err := json.Marshal(decodeNulInValue(v))
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeNulInParts returns a shallow copy of parts with all
|
||||
// string and json.RawMessage fields NUL-encoded. The caller's
|
||||
// slice is not modified.
|
||||
func encodeNulInParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
|
||||
encoded := make([]codersdk.ChatMessagePart, len(parts))
|
||||
copy(encoded, parts)
|
||||
for i := range encoded {
|
||||
p := &encoded[i]
|
||||
p.Text = encodeNulInString(p.Text)
|
||||
p.Content = encodeNulInString(p.Content)
|
||||
p.Args = encodeNulInJSON(p.Args)
|
||||
p.ArgsDelta = encodeNulInString(p.ArgsDelta)
|
||||
p.Result = encodeNulInJSON(p.Result)
|
||||
p.ResultDelta = encodeNulInString(p.ResultDelta)
|
||||
}
|
||||
return encoded
|
||||
}
|
||||
|
||||
// decodeNulInParts reverses encodeNulInParts in place.
|
||||
func decodeNulInParts(parts []codersdk.ChatMessagePart) {
|
||||
for i := range parts {
|
||||
p := &parts[i]
|
||||
p.Text = decodeNulInString(p.Text)
|
||||
p.Content = decodeNulInString(p.Content)
|
||||
p.Args = decodeNulInJSON(p.Args)
|
||||
p.ArgsDelta = decodeNulInString(p.ArgsDelta)
|
||||
p.Result = decodeNulInJSON(p.Result)
|
||||
p.ResultDelta = decodeNulInString(p.ResultDelta)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,10 +17,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// testMsg builds a database.ChatMessage for ParseContent tests.
|
||||
@@ -1444,327 +1441,3 @@ func extractToolResultIDs(t *testing.T, msgs ...fantasy.Message) []string {
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestNulEscapeRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Seed minimal dependencies for the DB round-trip path:
|
||||
// user, provider, model config, chat.
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "openai",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "nul-roundtrip-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
textTests := []struct {
|
||||
name string
|
||||
input string
|
||||
hasNul bool // Whether the input contains actual NUL bytes.
|
||||
}{
|
||||
// --- basic ---
|
||||
{"NoNul", "hello world", false},
|
||||
{"SingleNul", "a\x00b", true},
|
||||
{"MultipleNuls", "a\x00b\x00c", true},
|
||||
{"ConsecutiveNuls", "\x00\x00\x00", true},
|
||||
|
||||
// --- boundaries ---
|
||||
{"EmptyString", "", false},
|
||||
{"NulOnly", "\x00", true},
|
||||
{"NulAtStart", "\x00hello", true},
|
||||
{"NulAtEnd", "hello\x00", true},
|
||||
|
||||
// --- sentinel / marker in original data ---
|
||||
// U+E000 is the sentinel character. The encoder must
|
||||
// double it so it round-trips without being mistaken
|
||||
// for an encoded NUL.
|
||||
{"SentinelInOriginal", "a\uE000b", false},
|
||||
{"ConsecutiveSentinels", "\uE000\uE000\uE000", false},
|
||||
// U+E001 is the marker character used in the NUL pair.
|
||||
{"MarkerCharInOriginal", "a\uE001b", false},
|
||||
// U+E000 followed by U+E001 looks exactly like an
|
||||
// encoded NUL in the encoded form, so the encoder must
|
||||
// double the U+E000 to avoid confusion.
|
||||
{"SentinelThenMarkerChar", "\uE000\uE001", false},
|
||||
{"NulAndSentinel", "a\x00b\uE000c", true},
|
||||
// Both orders: sentinel adjacent to NUL.
|
||||
{"SentinelThenNul", "\uE000\x00", true},
|
||||
{"NulThenSentinel", "\x00\uE000", true},
|
||||
{"AlternatingSentinelNul", "\x00\uE000\x00\uE000", true},
|
||||
|
||||
// --- strings containing backslashes ---
|
||||
// Backslashes are normal characters at the Go string
|
||||
// level; no special handling needed (unlike the old
|
||||
// JSON-byte approach).
|
||||
{"BackslashU0000Text", "\\u0000", false},
|
||||
{"BackslashThenNul", "\\\x00", true},
|
||||
|
||||
// --- literal text that looks like escape patterns ---
|
||||
{"LiteralTextU0000", "the value is u0000 here", false},
|
||||
{"LiteralTextUE000", "sentinel uE000 text", false},
|
||||
|
||||
// --- other control characters mixed with NUL ---
|
||||
{"ControlCharsMixedWithNul", "\x01\x00\x02\x00\x1f", true},
|
||||
|
||||
// --- long / stress ---
|
||||
{"LongNulRun", "\x00\x00\x00\x00\x00\x00\x00\x00", true},
|
||||
// Simulated find -print0 output.
|
||||
{"FindPrint0", "/usr/bin/ls\x00/usr/bin/cat\x00/usr/bin/grep\x00", true},
|
||||
}
|
||||
|
||||
for _, tc := range textTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(tc.input),
|
||||
}
|
||||
|
||||
encoded, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When the input has real NUL bytes, the stored JSON
|
||||
// must not contain the \u0000 escape sequence.
|
||||
if tc.hasNul {
|
||||
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
|
||||
"encoded JSON must not contain \\u0000")
|
||||
}
|
||||
|
||||
// In-memory round-trip through ParseContent.
|
||||
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
|
||||
decoded, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, decoded, 1)
|
||||
require.Equal(t, tc.input, decoded[0].Text)
|
||||
|
||||
// Full DB round-trip: write to PostgreSQL jsonb, read
|
||||
// back, and verify the value survives storage.
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
dbMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{string(encoded.RawMessage)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbMsgs, 1)
|
||||
|
||||
readBack, err := db.GetChatMessageByID(ctx, dbMsgs[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbDecoded, err := chatprompt.ParseContent(readBack)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbDecoded, 1)
|
||||
require.Equal(t, tc.input, dbDecoded[0].Text)
|
||||
})
|
||||
}
|
||||
|
||||
// Tool result with NUL in the result JSON value.
|
||||
t.Run("ToolResultWithNul", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resultJSON := json.RawMessage(`"output:\u0000done"`)
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false),
|
||||
}
|
||||
|
||||
encoded, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
|
||||
"encoded JSON must not contain \\u0000")
|
||||
|
||||
msg := testMsgV1(codersdk.ChatMessageRoleTool, encoded)
|
||||
decoded, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, decoded, 1)
|
||||
// JSON re-serialization may reformat, so compare
|
||||
// semantically.
|
||||
assert.JSONEq(t, string(resultJSON), string(decoded[0].Result))
|
||||
})
|
||||
|
||||
// Multiple parts in one message: one with NUL, one without.
|
||||
t.Run("MultiPartMixed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("clean text"),
|
||||
codersdk.ChatMessageText("has\x00nul"),
|
||||
}
|
||||
|
||||
encoded, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(encoded.RawMessage), `\u0000`,
|
||||
"encoded JSON must not contain \\u0000")
|
||||
|
||||
msg := testMsgV1(codersdk.ChatMessageRoleAssistant, encoded)
|
||||
decoded, err := chatprompt.ParseContent(msg)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, decoded, 2)
|
||||
require.Equal(t, "clean text", decoded[0].Text)
|
||||
require.Equal(t, "has\x00nul", decoded[1].Text)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvertMessagesWithFiles_FiltersEmptyTextAndReasoningParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Helper to build a DB message from SDK parts.
|
||||
makeMsg := func(t *testing.T, role database.ChatMessageRole, parts []codersdk.ChatMessagePart) database.ChatMessage {
|
||||
t.Helper()
|
||||
encoded, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
return database.ChatMessage{
|
||||
Role: role,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: encoded,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("UserRole", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(""), // empty — filtered
|
||||
codersdk.ChatMessageText(" \t\n "), // whitespace — filtered
|
||||
codersdk.ChatMessageReasoning(""), // empty — filtered
|
||||
codersdk.ChatMessageReasoning(" \n"), // whitespace — filtered
|
||||
codersdk.ChatMessageText("hello"), // kept
|
||||
codersdk.ChatMessageText(" hello "), // kept with original whitespace
|
||||
codersdk.ChatMessageReasoning("thinking deeply"), // kept
|
||||
codersdk.ChatMessageToolCall("call-1", "my_tool", json.RawMessage(`{"x":1}`)),
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false),
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleUser, parts)},
|
||||
nil,
|
||||
slogtest.Make(t, nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, prompt, 1)
|
||||
|
||||
resultParts := prompt[0].Content
|
||||
require.Len(t, resultParts, 5, "expected 5 parts after filtering empty text/reasoning")
|
||||
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
|
||||
require.True(t, ok, "expected TextPart at index 0")
|
||||
require.Equal(t, "hello", textPart.Text)
|
||||
|
||||
// Leading/trailing whitespace is preserved — only
|
||||
// all-whitespace parts are dropped.
|
||||
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[1])
|
||||
require.True(t, ok, "expected TextPart at index 1")
|
||||
require.Equal(t, " hello ", paddedPart.Text)
|
||||
|
||||
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](resultParts[2])
|
||||
require.True(t, ok, "expected ReasoningPart at index 2")
|
||||
require.Equal(t, "thinking deeply", reasoningPart.Text)
|
||||
|
||||
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[3])
|
||||
require.True(t, ok, "expected ToolCallPart at index 3")
|
||||
require.Equal(t, "call-1", toolCallPart.ToolCallID)
|
||||
|
||||
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](resultParts[4])
|
||||
require.True(t, ok, "expected ToolResultPart at index 4")
|
||||
require.Equal(t, "call-1", toolResultPart.ToolCallID)
|
||||
})
|
||||
|
||||
t.Run("AssistantRole", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(""), // empty — filtered
|
||||
codersdk.ChatMessageText(" "), // whitespace — filtered
|
||||
codersdk.ChatMessageReasoning(""), // empty — filtered
|
||||
codersdk.ChatMessageText(" reply "), // kept with whitespace
|
||||
codersdk.ChatMessageToolCall("tc-1", "read_file", json.RawMessage(`{"path":"x"}`)),
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
|
||||
nil,
|
||||
slogtest.Make(t, nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// 2 messages: assistant + synthetic tool result injected
|
||||
// by injectMissingToolResults for the unmatched tool call.
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultParts := prompt[0].Content
|
||||
require.Len(t, resultParts, 2, "expected text + tool-call after filtering")
|
||||
|
||||
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](resultParts[0])
|
||||
require.True(t, ok, "expected TextPart")
|
||||
require.Equal(t, " reply ", textPart.Text)
|
||||
|
||||
tcPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](resultParts[1])
|
||||
require.True(t, ok, "expected ToolCallPart")
|
||||
require.Equal(t, "tc-1", tcPart.ToolCallID)
|
||||
})
|
||||
|
||||
t.Run("AllEmptyDropsMessage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When every part is filtered, the message itself should
|
||||
// be dropped rather than appending an empty-content message.
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText(""),
|
||||
codersdk.ChatMessageText(" "),
|
||||
codersdk.ChatMessageReasoning(""),
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
context.Background(),
|
||||
[]database.ChatMessage{makeMsg(t, database.ChatMessageRoleAssistant, parts)},
|
||||
nil,
|
||||
slogtest.Make(t, nil),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, prompt, "all-empty message should be dropped entirely")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1083,7 +1083,6 @@ func openAIProviderOptionsFromChatConfig(
|
||||
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
|
||||
ServiceTier: openAIServiceTierFromChat(options.ServiceTier),
|
||||
StrictJSONSchema: options.StrictJSONSchema,
|
||||
Store: boolPtrOrDefault(options.Store, true),
|
||||
TextVerbosity: OpenAITextVerbosityFromChat(options.TextVerbosity),
|
||||
User: normalizedStringPointer(options.User),
|
||||
}
|
||||
@@ -1100,7 +1099,7 @@ func openAIProviderOptionsFromChatConfig(
|
||||
MaxCompletionTokens: options.MaxCompletionTokens,
|
||||
TextVerbosity: normalizedStringPointer(options.TextVerbosity),
|
||||
Prediction: options.Prediction,
|
||||
Store: boolPtrOrDefault(options.Store, true),
|
||||
Store: options.Store,
|
||||
Metadata: options.Metadata,
|
||||
PromptCacheKey: normalizedStringPointer(options.PromptCacheKey),
|
||||
SafetyIdentifier: normalizedStringPointer(options.SafetyIdentifier),
|
||||
@@ -1281,13 +1280,6 @@ func useOpenAIResponsesOptions(model fantasy.LanguageModel) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func boolPtrOrDefault(value *bool, def bool) *bool {
|
||||
if value != nil {
|
||||
return value
|
||||
}
|
||||
return &def
|
||||
}
|
||||
|
||||
func normalizedStringPointer(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -26,37 +25,37 @@ func TestReasoningEffortFromChat(t *testing.T) {
|
||||
{
|
||||
name: "OpenAICaseInsensitive",
|
||||
provider: "openai",
|
||||
input: ptr.Ref(" HIGH "),
|
||||
want: ptr.Ref(string(fantasyopenai.ReasoningEffortHigh)),
|
||||
input: stringPtr(" HIGH "),
|
||||
want: stringPtr(string(fantasyopenai.ReasoningEffortHigh)),
|
||||
},
|
||||
{
|
||||
name: "AnthropicEffort",
|
||||
provider: "anthropic",
|
||||
input: ptr.Ref("max"),
|
||||
want: ptr.Ref(string(fantasyanthropic.EffortMax)),
|
||||
input: stringPtr("max"),
|
||||
want: stringPtr(string(fantasyanthropic.EffortMax)),
|
||||
},
|
||||
{
|
||||
name: "OpenRouterEffort",
|
||||
provider: "openrouter",
|
||||
input: ptr.Ref("medium"),
|
||||
want: ptr.Ref(string(fantasyopenrouter.ReasoningEffortMedium)),
|
||||
input: stringPtr("medium"),
|
||||
want: stringPtr(string(fantasyopenrouter.ReasoningEffortMedium)),
|
||||
},
|
||||
{
|
||||
name: "VercelEffort",
|
||||
provider: "vercel",
|
||||
input: ptr.Ref("xhigh"),
|
||||
want: ptr.Ref(string(fantasyvercel.ReasoningEffortXHigh)),
|
||||
input: stringPtr("xhigh"),
|
||||
want: stringPtr(string(fantasyvercel.ReasoningEffortXHigh)),
|
||||
},
|
||||
{
|
||||
name: "InvalidEffortReturnsNil",
|
||||
provider: "openai",
|
||||
input: ptr.Ref("unknown"),
|
||||
input: stringPtr("unknown"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedProviderReturnsNil",
|
||||
provider: "bedrock",
|
||||
input: ptr.Ref("high"),
|
||||
input: stringPtr("high"),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
@@ -83,8 +82,8 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Enabled: ptr.Ref(true),
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"openai"},
|
||||
@@ -93,22 +92,22 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Enabled: ptr.Ref(false),
|
||||
Exclude: ptr.Ref(true),
|
||||
MaxTokens: ptr.Ref[int64](123),
|
||||
Effort: ptr.Ref("high"),
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
Effort: stringPtr("high"),
|
||||
},
|
||||
IncludeUsage: ptr.Ref(true),
|
||||
IncludeUsage: boolPtr(true),
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
Order: []string{"anthropic"},
|
||||
AllowFallbacks: ptr.Ref(true),
|
||||
RequireParameters: ptr.Ref(false),
|
||||
DataCollection: ptr.Ref("allow"),
|
||||
AllowFallbacks: boolPtr(true),
|
||||
RequireParameters: boolPtr(false),
|
||||
DataCollection: stringPtr("allow"),
|
||||
Only: []string{"openai"},
|
||||
Ignore: []string{"foo"},
|
||||
Quantizations: []string{"int8"},
|
||||
Sort: ptr.Ref("latency"),
|
||||
Sort: stringPtr("latency"),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -137,3 +136,15 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations)
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
|
||||
func boolPtr(value bool) *bool {
|
||||
return &value
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package chattool
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
@@ -21,10 +23,9 @@ const (
|
||||
// maxOutputToModel is the maximum output sent to the LLM.
|
||||
maxOutputToModel = 32 << 10 // 32KB
|
||||
|
||||
// snapshotTimeout is how long a non-blocking fallback
|
||||
// request is allowed to take when retrieving a process
|
||||
// output snapshot after a blocking wait times out.
|
||||
snapshotTimeout = 30 * time.Second
|
||||
// pollInterval is how often we check for process completion
|
||||
// in foreground mode.
|
||||
pollInterval = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// nonInteractiveEnvVars are set on every process to prevent
|
||||
@@ -88,7 +89,7 @@ type ExecuteArgs struct {
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding. If the command times out, the response includes a background_process_id so you can retrieve output later with process_output.",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -126,7 +127,7 @@ func executeTool(
|
||||
// run_in_background parameter, which causes the shell to fork
|
||||
// and exit immediately, leaving an untracked orphan process.
|
||||
trimmed := strings.TrimSpace(args.Command)
|
||||
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") && !strings.HasSuffix(trimmed, "|&") {
|
||||
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
|
||||
background = true
|
||||
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
|
||||
}
|
||||
@@ -172,7 +173,7 @@ func executeBackground(
|
||||
return fantasy.NewTextResponse(string(data))
|
||||
}
|
||||
|
||||
// executeForeground starts a process and waits for its
|
||||
// executeForeground starts a process and polls for its
|
||||
// completion, enforcing the configured timeout.
|
||||
func executeForeground(
|
||||
ctx context.Context,
|
||||
@@ -211,7 +212,7 @@ func executeForeground(
|
||||
return errorResult(fmt.Sprintf("start process: %v", err))
|
||||
}
|
||||
|
||||
result := waitForProcess(cmdCtx, conn, resp.ID, timeout)
|
||||
result := pollProcess(cmdCtx, conn, resp.ID, timeout)
|
||||
result.WallDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Add an advisory note for file-dump commands.
|
||||
@@ -236,84 +237,62 @@ func truncateOutput(output string) string {
|
||||
return output
|
||||
}
|
||||
|
||||
// waitForProcess waits for process completion using the
|
||||
// blocking process output API instead of polling.
|
||||
func waitForProcess(
|
||||
// pollProcess polls for process output until the process exits
|
||||
// or the context times out.
|
||||
func pollProcess(
|
||||
ctx context.Context,
|
||||
conn workspacesdk.AgentConn,
|
||||
processID string,
|
||||
timeout time.Duration,
|
||||
) ExecuteResult {
|
||||
// Block until the process exits or the context is
|
||||
// canceled.
|
||||
resp, err := conn.ProcessOutput(ctx, processID, &workspacesdk.ProcessOutputOptions{
|
||||
Wait: true,
|
||||
})
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// Timeout: fetch final snapshot with a fresh
|
||||
// context. The blocking request was canceled
|
||||
// so the response body was lost.
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Timeout — get whatever output we have. Use a
|
||||
// fresh context since cmdCtx is already canceled.
|
||||
bgCtx, bgCancel := context.WithTimeout(
|
||||
context.Background(),
|
||||
snapshotTimeout,
|
||||
5*time.Second,
|
||||
)
|
||||
defer bgCancel()
|
||||
resp, err = conn.ProcessOutput(bgCtx, processID, nil)
|
||||
outputResp, outputErr := conn.ProcessOutput(bgCtx, processID)
|
||||
bgCancel()
|
||||
output := truncateOutput(outputResp.Output)
|
||||
timeoutErr := xerrors.Errorf("command timed out after %s", timeout)
|
||||
if outputErr != nil {
|
||||
timeoutErr = errors.Join(timeoutErr, xerrors.Errorf("failed to get output: %w", outputErr))
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: timeoutErr.Error(),
|
||||
Truncated: outputResp.Truncated,
|
||||
}
|
||||
case <-ticker.C:
|
||||
outputResp, err := conn.ProcessOutput(ctx, processID)
|
||||
if err != nil {
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s; failed to get output: %v", timeout, err),
|
||||
BackgroundProcessID: processID,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("get process output: %v", err),
|
||||
}
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s", timeout),
|
||||
Truncated: resp.Truncated,
|
||||
BackgroundProcessID: processID,
|
||||
if !outputResp.Running {
|
||||
exitCode := 0
|
||||
if outputResp.ExitCode != nil {
|
||||
exitCode = *outputResp.ExitCode
|
||||
}
|
||||
output := truncateOutput(outputResp.Output)
|
||||
return ExecuteResult{
|
||||
Success: exitCode == 0,
|
||||
Output: output,
|
||||
ExitCode: exitCode,
|
||||
Truncated: outputResp.Truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("get process output: %v", err),
|
||||
}
|
||||
}
|
||||
|
||||
// The server-side wait may return before the
|
||||
// process exits if maxWaitDuration is shorter than
|
||||
// the client's timeout. Retry if our context still
|
||||
// has time left.
|
||||
if resp.Running {
|
||||
if ctx.Err() == nil {
|
||||
// Still within the caller's timeout, retry.
|
||||
return waitForProcess(ctx, conn, processID, timeout)
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: fmt.Sprintf("command timed out after %s", timeout),
|
||||
Truncated: resp.Truncated,
|
||||
BackgroundProcessID: processID,
|
||||
}
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
if resp.ExitCode != nil {
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
return ExecuteResult{
|
||||
Success: exitCode == 0,
|
||||
Output: output,
|
||||
ExitCode: exitCode,
|
||||
Truncated: resp.Truncated,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,19 +322,10 @@ func detectFileDump(command string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultProcessOutputTimeout is the default time the
|
||||
// process_output tool blocks waiting for new output or
|
||||
// process exit before returning. This avoids polling
|
||||
// loops that waste tokens and HTTP round-trips.
|
||||
defaultProcessOutputTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// ProcessOutputArgs are the parameters accepted by the
|
||||
// process_output tool.
|
||||
type ProcessOutputArgs struct {
|
||||
ProcessID string `json:"process_id"`
|
||||
WaitTimeout *string `json:"wait_timeout,omitempty" description:"Override the default 10s block duration. The call blocks until the process exits or this timeout is reached. Set to '0s' for an immediate snapshot without waiting."`
|
||||
ProcessID string `json:"process_id"`
|
||||
}
|
||||
|
||||
// ProcessOutput returns an AgentTool that retrieves the output
|
||||
@@ -365,13 +335,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
"process_output",
|
||||
"Retrieve output from a background process. "+
|
||||
"Use the process_id returned by execute with "+
|
||||
"run_in_background=true or from a timed-out "+
|
||||
"execute's background_process_id. Blocks up to "+
|
||||
"10s for the process to exit, then returns the "+
|
||||
"output and exit_code. If still running after "+
|
||||
"the timeout, returns the output so far. Use "+
|
||||
"wait_timeout to override the default 10s wait "+
|
||||
"(e.g. '30s', or '0s' for an immediate snapshot).",
|
||||
"run_in_background=true. Returns the current output, "+
|
||||
"whether the process is still running, and the exit "+
|
||||
"code if it has finished.",
|
||||
func(ctx context.Context, args ProcessOutputArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -383,42 +349,9 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
timeout := defaultProcessOutputTimeout
|
||||
if args.WaitTimeout != nil {
|
||||
parsed, err := time.ParseDuration(*args.WaitTimeout)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("invalid wait_timeout %q: %v", *args.WaitTimeout, err),
|
||||
), nil
|
||||
}
|
||||
timeout = parsed
|
||||
}
|
||||
var opts *workspacesdk.ProcessOutputOptions
|
||||
// Save parent context before applying timeout.
|
||||
parentCtx := ctx
|
||||
if timeout > 0 {
|
||||
opts = &workspacesdk.ProcessOutputOptions{
|
||||
Wait: true,
|
||||
}
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
resp, err := conn.ProcessOutput(ctx, args.ProcessID, opts)
|
||||
resp, err := conn.ProcessOutput(ctx, args.ProcessID)
|
||||
if err != nil {
|
||||
// If our wait timed out but the parent is still alive,
|
||||
// fetch a non-blocking snapshot.
|
||||
if ctx.Err() == nil || parentCtx.Err() != nil {
|
||||
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
|
||||
}
|
||||
bgCtx, bgCancel := context.WithTimeout(parentCtx, snapshotTimeout)
|
||||
defer bgCancel()
|
||||
resp, err = conn.ProcessOutput(bgCtx, args.ProcessID, nil)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
|
||||
}
|
||||
// Fall through to normal response handling below.
|
||||
return errorResult(fmt.Sprintf("get process output: %v", err)), nil
|
||||
}
|
||||
output := truncateOutput(resp.Output)
|
||||
exitCode := 0
|
||||
@@ -432,7 +365,7 @@ func ProcessOutput(options ProcessToolOptions) fantasy.AgentTool {
|
||||
Truncated: resp.Truncated,
|
||||
}
|
||||
if resp.Running {
|
||||
// Process is still running, success is not
|
||||
// Process is still running — success is not
|
||||
// yet determined.
|
||||
result.Success = true
|
||||
result.Note = "process is still running"
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestTruncateOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := runForegroundWithOutput(t, "")
|
||||
assert.Empty(t, result.Output)
|
||||
})
|
||||
|
||||
t.Run("ShortOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := runForegroundWithOutput(t, "short")
|
||||
assert.Equal(t, "short", result.Output)
|
||||
})
|
||||
|
||||
t.Run("ExactlyAtLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
output := strings.Repeat("a", maxOutputToModel)
|
||||
result := runForegroundWithOutput(t, output)
|
||||
assert.Equal(t, maxOutputToModel, len(result.Output))
|
||||
assert.Equal(t, output, result.Output)
|
||||
})
|
||||
|
||||
t.Run("OverLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
output := strings.Repeat("b", maxOutputToModel+1024)
|
||||
result := runForegroundWithOutput(t, output)
|
||||
assert.Equal(t, maxOutputToModel, len(result.Output))
|
||||
})
|
||||
|
||||
t.Run("MultiByteCutMidCharacter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Build output that places a 3-byte UTF-8 character
|
||||
// (U+2603, snowman ☃) right at the truncation boundary
|
||||
// so the cut falls mid-character.
|
||||
padding := strings.Repeat("x", maxOutputToModel-1)
|
||||
output := padding + "☃" // ☃ is 3 bytes, only 1 byte fits
|
||||
result := runForegroundWithOutput(t, output)
|
||||
assert.LessOrEqual(t, len(result.Output), maxOutputToModel)
|
||||
assert.True(t, utf8.ValidString(result.Output),
|
||||
"truncated output must be valid UTF-8")
|
||||
})
|
||||
}
|
||||
|
||||
// runForegroundWithOutput runs a foreground command through the
|
||||
// Execute tool with a mock that returns the given output, and
|
||||
// returns the parsed result.
|
||||
func runForegroundWithOutput(t *testing.T, output string) ExecuteResult {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
exitCode := 0
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
Output: output,
|
||||
}, nil)
|
||||
|
||||
tool := Execute(ExecuteOptions{
|
||||
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
},
|
||||
})
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo test"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
return result
|
||||
}
|
||||
@@ -1,493 +0,0 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExecuteTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":""}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "command is required")
|
||||
})
|
||||
|
||||
t.Run("AmpersandDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
runInBackground *bool
|
||||
wantCommand string
|
||||
wantBackground bool
|
||||
wantBackgroundResp bool // true if the response should contain a background_process_id
|
||||
comment string
|
||||
}{
|
||||
{
|
||||
name: "SimpleBackground",
|
||||
command: "cmd &",
|
||||
wantCommand: "cmd",
|
||||
wantBackground: true,
|
||||
wantBackgroundResp: true,
|
||||
comment: "Trailing & is correctly detected and stripped.",
|
||||
},
|
||||
{
|
||||
name: "TrailingDoubleAmpersand",
|
||||
command: "cmd &&",
|
||||
wantCommand: "cmd &&",
|
||||
wantBackground: false,
|
||||
wantBackgroundResp: false,
|
||||
comment: "Ends with &&, excluded by the && suffix check.",
|
||||
},
|
||||
{
|
||||
name: "NoAmpersand",
|
||||
command: "cmd",
|
||||
wantCommand: "cmd",
|
||||
wantBackground: false,
|
||||
wantBackgroundResp: false,
|
||||
},
|
||||
{
|
||||
name: "ChainThenBackground",
|
||||
command: "cmd1 && cmd2 &",
|
||||
wantCommand: "cmd1 && cmd2",
|
||||
wantBackground: true,
|
||||
wantBackgroundResp: true,
|
||||
comment: "Ends with & but not &&, so it gets promoted " +
|
||||
"to background and the trailing & is stripped. " +
|
||||
"The remaining command runs in background mode.",
|
||||
},
|
||||
{
|
||||
// "|&" is bash's pipe-stderr operator, not
|
||||
// backgrounding. It must not be detected as a
|
||||
// trailing "&".
|
||||
name: "BashPipeStderr",
|
||||
command: "cmd |&",
|
||||
wantCommand: "cmd |&",
|
||||
wantBackground: false,
|
||||
wantBackgroundResp: false,
|
||||
},
|
||||
{
|
||||
name: "AlreadyBackgroundWithTrailingAmpersand",
|
||||
command: "cmd &",
|
||||
runInBackground: ptr(true),
|
||||
wantCommand: "cmd &",
|
||||
wantBackground: true,
|
||||
wantBackgroundResp: true,
|
||||
comment: "When run_in_background is already true, " +
|
||||
"the stripping logic is skipped, preserving " +
|
||||
"the original command.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var capturedReq workspacesdk.StartProcessRequest
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
|
||||
capturedReq = req
|
||||
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
|
||||
})
|
||||
|
||||
// For foreground cases, ProcessOutput is polled.
|
||||
exitCode := 0
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
}, nil).
|
||||
AnyTimes()
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
|
||||
input := map[string]any{"command": tc.command}
|
||||
if tc.runInBackground != nil {
|
||||
input["run_in_background"] = *tc.runInBackground
|
||||
}
|
||||
inputJSON, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: string(inputJSON),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError, "response should not be an error")
|
||||
assert.Equal(t, tc.wantCommand, capturedReq.Command,
|
||||
"command passed to StartProcess")
|
||||
assert.Equal(t, tc.wantBackground, capturedReq.Background,
|
||||
"background flag passed to StartProcess")
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
if tc.wantBackgroundResp {
|
||||
assert.NotEmpty(t, result.BackgroundProcessID,
|
||||
"expected background_process_id in response")
|
||||
} else {
|
||||
assert.Empty(t, result.BackgroundProcessID,
|
||||
"expected no background_process_id")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ForegroundSuccess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
var capturedReq workspacesdk.StartProcessRequest
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
|
||||
capturedReq = req
|
||||
return workspacesdk.StartProcessResponse{ID: "proc-1"}, nil
|
||||
})
|
||||
exitCode := 0
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
Output: "hello world",
|
||||
}, nil)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hello"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, 0, result.ExitCode)
|
||||
assert.Equal(t, "hello world", result.Output)
|
||||
assert.Empty(t, result.BackgroundProcessID)
|
||||
assert.Equal(t, "true", capturedReq.Env["CODER_CHAT_AGENT"])
|
||||
})
|
||||
|
||||
t.Run("ForegroundNonZeroExit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
exitCode := 42
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
Output: "something failed",
|
||||
}, nil)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"exit 42"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Equal(t, 42, result.ExitCode)
|
||||
assert.Equal(t, "something failed", result.Output)
|
||||
})
|
||||
|
||||
t.Run("BackgroundExecution", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, req workspacesdk.StartProcessRequest) (workspacesdk.StartProcessResponse, error) {
|
||||
assert.True(t, req.Background)
|
||||
return workspacesdk.StartProcessResponse{ID: "bg-42"}, nil
|
||||
})
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"sleep 999","run_in_background":true}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.True(t, result.Success)
|
||||
assert.Equal(t, "bg-42", result.BackgroundProcessID)
|
||||
})
|
||||
|
||||
t.Run("Timeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
|
||||
// First call (blocking wait) returns context error
|
||||
// because the 50ms timeout expires.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, _ string, _ *workspacesdk.ProcessOutputOptions) (workspacesdk.ProcessOutputResponse, error) {
|
||||
<-ctx.Done()
|
||||
return workspacesdk.ProcessOutputResponse{}, ctx.Err()
|
||||
})
|
||||
// Second call (snapshot fallback) returns partial output.
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: true,
|
||||
Output: "partial output",
|
||||
}, nil)
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
// 50ms timeout expires during the blocking wait.
|
||||
Input: `{"command":"sleep 999","timeout":"50ms"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Equal(t, -1, result.ExitCode)
|
||||
assert.Contains(t, result.Error, "timed out")
|
||||
assert.Equal(t, "partial output", result.Output)
|
||||
})
|
||||
|
||||
t.Run("StartProcessError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{}, xerrors.New("connection lost"))
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hi"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Errors from StartProcess are returned as a JSON body
|
||||
// with success=false, not as a ToolResponse error.
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Contains(t, result.Error, "connection lost")
|
||||
})
|
||||
|
||||
t.Run("ProcessOutputError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{}, xerrors.New("agent disconnected"))
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hi"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsError)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
assert.False(t, result.Success)
|
||||
assert.Contains(t, result.Error, "agent disconnected")
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceConnNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool := chattool.Execute(chattool.ExecuteOptions{
|
||||
GetWorkspaceConn: nil,
|
||||
})
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hi"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "not configured")
|
||||
})
|
||||
|
||||
t.Run("GetWorkspaceConnError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool := chattool.Execute(chattool.ExecuteOptions{
|
||||
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace offline")
|
||||
},
|
||||
})
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: `{"command":"echo hi"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace offline")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetectFileDump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
wantHit bool
|
||||
}{
|
||||
{
|
||||
name: "CatFile",
|
||||
command: "cat foo.txt",
|
||||
wantHit: true,
|
||||
},
|
||||
{
|
||||
name: "NotCatPrefix",
|
||||
command: "concatenate foo",
|
||||
wantHit: false,
|
||||
},
|
||||
{
|
||||
name: "GrepIncludeAll",
|
||||
command: "grep --include-all pattern",
|
||||
wantHit: true,
|
||||
},
|
||||
{
|
||||
name: "RgListFiles",
|
||||
command: "rg -l pattern",
|
||||
wantHit: true,
|
||||
},
|
||||
{
|
||||
name: "GrepRecursive",
|
||||
command: "grep -r pattern",
|
||||
wantHit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().
|
||||
StartProcess(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.StartProcessResponse{ID: "proc-1"}, nil)
|
||||
exitCode := 0
|
||||
mockConn.EXPECT().
|
||||
ProcessOutput(gomock.Any(), "proc-1", gomock.Any()).
|
||||
Return(workspacesdk.ProcessOutputResponse{
|
||||
Running: false,
|
||||
ExitCode: &exitCode,
|
||||
Output: "output",
|
||||
}, nil)
|
||||
|
||||
tool := newExecuteTool(t, mockConn)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
input, err := json.Marshal(map[string]any{
|
||||
"command": tc.command,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "execute",
|
||||
Input: string(input),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var result chattool.ExecuteResult
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
if tc.wantHit {
|
||||
assert.Contains(t, result.Note, "read_file",
|
||||
"expected advisory note for %q", tc.command)
|
||||
} else {
|
||||
assert.Empty(t, result.Note,
|
||||
"expected no note for %q", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newExecuteTool creates an Execute tool wired to the given mock.
|
||||
func newExecuteTool(t *testing.T, mockConn *agentconnmock.MockAgentConn) fantasy.AgentTool {
|
||||
t.Helper()
|
||||
return chattool.Execute(chattool.ExecuteOptions{
|
||||
GetWorkspaceConn: func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -272,156 +272,6 @@ func logMessages(t *testing.T, msgs []codersdk.ChatMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIReasoningRoundTrip is an integration test that verifies
|
||||
// reasoning items from OpenAI's Responses API survive the full
|
||||
// persist → reconstruct → re-send cycle when Store: true. It sends
|
||||
// a query to a reasoning model, waits for completion, then sends a
|
||||
// follow-up message. If reasoning items are sent back without their
|
||||
// required following output item, the API rejects the second request:
|
||||
//
|
||||
// Item 'rs_xxx' of type 'reasoning' was provided without its
|
||||
// required following item.
|
||||
//
|
||||
// The test requires OPENAI_API_KEY to be set.
|
||||
func TestOpenAIReasoningRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Skip("OPENAI_API_KEY not set; skipping OpenAI integration test")
|
||||
}
|
||||
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
|
||||
// Stand up a full coderd with the agents experiment.
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Configure an OpenAI provider with the real API key.
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a model config for a reasoning model with Store: true
|
||||
// (the default). Using o4-mini because it always produces
|
||||
// reasoning items.
|
||||
contextLimit := int64(200000)
|
||||
isDefault := true
|
||||
reasoningSummary := "auto"
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
Model: "o4-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
ModelConfig: &codersdk.ChatModelCallConfig{
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
Store: ptr.Ref(true),
|
||||
ReasoningSummary: &reasoningSummary,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// --- Step 1: Send a message that triggers reasoning ---
|
||||
t.Log("Creating chat with reasoning query...")
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "What is 2+2? Be brief.",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat created: %s (status=%s)", chat.ID, chat.Status)
|
||||
|
||||
// Stream events until the chat reaches a terminal status.
|
||||
events, closer, err := client.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer closer.Close()
|
||||
|
||||
waitForChatDone(ctx, t, events, "step 1")
|
||||
|
||||
// Verify the chat completed and messages were persisted.
|
||||
chatData, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 1: %s, messages: %d",
|
||||
chatData.Status, len(chatMsgs.Messages))
|
||||
logMessages(t, chatMsgs.Messages)
|
||||
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status,
|
||||
"chat should be in waiting status after step 1")
|
||||
|
||||
// Verify the assistant message has reasoning content.
|
||||
assistantMsg := findAssistantWithText(t, chatMsgs.Messages)
|
||||
require.NotNil(t, assistantMsg,
|
||||
"expected an assistant message with text content after step 1")
|
||||
|
||||
partTypes := partTypeSet(assistantMsg.Content)
|
||||
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeReasoning,
|
||||
"assistant message should contain reasoning parts from o4-mini")
|
||||
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeText,
|
||||
"assistant message should contain a text part")
|
||||
|
||||
// --- Step 2: Send a follow-up message ---
|
||||
// This is the critical test: if reasoning items are sent back
|
||||
// without their required following item, the API will reject
|
||||
// the request with:
|
||||
// Item 'rs_xxx' of type 'reasoning' was provided without its
|
||||
// required following item.
|
||||
t.Log("Sending follow-up message...")
|
||||
_, err = client.CreateChatMessage(ctx, chat.ID,
|
||||
codersdk.CreateChatMessageRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "And what is 3+3? Be brief.",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Stream the follow-up response.
|
||||
events2, closer2, err := client.StreamChat(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer closer2.Close()
|
||||
|
||||
waitForChatDone(ctx, t, events2, "step 2")
|
||||
|
||||
// Verify the follow-up completed and produced content.
|
||||
chatData2, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 2: %s, messages: %d",
|
||||
chatData2.Status, len(chatMsgs2.Messages))
|
||||
logMessages(t, chatMsgs2.Messages)
|
||||
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status,
|
||||
"chat should be in waiting status after step 2")
|
||||
require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages),
|
||||
"follow-up should have added more messages")
|
||||
|
||||
// The last assistant message should have text.
|
||||
lastAssistant := findLastAssistantWithText(t, chatMsgs2.Messages)
|
||||
require.NotNil(t, lastAssistant,
|
||||
"expected an assistant message with text in the follow-up")
|
||||
|
||||
t.Log("OpenAI reasoning round-trip test passed.")
|
||||
}
|
||||
|
||||
// partTypeSet returns the set of part types present in a message.
|
||||
func partTypeSet(parts []codersdk.ChatMessagePart) map[codersdk.ChatMessagePartType]struct{} {
|
||||
set := make(map[codersdk.ChatMessagePartType]struct{}, len(parts))
|
||||
|
||||
@@ -62,7 +62,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
@@ -112,7 +111,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,14 +84,6 @@ func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
enabled, err := p.db.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
tools := []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
@@ -261,8 +253,9 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
}
|
||||
|
||||
// Only include the computer use tool when an Anthropic
|
||||
// provider is configured and desktop is enabled.
|
||||
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
|
||||
// provider is configured, since it requires an Anthropic
|
||||
// model.
|
||||
if p.isAnthropicConfigured(ctx) {
|
||||
tools = append(tools, fantasy.NewAgentTool(
|
||||
"spawn_computer_use_agent",
|
||||
"Spawn a dedicated computer use agent that can see the desktop "+
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
@@ -145,20 +144,14 @@ func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatdTestContext(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// No Anthropic key in ProviderAPIKeys.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -183,13 +176,12 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the provider check passes.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
@@ -240,42 +232,16 @@ func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-desktop-disabled",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the tool can proceed.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
@@ -332,139 +298,3 @@ func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
|
||||
assert.NotEmpty(t, chattool.ComputerUseModelName)
|
||||
}
|
||||
|
||||
func TestIsSubagentDescendant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Build a chain: root -> child -> grandchild.
|
||||
root, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "root",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("root")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
child, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: root.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: root.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "child",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("child")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
grandchild, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: child.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: root.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "grandchild",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("grandchild")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build a separate, unrelated chain.
|
||||
unrelated, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "unrelated-root",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
unrelatedChild, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: unrelated.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: unrelated.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "unrelated-child",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("unrelated-child")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ancestor uuid.UUID
|
||||
target uuid.UUID
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "SameID",
|
||||
ancestor: root.ID,
|
||||
target: root.ID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "DirectChild",
|
||||
ancestor: root.ID,
|
||||
target: child.ID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "GrandChild",
|
||||
ancestor: root.ID,
|
||||
target: grandchild.ID,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Unrelated",
|
||||
ancestor: root.ID,
|
||||
target: unrelatedChild.ID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "RootChat",
|
||||
ancestor: child.ID,
|
||||
target: root.ID,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "BrokenChain",
|
||||
ancestor: root.ID,
|
||||
target: uuid.New(),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "NotDescendant",
|
||||
ancestor: unrelated.ID,
|
||||
target: child.ID,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := chatdTestContext(t)
|
||||
got, err := isSubagentDescendant(ctx, db, tt.ancestor, tt.target)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+58
-396
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -190,7 +189,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
params := database.GetChatsParams{
|
||||
params := database.GetChatsByOwnerIDParams{
|
||||
OwnerID: apiKey.UserID,
|
||||
Archived: searchParams.Archived,
|
||||
AfterID: paginationParams.AfterID,
|
||||
@@ -200,7 +199,7 @@ func (api *API) listChats(rw http.ResponseWriter, r *http.Request) {
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
}
|
||||
|
||||
chats, err := api.Database.GetChats(ctx, params)
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, params)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to list chats.",
|
||||
@@ -284,41 +283,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate MCP server IDs exist.
|
||||
if len(req.MCPServerIDs) > 0 {
|
||||
//nolint:gocritic // Need to validate MCP server IDs exist.
|
||||
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), req.MCPServerIDs)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to validate MCP server IDs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(existingConfigs) != len(req.MCPServerIDs) {
|
||||
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
||||
for _, c := range existingConfigs {
|
||||
found[c.ID] = struct{}{}
|
||||
}
|
||||
var missing []string
|
||||
for _, id := range req.MCPServerIDs {
|
||||
if _, ok := found[id]; !ok {
|
||||
missing = append(missing, id.String())
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "One or more MCP server IDs are invalid.",
|
||||
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mcpServerIDs := req.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -326,7 +290,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ModelConfigID: modelConfigID,
|
||||
SystemPrompt: api.resolvedChatSystemPrompt(ctx),
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -1406,58 +1369,64 @@ func (api *API) watchChatDesktop(rw http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug(ctx, "desktop Bicopy finished")
|
||||
}
|
||||
|
||||
// patchChat updates a chat resource. Currently supports toggling the
|
||||
// archived state via the Archived field.
|
||||
func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) archiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
var req codersdk.UpdateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
if chat.Archived {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat is already archived.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Archived != nil {
|
||||
archived := *req.Archived
|
||||
if archived == chat.Archived {
|
||||
state := "archived"
|
||||
if !archived {
|
||||
state = "not archived"
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Chat is already %s.", state),
|
||||
})
|
||||
return
|
||||
}
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple archive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to archive chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify active
|
||||
// subscribers. Fall back to direct DB for the simple
|
||||
// archive flag — no streaming state is involved.
|
||||
if archived {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.ArchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.ArchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
} else {
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
action := "archive"
|
||||
if !archived {
|
||||
action = "unarchive"
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: fmt.Sprintf("Failed to %s chat.", action),
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (api *API) unarchiveChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
if !chat.Archived {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Chat is not archived.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
// Use chatDaemon when available so it can notify
|
||||
// active subscribers. Fall back to direct DB for the
|
||||
// simple unarchive flag — no streaming state is involved.
|
||||
if api.chatDaemon != nil {
|
||||
err = api.chatDaemon.UnarchiveChat(ctx, chat.ID)
|
||||
} else {
|
||||
err = api.Database.UnarchiveChatByID(ctx, chat.ID)
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to unarchive chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
@@ -1492,36 +1461,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate MCP server IDs exist.
|
||||
if req.MCPServerIDs != nil && len(*req.MCPServerIDs) > 0 {
|
||||
//nolint:gocritic // Need to validate MCP server IDs exist.
|
||||
existingConfigs, err := api.Database.GetMCPServerConfigsByIDs(dbauthz.AsSystemRestricted(ctx), *req.MCPServerIDs)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to validate MCP server IDs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(existingConfigs) != len(*req.MCPServerIDs) {
|
||||
found := make(map[uuid.UUID]struct{}, len(existingConfigs))
|
||||
for _, c := range existingConfigs {
|
||||
found[c.ID] = struct{}{}
|
||||
}
|
||||
var missing []string
|
||||
for _, id := range *req.MCPServerIDs {
|
||||
if _, ok := found[id]; !ok {
|
||||
missing = append(missing, id.String())
|
||||
}
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "One or more MCP server IDs are invalid.",
|
||||
Detail: fmt.Sprintf("Invalid IDs: %s", strings.Join(missing, ", ")),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sendResult, sendErr := api.chatDaemon.SendMessage(
|
||||
ctx,
|
||||
chatd.SendMessageOptions{
|
||||
@@ -1530,7 +1469,6 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: req.ModelConfigID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
MCPServerIDs: req.MCPServerIDs,
|
||||
},
|
||||
)
|
||||
if sendErr != nil {
|
||||
@@ -2587,14 +2525,14 @@ func (api *API) getChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatSystemPromptResponse{
|
||||
SystemPrompt: prompt,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var req codersdk.ChatSystemPrompt
|
||||
var req codersdk.UpdateChatSystemPromptRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
@@ -2622,49 +2560,6 @@ func (api *API) putChatSystemPrompt(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
enabled, err := api.Database.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching desktop setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatDesktopEnabledResponse{
|
||||
EnableDesktop: enabled,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
func (api *API) putChatDesktopEnabled(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateChatDesktopEnabledRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatDesktopEnabled(ctx, req.EnableDesktop); httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
} else if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating desktop setting.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -2687,7 +2582,7 @@ func (api *API) getUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
customPrompt = ""
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
|
||||
CustomPrompt: customPrompt,
|
||||
})
|
||||
}
|
||||
@@ -2699,7 +2594,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
apiKey = httpmw.APIKey(r)
|
||||
)
|
||||
|
||||
var params codersdk.UserChatCustomPrompt
|
||||
var params codersdk.UpdateUserChatCustomPromptRequest
|
||||
if !httpapi.Read(ctx, rw, r, ¶ms) {
|
||||
return
|
||||
}
|
||||
@@ -2726,7 +2621,7 @@ func (api *API) putUserChatCustomPrompt(rw http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPrompt{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserChatCustomPromptResponse{
|
||||
CustomPrompt: updatedConfig.Value,
|
||||
})
|
||||
}
|
||||
@@ -3046,10 +2941,6 @@ func truncateRunes(value string, maxLen int) string {
|
||||
}
|
||||
|
||||
func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.Chat {
|
||||
mcpServerIDs := c.MCPServerIDs
|
||||
if mcpServerIDs == nil {
|
||||
mcpServerIDs = []uuid.UUID{}
|
||||
}
|
||||
chat := codersdk.Chat{
|
||||
ID: c.ID,
|
||||
OwnerID: c.OwnerID,
|
||||
@@ -3059,7 +2950,6 @@ func convertChat(c database.Chat, diffStatus *database.ChatDiffStatus) codersdk.
|
||||
Archived: c.Archived,
|
||||
CreatedAt: c.CreatedAt,
|
||||
UpdatedAt: c.UpdatedAt,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
}
|
||||
if c.LastError.Valid {
|
||||
chat.LastError = &c.LastError.String
|
||||
@@ -3485,16 +3375,6 @@ func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := api.Database.DeleteChatProviderByID(ctx, providerID); err != nil {
|
||||
if database.IsForeignKeyViolation(err,
|
||||
database.ForeignKeyChatMessagesModelConfigID,
|
||||
database.ForeignKeyChatsLastModelConfigID,
|
||||
) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Provider models are still referenced by existing chats.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to delete chat provider.",
|
||||
Detail: err.Error(),
|
||||
@@ -4262,221 +4142,3 @@ func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider databas
|
||||
)
|
||||
return effectiveKeys.APIKey(provider.Provider) != ""
|
||||
}
|
||||
|
||||
// @Summary Get PR insights
|
||||
// @ID get-pr-insights
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Produce json
|
||||
// @Param start_date query string true "Start date (RFC3339)"
|
||||
// @Param end_date query string true "End date (RFC3339)"
|
||||
// @Success 200 {object} codersdk.PRInsightsResponse
|
||||
// @Router /chats/insights/pull-requests [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Admin-only endpoint.
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse date range.
|
||||
now := time.Now()
|
||||
defaultStart := now.AddDate(0, 0, -30)
|
||||
|
||||
qp := r.URL.Query()
|
||||
p := httpapi.NewQueryParamParser()
|
||||
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
||||
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
||||
p.ErrorExcessParams(qp)
|
||||
if len(p.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid query parameters.",
|
||||
Validations: p.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate previous period of equal length for trend comparison.
|
||||
duration := endDate.Sub(startDate)
|
||||
prevStart := startDate.Add(-duration)
|
||||
|
||||
// No owner filter — admin sees all data.
|
||||
ownerID := uuid.NullUUID{}
|
||||
|
||||
// Run all queries in parallel.
|
||||
var (
|
||||
currentSummary database.GetPRInsightsSummaryRow
|
||||
previousSummary database.GetPRInsightsSummaryRow
|
||||
timeSeries []database.GetPRInsightsTimeSeriesRow
|
||||
byModel []database.GetPRInsightsPerModelRow
|
||||
recentPRs []database.GetPRInsightsRecentPRsRow
|
||||
)
|
||||
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
eg.SetLimit(5)
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
currentSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
previousSummary, err = api.Database.GetPRInsightsSummary(egCtx, database.GetPRInsightsSummaryParams{
|
||||
StartDate: prevStart,
|
||||
EndDate: startDate,
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
timeSeries, err = api.Database.GetPRInsightsTimeSeries(egCtx, database.GetPRInsightsTimeSeriesParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
byModel, err = api.Database.GetPRInsightsPerModel(egCtx, database.GetPRInsightsPerModelParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
LimitVal: 20,
|
||||
})
|
||||
return err
|
||||
})
|
||||
|
||||
if err := eg.Wait(); err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build summary with computed fields.
|
||||
summary := codersdk.PRInsightsSummary{
|
||||
TotalPRsCreated: currentSummary.TotalPrsCreated,
|
||||
TotalPRsMerged: currentSummary.TotalPrsMerged,
|
||||
TotalAdditions: currentSummary.TotalAdditions,
|
||||
TotalDeletions: currentSummary.TotalDeletions,
|
||||
TotalCostMicros: currentSummary.TotalCostMicros,
|
||||
PrevTotalPRsCreated: previousSummary.TotalPrsCreated,
|
||||
PrevTotalPRsMerged: previousSummary.TotalPrsMerged,
|
||||
}
|
||||
if summary.TotalPRsCreated > 0 {
|
||||
summary.MergeRate = float64(summary.TotalPRsMerged) / float64(summary.TotalPRsCreated)
|
||||
}
|
||||
if summary.TotalPRsMerged > 0 {
|
||||
summary.CostPerMergedPRMicros = currentSummary.MergedCostMicros / summary.TotalPRsMerged
|
||||
}
|
||||
if summary.PrevTotalPRsCreated > 0 {
|
||||
summary.PrevMergeRate = float64(summary.PrevTotalPRsMerged) / float64(summary.PrevTotalPRsCreated)
|
||||
}
|
||||
if summary.PrevTotalPRsMerged > 0 {
|
||||
summary.PrevCostPerMergedPRMicros = previousSummary.MergedCostMicros / summary.PrevTotalPRsMerged
|
||||
}
|
||||
|
||||
// Convert time series.
|
||||
tsEntries := make([]codersdk.PRInsightsTimeSeriesEntry, 0, len(timeSeries))
|
||||
for _, ts := range timeSeries {
|
||||
tsEntries = append(tsEntries, codersdk.PRInsightsTimeSeriesEntry{
|
||||
Date: ts.Date,
|
||||
PRsCreated: ts.PrsCreated,
|
||||
PRsMerged: ts.PrsMerged,
|
||||
PRsClosed: ts.PrsClosed,
|
||||
})
|
||||
}
|
||||
|
||||
// Convert model breakdown.
|
||||
modelEntries := make([]codersdk.PRInsightsModelBreakdown, 0, len(byModel))
|
||||
for _, m := range byModel {
|
||||
entry := codersdk.PRInsightsModelBreakdown{
|
||||
ModelConfigID: m.ModelConfigID.UUID,
|
||||
DisplayName: m.DisplayName,
|
||||
Provider: m.Provider,
|
||||
TotalPRs: m.TotalPrs,
|
||||
MergedPRs: m.MergedPrs,
|
||||
TotalAdditions: m.TotalAdditions,
|
||||
TotalDeletions: m.TotalDeletions,
|
||||
TotalCostMicros: m.TotalCostMicros,
|
||||
}
|
||||
if entry.TotalPRs > 0 {
|
||||
entry.MergeRate = float64(entry.MergedPRs) / float64(entry.TotalPRs)
|
||||
}
|
||||
if entry.MergedPRs > 0 {
|
||||
entry.CostPerMergedPRMicros = m.MergedCostMicros / entry.MergedPRs
|
||||
}
|
||||
modelEntries = append(modelEntries, entry)
|
||||
}
|
||||
|
||||
// Convert recent PRs.
|
||||
prEntries := make([]codersdk.PRInsightsPullRequest, 0, len(recentPRs))
|
||||
for _, pr := range recentPRs {
|
||||
entry := codersdk.PRInsightsPullRequest{
|
||||
ChatID: pr.ChatID,
|
||||
PRTitle: pr.PrTitle,
|
||||
Draft: pr.Draft,
|
||||
Additions: pr.Additions,
|
||||
Deletions: pr.Deletions,
|
||||
ChangedFiles: pr.ChangedFiles,
|
||||
ChangesRequested: pr.ChangesRequested,
|
||||
BaseBranch: pr.BaseBranch,
|
||||
ModelDisplayName: pr.ModelDisplayName,
|
||||
CostMicros: pr.CostMicros,
|
||||
CreatedAt: pr.CreatedAt,
|
||||
}
|
||||
if pr.PrUrl.Valid {
|
||||
entry.PRURL = &pr.PrUrl.String
|
||||
}
|
||||
if pr.PrNumber.Valid {
|
||||
entry.PRNumber = &pr.PrNumber.Int32
|
||||
}
|
||||
if pr.State.Valid {
|
||||
entry.State = pr.State.String
|
||||
}
|
||||
if pr.Commits.Valid {
|
||||
entry.Commits = &pr.Commits.Int32
|
||||
}
|
||||
if pr.Approved.Valid {
|
||||
entry.Approved = &pr.Approved.Bool
|
||||
}
|
||||
if pr.ReviewerCount.Valid {
|
||||
entry.ReviewerCount = &pr.ReviewerCount.Int32
|
||||
}
|
||||
if pr.AuthorLogin.Valid {
|
||||
entry.AuthorLogin = &pr.AuthorLogin.String
|
||||
}
|
||||
if pr.AuthorAvatarUrl.Valid {
|
||||
entry.AuthorAvatarURL = &pr.AuthorAvatarUrl.String
|
||||
}
|
||||
prEntries = append(prEntries, entry)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
RecentPRs: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
+95
-302
@@ -6,7 +6,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
@@ -30,7 +29,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
@@ -128,24 +126,22 @@ func insertAssistantCostMessage(
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chatID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfigID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Content: []string{string(assistantContent.RawMessage)},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{totalCostMicros},
|
||||
RuntimeMs: []int64{0},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
TotalTokens: sql.NullInt64{},
|
||||
ReasoningTokens: sql.NullInt64{},
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{Int64: totalCostMicros, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1691,7 +1687,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, chatsBeforeArchive, 2)
|
||||
|
||||
err = client.UpdateChat(ctx, chatToArchive.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, chatToArchive.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default (no filter) returns only non-archived chats.
|
||||
@@ -1725,7 +1721,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err := client.ArchiveChat(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
@@ -1768,7 +1764,7 @@ func TestArchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the parent via the API.
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// archived:false should exclude the entire archived family.
|
||||
@@ -1815,7 +1811,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat first.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
err = client.ArchiveChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's archived.
|
||||
@@ -1826,7 +1822,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.Len(t, archivedChats, 1)
|
||||
require.True(t, archivedChats[0].Archived)
|
||||
// Unarchive the chat.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's no longer archived.
|
||||
@@ -1865,9 +1861,10 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trying to unarchive a non-archived chat should fail.
|
||||
err = client.UpdateChat(ctx, chat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err = client.UnarchiveChat(ctx, chat.ID)
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1875,7 +1872,7 @@ func TestUnarchiveChat(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.UpdateChat(ctx, uuid.New(), codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
err := client.UnarchiveChat(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
@@ -2663,9 +2660,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// The edited message is soft-deleted and a new one is inserted,
|
||||
// so the returned ID will differ from the original.
|
||||
require.NotEqual(t, userMessageID, edited.ID)
|
||||
require.Equal(t, userMessageID, edited.ID)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, edited.Role)
|
||||
|
||||
foundEditedText := false
|
||||
@@ -2755,9 +2750,7 @@ func TestPatchChatMessage(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// The edited message is soft-deleted and a new one is inserted,
|
||||
// so the returned ID will differ from the original.
|
||||
require.NotEqual(t, userMessageID, edited.ID)
|
||||
require.Equal(t, userMessageID, edited.ID)
|
||||
|
||||
// Assert the edit response preserves the file_id.
|
||||
var foundText, foundFile bool
|
||||
@@ -3991,30 +3984,6 @@ func TestGetChatFile(t *testing.T) {
|
||||
require.NotContains(t, cd, strings.Repeat("a", 256))
|
||||
})
|
||||
|
||||
t.Run("UnicodeFilename", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Upload with a non-ASCII filename using RFC 5987 encoding,
|
||||
// which is what the frontend sends for Unicode filenames.
|
||||
data := append([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, make([]byte, 64)...)
|
||||
uploaded, err := client.UploadChatFile(ctx, firstUser.OrganizationID, "image/png", "スクリーンショット.png", bytes.NewReader(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(ctx, http.MethodGet,
|
||||
fmt.Sprintf("/api/experimental/chats/files/%s", uploaded.ID), nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
cd := res.Header.Get("Content-Disposition")
|
||||
require.Contains(t, cd, "inline")
|
||||
_, params, err := mime.ParseMediaType(cd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "スクリーンショット.png", params["filename"])
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -4088,35 +4057,24 @@ func seedChatCostFixture(t *testing.T) chatCostTestFixture {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID, modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant", "assistant"},
|
||||
Content: []string{"null", "null"},
|
||||
ContentVersion: []int16{0, 0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100, 100},
|
||||
OutputTokens: []int64{50, 50},
|
||||
TotalTokens: []int64{0, 0},
|
||||
ReasoningTokens: []int64{0, 0},
|
||||
CacheCreationTokens: []int64{0, 0},
|
||||
CacheReadTokens: []int64{0, 0},
|
||||
ContextLimit: []int64{0, 0},
|
||||
Compressed: []bool{false, false},
|
||||
TotalCostMicros: []int64{500, 500},
|
||||
RuntimeMs: []int64{0, 0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 2)
|
||||
earliestCreatedAt := results[0].CreatedAt
|
||||
latestCreatedAt := results[0].CreatedAt
|
||||
for _, msg := range results {
|
||||
if msg.CreatedAt.Before(earliestCreatedAt) {
|
||||
earliestCreatedAt = msg.CreatedAt
|
||||
var earliestCreatedAt time.Time
|
||||
var latestCreatedAt time.Time
|
||||
for i := 0; i < 2; i++ {
|
||||
message, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if i == 0 || message.CreatedAt.Before(earliestCreatedAt) {
|
||||
earliestCreatedAt = message.CreatedAt
|
||||
}
|
||||
if msg.CreatedAt.After(latestCreatedAt) {
|
||||
latestCreatedAt = msg.CreatedAt
|
||||
if i == 0 || message.CreatedAt.After(latestCreatedAt) {
|
||||
latestCreatedAt = message.CreatedAt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4204,27 +4162,16 @@ func TestChatCostSummary_AdminDrilldown(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{200},
|
||||
OutputTokens: []int64{100},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{750},
|
||||
RuntimeMs: []int64{0},
|
||||
message, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
message := results[0]
|
||||
options := codersdk.ChatCostSummaryOptions{
|
||||
// Pad the DB-assigned timestamp so the query window cannot race it.
|
||||
StartDate: message.CreatedAt.Add(-time.Minute),
|
||||
@@ -4270,24 +4217,14 @@ func TestChatCostUsers(t *testing.T) {
|
||||
Title: "admin chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
|
||||
ChatID: adminChat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100},
|
||||
OutputTokens: []int64{50},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{300},
|
||||
RuntimeMs: []int64{0},
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: adminChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -4297,24 +4234,14 @@ func TestChatCostUsers(t *testing.T) {
|
||||
Title: "member chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
|
||||
ChatID: memberChat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{200},
|
||||
OutputTokens: []int64{100},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{800},
|
||||
RuntimeMs: []int64{0},
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: memberChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -4381,24 +4308,14 @@ func TestChatCostSummary_DateRange(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100},
|
||||
OutputTokens: []int64{50},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{500},
|
||||
RuntimeMs: []int64{0},
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -4446,49 +4363,27 @@ func TestChatCostSummary_UnpricedMessages(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100},
|
||||
OutputTokens: []int64{50},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{500},
|
||||
RuntimeMs: []int64{0},
|
||||
pricedMessage, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
pricedMessage := pricedResults[0]
|
||||
|
||||
unpricedResults, err := db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{"assistant"},
|
||||
Content: []string{"null"},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{200},
|
||||
OutputTokens: []int64{75},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
unpricedMessage, err := db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
unpricedMessage := unpricedResults[0]
|
||||
|
||||
earliestCreatedAt := pricedMessage.CreatedAt
|
||||
latestCreatedAt := pricedMessage.CreatedAt
|
||||
@@ -4567,7 +4462,7 @@ func TestWatchChatDesktop(t *testing.T) {
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/stream/desktop", createdChat.ID),
|
||||
fmt.Sprintf("/api/experimental/chats/%s/desktop", createdChat.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -4617,7 +4512,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("AdminCanSet", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "You are a helpful coding assistant.",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4631,7 +4526,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Unset by sending an empty string.
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -4644,7 +4539,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
t.Run("NonAdminFails", func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := memberClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: "This should fail.",
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
@@ -4665,7 +4560,7 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
tooLong := strings.Repeat("a", 131073)
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.ChatSystemPrompt{
|
||||
err := adminClient.UpdateChatSystemPrompt(ctx, codersdk.UpdateChatSystemPromptRequest{
|
||||
SystemPrompt: tooLong,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
@@ -4673,108 +4568,6 @@ func TestChatSystemPrompt(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatDesktopEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsFalseWhenUnset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetTrue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("AdminCanSetFalse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
// Set true first, then set false.
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := adminClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminCanRead", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := adminClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := memberClient.GetChatDesktopEnabled(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.EnableDesktop)
|
||||
})
|
||||
|
||||
t.Run("NonAdminWriteFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
err := memberClient.UpdateChatDesktopEnabled(ctx, codersdk.UpdateChatDesktopEnabledRequest{
|
||||
EnableDesktop: true,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("UnauthenticatedFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
coderdtest.CreateFirstUser(t, adminClient)
|
||||
|
||||
anonClient := codersdk.New(adminClient.URL)
|
||||
_, err := anonClient.GetChatDesktopEnabled(ctx)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusUnauthorized, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+17
-55
@@ -10,7 +10,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
httppprof "net/http/pprof"
|
||||
"net/url"
|
||||
@@ -767,27 +766,17 @@ func New(options *Options) *API {
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
@@ -1044,12 +1033,10 @@ func New(options *Options) *API {
|
||||
|
||||
// OAuth2 metadata endpoint for RFC 8414 discovery
|
||||
r.Route("/.well-known/oauth-authorization-server", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
|
||||
r.Get("/*", api.oauth2AuthorizationServerMetadata())
|
||||
})
|
||||
// OAuth2 protected resource metadata endpoint for RFC 9728 discovery
|
||||
r.Route("/.well-known/oauth-protected-resource", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2))
|
||||
r.Get("/*", api.oauth2ProtectedResourceMetadata())
|
||||
})
|
||||
|
||||
@@ -1162,9 +1149,6 @@ func New(options *Options) *API {
|
||||
r.Get("/summary", api.chatCostSummary)
|
||||
})
|
||||
})
|
||||
r.Route("/insights", func(r chi.Router) {
|
||||
r.Get("/pull-requests", api.prInsights)
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
@@ -1173,8 +1157,6 @@ func New(options *Options) *API {
|
||||
r.Route("/config", func(r chi.Router) {
|
||||
r.Get("/system-prompt", api.getChatSystemPrompt)
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
})
|
||||
@@ -1212,15 +1194,14 @@ func New(options *Options) *API {
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Patch("/", api.patchChat)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Get("/messages", api.getChatMessages)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Patch("/messages/{message}", api.patchChatMessage)
|
||||
r.Route("/stream", func(r chi.Router) {
|
||||
r.Get("/", api.streamChat)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Get("/stream", api.streamChat)
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
@@ -1233,27 +1214,10 @@ func New(options *Options) *API {
|
||||
r.Route("/mcp", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP),
|
||||
)
|
||||
// MCP server configuration endpoints.
|
||||
r.Route("/servers", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentAgents))
|
||||
r.Get("/", api.listMCPServerConfigs)
|
||||
r.Post("/", api.createMCPServerConfig)
|
||||
r.Route("/{mcpServer}", func(r chi.Router) {
|
||||
r.Get("/", api.getMCPServerConfig)
|
||||
r.Patch("/", api.updateMCPServerConfig)
|
||||
r.Delete("/", api.deleteMCPServerConfig)
|
||||
// OAuth2 user flow
|
||||
r.Get("/oauth2/connect", api.mcpServerOAuth2Connect)
|
||||
r.Get("/oauth2/callback", api.mcpServerOAuth2Callback)
|
||||
r.Delete("/oauth2/disconnect", api.mcpServerOAuth2Disconnect)
|
||||
})
|
||||
})
|
||||
// MCP HTTP transport endpoint with mandatory authentication
|
||||
r.Route("/http", func(r chi.Router) {
|
||||
r.Use(httpmw.RequireExperimentWithDevBypass(api.Experiments, codersdk.ExperimentOAuth2, codersdk.ExperimentMCPServerHTTP))
|
||||
r.Mount("/", api.mcpHTTPHandler())
|
||||
})
|
||||
r.Mount("/http", api.mcpHTTPHandler())
|
||||
})
|
||||
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
|
||||
r.Use(
|
||||
@@ -1515,8 +1479,6 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postUser)
|
||||
r.Get("/", api.users)
|
||||
r.Post("/logout", api.postLogout)
|
||||
r.Post("/me/session/token-to-cookie", api.postSessionTokenCookie)
|
||||
r.Get("/oidc-claims", api.userOIDCClaims)
|
||||
// These routes query information about site wide roles.
|
||||
r.Route("/roles", func(r chi.Router) {
|
||||
r.Get("/", api.AssignableSiteRoles)
|
||||
|
||||
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
|
||||
m(&req)
|
||||
}
|
||||
|
||||
// Service accounts cannot have a password or email and must
|
||||
// use login_type=none. Enforce this after mutators so callers
|
||||
// only need to set ServiceAccount=true.
|
||||
if req.ServiceAccount {
|
||||
req.Password = ""
|
||||
req.Email = ""
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
}
|
||||
|
||||
user, err := client.CreateUserWithOrgs(context.Background(), req)
|
||||
var apiError *codersdk.Error
|
||||
// If the user already exists by username or email conflict, try again up to "retries" times.
|
||||
|
||||
@@ -13,64 +13,32 @@ var _ usage.Inserter = (*UsageInserter)(nil)
|
||||
|
||||
type UsageInserter struct {
|
||||
sync.Mutex
|
||||
discreteEvents []usagetypes.DiscreteEvent
|
||||
heartbeatEvents []usagetypes.HeartbeatEvent
|
||||
seenHeartbeats map[string]struct{}
|
||||
events []usagetypes.DiscreteEvent
|
||||
}
|
||||
|
||||
func NewUsageInserter() *UsageInserter {
|
||||
return &UsageInserter{
|
||||
discreteEvents: []usagetypes.DiscreteEvent{},
|
||||
seenHeartbeats: map[string]struct{}{},
|
||||
heartbeatEvents: []usagetypes.HeartbeatEvent{},
|
||||
events: []usagetypes.DiscreteEvent{},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.discreteEvents = append(u.discreteEvents, event)
|
||||
u.events = append(u.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
|
||||
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
if _, seen := u.seenHeartbeats[id]; seen {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.seenHeartbeats[id] = struct{}{}
|
||||
u.heartbeatEvents = append(u.heartbeatEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
|
||||
copy(eventsCopy, u.heartbeatEvents)
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
|
||||
copy(eventsCopy, u.events)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
|
||||
copy(eventsCopy, u.discreteEvents)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) TotalEventCount() int {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
return len(u.discreteEvents) + len(u.heartbeatEvents)
|
||||
}
|
||||
|
||||
func (u *UsageInserter) Reset() {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.seenHeartbeats = map[string]struct{}{}
|
||||
u.discreteEvents = []usagetypes.DiscreteEvent{}
|
||||
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
|
||||
u.events = []usagetypes.DiscreteEvent{}
|
||||
}
|
||||
|
||||
@@ -20,9 +20,6 @@ const (
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs
|
||||
CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs
|
||||
CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
|
||||
@@ -195,14 +195,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
|
||||
|
||||
func ReducedUser(user database.User) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
IsServiceAccount: user.IsServiceAccount,
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -705,7 +705,7 @@ var (
|
||||
DisplayName: "Chat Daemon",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate},
|
||||
rbac.ResourceWorkspace.Type: {policy.ActionRead},
|
||||
rbac.ResourceDeploymentConfig.Type: {policy.ActionRead},
|
||||
rbac.ResourceUser.Type: {policy.ActionReadPersonal},
|
||||
}),
|
||||
@@ -769,9 +769,6 @@ func AsSubAgentAPI(ctx context.Context, orgID uuid.UUID, userID uuid.UUID) conte
|
||||
|
||||
// AsSystemRestricted returns a context with an actor that has permissions
|
||||
// required for various system operations (login, logout, metrics cache).
|
||||
// DO NOT USE THIS UNLESS YOU HAVE ABSOLUTELY NO OTHER CHOICE. Prefer using a
|
||||
// more specific As* helper above (or adding a new, narrowly-scoped one) so
|
||||
// that permissions remain limited to the operation you need.
|
||||
func AsSystemRestricted(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectSystemRestricted)
|
||||
}
|
||||
@@ -1267,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
|
||||
// System roles are stored in the database but have a fixed, code-defined
|
||||
// meaning. Do not rewrite the name for them so the static "who can assign
|
||||
// what" mapping applies.
|
||||
if !rolestore.IsSystemRoleName(roleName.Name) {
|
||||
if !rbac.SystemRoleName(roleName.Name) {
|
||||
// To support a dynamic mapping of what roles can assign what, we need
|
||||
// to store this in the database. For now, just use a static role so
|
||||
// owners and org admins can assign roles.
|
||||
@@ -1694,13 +1691,6 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return q.db.CleanTailnetTunnels(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -1834,6 +1824,18 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -1930,20 +1932,6 @@ func (q *querier) DeleteLicense(ctx context.Context, id int32) (int32, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteMCPServerConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceOauth2App); err != nil {
|
||||
return err
|
||||
@@ -2157,12 +2145,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
@@ -2494,17 +2482,6 @@ func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCo
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
// computer-use tooling. We only require that an explicit actor is present
|
||||
// in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDesktopEnabled(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2676,12 +2653,8 @@ func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid
|
||||
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.GetAuthorizedChats(ctx, arg, prep)
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
@@ -2787,13 +2760,6 @@ func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatP
|
||||
return q.db.GetEnabledChatProviders(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetEnabledMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetExternalAuthLink)(ctx, arg)
|
||||
}
|
||||
@@ -2852,13 +2818,6 @@ func (q *querier) GetFilteredInboxNotificationsByUserID(ctx context.Context, arg
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetFilteredInboxNotificationsByUserID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetForcedMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetGitSSHKey)(ctx, userID)
|
||||
}
|
||||
@@ -2999,48 +2958,6 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
|
||||
return q.db.GetLogoURL(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigBySlug(ctx, slug)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return q.db.GetMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
|
||||
return nil, err
|
||||
@@ -3227,34 +3144,6 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
|
||||
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetPRInsightsSummaryRow{}, err
|
||||
}
|
||||
return q.db.GetPRInsightsSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsTimeSeries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
|
||||
if err != nil {
|
||||
@@ -4608,13 +4497,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
|
||||
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
return database.AIBridgeModelThought{}, err
|
||||
}
|
||||
return q.db.InsertAIBridgeModelThought(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
// All aibridge_token_usages records belong to the initiator of their associated interception.
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
@@ -4671,16 +4553,16 @@ func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFil
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return nil, err
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatMessages(ctx, arg)
|
||||
return q.db.InsertChatMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
@@ -4807,13 +4689,6 @@ func (q *querier) InsertLicense(ctx context.Context, arg database.InsertLicenseP
|
||||
return q.db.InsertLicense(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.InsertMCPServerConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceWorkspaceAgentResourceMonitor); err != nil {
|
||||
return database.WorkspaceAgentMemoryResourceMonitor{}, err
|
||||
@@ -5466,32 +5341,6 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
|
||||
return q.db.SelectUsageEventsForPublishing(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
msg, err := q.db.GetChatMessageByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteChatMessageByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -5573,17 +5422,6 @@ func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateCh
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatMCPServerIDs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
// Authorize update on the parent chat of the edited message.
|
||||
msg, err := q.db.GetChatMessageByID(ctx, arg.ID)
|
||||
@@ -5748,13 +5586,6 @@ func (q *querier) UpdateInboxNotificationReadStatus(ctx context.Context, args da
|
||||
return update(q.log, q.auth, fetchFunc, q.db.UpdateInboxNotificationReadStatus)(ctx, args)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerConfig{}, err
|
||||
}
|
||||
return q.db.UpdateMCPServerConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
|
||||
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
@@ -6706,13 +6537,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -6800,13 +6624,6 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
|
||||
return q.db.UpsertLogoURL(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.MCPServerUserToken{}, err
|
||||
}
|
||||
return q.db.UpsertMCPServerUserToken(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -6959,13 +6776,6 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
|
||||
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.UsageEventExistsByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
// This check is probably overly restrictive, but the "correct" check isn't
|
||||
// necessarily obvious. It's only used as a verification check for ACLs right
|
||||
@@ -7061,7 +6871,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -401,27 +401,16 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.SoftDeleteChatMessagesAfterIDParams{
|
||||
arg := database.DeleteChatMessagesAfterIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 123,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("SoftDeleteChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := database.ChatMessage{
|
||||
ID: 456,
|
||||
ChatID: chat.ID,
|
||||
}
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), msg.ID).Return(nil).AnyTimes()
|
||||
check.Args(msg.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
@@ -629,17 +618,12 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because it re-routes through GetChats which uses SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -652,10 +636,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -686,13 +666,13 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
|
||||
}))
|
||||
s.Run("InsertChatMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessagesParams{ChatID: chat.ID})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatMessages(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msgs)
|
||||
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
|
||||
}))
|
||||
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -865,10 +845,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetUserChatSpendInPeriodParams{
|
||||
UserID: uuid.New(),
|
||||
@@ -1005,114 +981,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("CleanupDeletedMCPServerIDsFromChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CleanupDeletedMCPServerIDsFromChats(gomock.Any()).Return(nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceChat, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteMCPServerConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.DeleteMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().DeleteMCPServerUserToken(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetEnabledMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetEnabledMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetForcedMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetForcedMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetMCPServerConfigByID(gomock.Any(), config.ID).Return(config, nil).AnyTimes()
|
||||
check.Args(config.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetMCPServerConfigBySlug", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
slug := "test-mcp-server"
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{Slug: slug})
|
||||
dbm.EXPECT().GetMCPServerConfigBySlug(gomock.Any(), slug).Return(config, nil).AnyTimes()
|
||||
check.Args(slug).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetMCPServerConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
dbm.EXPECT().GetMCPServerConfigs(gomock.Any()).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerConfigsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
ids := []uuid.UUID{configA.ID, configB.ID}
|
||||
dbm.EXPECT().GetMCPServerConfigsByIDs(gomock.Any(), ids).Return([]database.MCPServerConfig{configA, configB}, nil).AnyTimes()
|
||||
check.Args(ids).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.MCPServerConfig{configA, configB})
|
||||
}))
|
||||
s.Run("GetMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
|
||||
dbm.EXPECT().GetMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(token)
|
||||
}))
|
||||
s.Run("GetMCPServerUserTokensByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
tokens := []database.MCPServerUserToken{testutil.Fake(s.T(), faker, database.MCPServerUserToken{UserID: userID})}
|
||||
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
|
||||
}))
|
||||
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.InsertMCPServerConfigParams{
|
||||
DisplayName: "Test MCP Server",
|
||||
Slug: "test-mcp-server",
|
||||
}
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{DisplayName: arg.DisplayName, Slug: arg.Slug})
|
||||
dbm.EXPECT().InsertMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpdateChatMCPServerIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatMCPServerIDsParams{
|
||||
ID: chat.ID,
|
||||
MCPServerIDs: []uuid.UUID{uuid.New()},
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatMCPServerIDs(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
config := testutil.Fake(s.T(), faker, database.MCPServerConfig{})
|
||||
arg := database.UpdateMCPServerConfigParams{
|
||||
ID: config.ID,
|
||||
DisplayName: "Updated MCP Server",
|
||||
Slug: "updated-mcp-server",
|
||||
}
|
||||
dbm.EXPECT().UpdateMCPServerConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpsertMCPServerUserToken", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertMCPServerUserTokenParams{
|
||||
MCPServerConfigID: uuid.New(),
|
||||
UserID: uuid.New(),
|
||||
AccessToken: "test-access-token",
|
||||
TokenType: "bearer",
|
||||
}
|
||||
token := testutil.Fake(s.T(), faker, database.MCPServerUserToken{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
|
||||
dbm.EXPECT().UpsertMCPServerUserToken(gomock.Any(), arg).Return(token, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(token)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
@@ -1489,8 +1357,8 @@ func (s *MethodTestSuite) TestLicense() {
|
||||
check.Args().Asserts().Returns("value")
|
||||
}))
|
||||
s.Run("GetDefaultProxyConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"}, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconURL: "/emojis/1f3e1.png"})
|
||||
dbm.EXPECT().GetDefaultProxyConfig(gomock.Any()).Return(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"}, nil).AnyTimes()
|
||||
check.Args().Asserts().Returns(database.GetDefaultProxyConfigRow{DisplayName: "Default", IconUrl: "/emojis/1f3e1.png"})
|
||||
}))
|
||||
s.Run("GetLogoURL", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetLogoURL(gomock.Any()).Return("value", nil).AnyTimes()
|
||||
@@ -1612,7 +1480,7 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
WorkspaceSharingDisabled: true,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
@@ -2043,26 +1911,6 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsSummaryParams{}
|
||||
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsTimeSeriesParams{}
|
||||
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPerModelParams{}
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
@@ -2551,12 +2399,9 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: uuid.New(),
|
||||
ExcludeServiceAccounts: false,
|
||||
}
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -5262,12 +5107,6 @@ func (s *MethodTestSuite) TestUsageEvents() {
|
||||
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
id := uuid.NewString()
|
||||
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
|
||||
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
|
||||
@@ -5328,17 +5167,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params).Asserts(intc, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
|
||||
|
||||
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
|
||||
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
|
||||
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
@@ -5697,16 +5525,12 @@ func TestAsChatd(t *testing.T) {
|
||||
require.NoError(t, err, "chat %s should be allowed", action)
|
||||
}
|
||||
|
||||
// Workspace read + update (update needed for ActivityBumpWorkspace).
|
||||
for _, action := range []policy.Action{
|
||||
policy.ActionRead, policy.ActionUpdate,
|
||||
} {
|
||||
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
|
||||
require.NoError(t, err, "workspace %s should be allowed", action)
|
||||
}
|
||||
// Workspace read.
|
||||
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceWorkspace)
|
||||
require.NoError(t, err, "workspace read should be allowed")
|
||||
|
||||
// DeploymentConfig read.
|
||||
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
||||
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
|
||||
require.NoError(t, err, "deployment config read should be allowed")
|
||||
|
||||
// User read_personal (needed for GetUserChatCustomPrompt).
|
||||
@@ -5717,12 +5541,16 @@ func TestAsChatd(t *testing.T) {
|
||||
t.Run("DeniedActions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Cannot delete workspaces.
|
||||
err := auth.Authorize(ctx, actor, policy.ActionDelete, rbac.ResourceWorkspace)
|
||||
require.Error(t, err, "workspace delete should be denied")
|
||||
// Cannot write workspaces.
|
||||
for _, action := range []policy.Action{
|
||||
policy.ActionUpdate, policy.ActionDelete,
|
||||
} {
|
||||
err := auth.Authorize(ctx, actor, action, rbac.ResourceWorkspace)
|
||||
require.Error(t, err, "workspace %s should be denied", action)
|
||||
}
|
||||
|
||||
// Cannot access users.
|
||||
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
|
||||
err := auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceUser)
|
||||
require.Error(t, err, "user read should be denied")
|
||||
|
||||
// Cannot access API keys.
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
|
||||
UUID: pair.OrganizationID,
|
||||
Valid: pair.OrganizationID != uuid.Nil,
|
||||
},
|
||||
IsSystem: rolestore.IsSystemRoleName(pair.Name),
|
||||
IsSystem: rbac.SystemRoleName(pair.Name),
|
||||
ID: uuid.New(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
|
||||
})
|
||||
require.NoError(t, err, "insert organization")
|
||||
|
||||
// Populate the placeholder system roles (created by DB
|
||||
// trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileSystemRole needs the system:update
|
||||
// Populate the placeholder organization-member system role (created by
|
||||
// DB trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
|
||||
// permission that `genCtx` does not have.
|
||||
sysCtx := dbauthz.AsSystemRestricted(genCtx)
|
||||
for roleName := range rolestore.SystemRoleNames {
|
||||
role := database.CustomRole{
|
||||
Name: roleName,
|
||||
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
|
||||
}
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
|
||||
require.NoError(t, err, "create role "+roleName)
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
}
|
||||
require.NoError(t, err, "reconcile role "+roleName)
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
|
||||
require.NoError(t, err, "create organization-member role")
|
||||
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
}
|
||||
require.NoError(t, err, "reconcile organization-member role")
|
||||
|
||||
return org
|
||||
}
|
||||
|
||||
@@ -264,14 +264,6 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
m.queryLatencies.WithLabelValues("CleanupDeletedMCPServerIDsFromChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CleanupDeletedMCPServerIDsFromChats").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -392,6 +384,14 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesAfterID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesAfterID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
@@ -488,22 +488,6 @@ func (m queryMetricsStore) DeleteLicense(ctx context.Context, id int32) (int32,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteMCPServerConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteMCPServerConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerConfigByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserToken").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteOAuth2ProviderAppByClientID(ctx, id)
|
||||
@@ -712,11 +696,10 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -1064,14 +1047,6 @@ func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
@@ -1216,11 +1191,11 @@ func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, us
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -1352,14 +1327,6 @@ func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]datab
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetEnabledMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetExternalAuthLink(ctx, arg)
|
||||
@@ -1416,14 +1383,6 @@ func (m queryMetricsStore) GetFilteredInboxNotificationsByUserID(ctx context.Con
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetForcedMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetForcedMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetForcedMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetGitSSHKey(ctx, userID)
|
||||
@@ -1584,54 +1543,6 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) {
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigBySlug(ctx, slug)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigBySlug").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigBySlug").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigs(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerConfigsByIDs(ctx, ids)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerConfigsByIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerConfigsByIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserToken").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetMCPServerUserTokensByUserID(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetMCPServerUserTokensByUserID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserTokensByUserID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg)
|
||||
@@ -1824,38 +1735,6 @@ func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Contex
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsPerModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPerModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPerModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsSummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsSummary").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsSummary").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsTimeSeries(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsTimeSeries").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsTimeSeries").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetParameterSchemasByJobID(ctx, jobID)
|
||||
@@ -3072,14 +2951,6 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
|
||||
@@ -3144,11 +3015,11 @@ func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.Inse
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessages(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessages").Inc()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -3272,14 +3143,6 @@ func (m queryMetricsStore) InsertLicense(ctx context.Context, arg database.Inser
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertMCPServerConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertMCPServerConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertMCPServerConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertMemoryResourceMonitor(ctx, arg)
|
||||
@@ -3864,22 +3727,6 @@ func (m queryMetricsStore) SelectUsageEventsForPublishing(ctx context.Context, n
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteChatMessageByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessageByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessagesAfterID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -3952,14 +3799,6 @@ func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatMCPServerIDs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatMCPServerIDs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMessageByID(ctx, arg)
|
||||
@@ -4064,14 +3903,6 @@ func (m queryMetricsStore) UpdateInboxNotificationReadStatus(ctx context.Context
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateMCPServerConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateMCPServerConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateMCPServerConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateMemberRoles(ctx, arg)
|
||||
@@ -4132,7 +3963,6 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4712,14 +4542,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
@@ -4808,14 +4630,6 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertMCPServerUserToken").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserToken").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertNotificationReportGeneratorLog(ctx, arg)
|
||||
@@ -4952,14 +4766,6 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UsageEventExistsByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UsageEventExistsByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UsageEventExistsByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds)
|
||||
@@ -5079,11 +4885,3 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -334,20 +334,6 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
|
||||
}
|
||||
|
||||
// CleanupDeletedMCPServerIDsFromChats mocks base method.
|
||||
func (m *MockStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupDeletedMCPServerIDsFromChats", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupDeletedMCPServerIDsFromChats indicates an expected call of CleanupDeletedMCPServerIDsFromChats.
|
||||
func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -612,6 +598,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID mocks base method.
|
||||
func (m *MockStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatMessagesAfterID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID indicates an expected call of DeleteChatMessagesAfterID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -783,34 +783,6 @@ func (mr *MockStoreMockRecorder) DeleteLicense(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteLicense", reflect.TypeOf((*MockStore)(nil).DeleteLicense), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteMCPServerConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteMCPServerConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteMCPServerConfigByID indicates an expected call of DeleteMCPServerConfigByID.
|
||||
func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteMCPServerUserToken indicates an expected call of DeleteMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) DeleteMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOAuth2ProviderAppByClientID mocks base method.
|
||||
func (m *MockStore) DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1183,17 +1155,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
@@ -1759,21 +1731,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedChats mocks base method.
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1939,21 +1896,6 @@ func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2224,19 +2166,19 @@ func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// GetChats mocks base method.
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChats indicates an expected call of GetChats.
|
||||
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
@@ -2479,21 +2421,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx)
|
||||
}
|
||||
|
||||
// GetEnabledMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEnabledMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEnabledMCPServerConfigs indicates an expected call of GetEnabledMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetEnabledMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetExternalAuthLink mocks base method.
|
||||
func (m *MockStore) GetExternalAuthLink(ctx context.Context, arg database.GetExternalAuthLinkParams) (database.ExternalAuthLink, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2599,21 +2526,6 @@ func (mr *MockStoreMockRecorder) GetFilteredInboxNotificationsByUserID(ctx, arg
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFilteredInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetFilteredInboxNotificationsByUserID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetForcedMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetForcedMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetForcedMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetForcedMCPServerConfigs indicates an expected call of GetForcedMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetForcedMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForcedMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetForcedMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetGitSSHKey mocks base method.
|
||||
func (m *MockStore) GetGitSSHKey(ctx context.Context, userID uuid.UUID) (database.GitSSHKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2914,96 +2826,6 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigByID mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigByID", ctx, id)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigByID indicates an expected call of GetMCPServerConfigByID.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigBySlug mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigBySlug(ctx context.Context, slug string) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigBySlug", ctx, slug)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigBySlug indicates an expected call of GetMCPServerConfigBySlug.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigBySlug(ctx, slug any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigBySlug", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigBySlug), ctx, slug)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigs mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigs", ctx)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigs indicates an expected call of GetMCPServerConfigs.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigs(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetMCPServerConfigsByIDs mocks base method.
|
||||
func (m *MockStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerConfigsByIDs", ctx, ids)
|
||||
ret0, _ := ret[0].([]database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerConfigsByIDs indicates an expected call of GetMCPServerConfigsByIDs.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerUserToken indicates an expected call of GetMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// GetMCPServerUserTokensByUserID mocks base method.
|
||||
func (m *MockStore) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMCPServerUserTokensByUserID", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetMCPServerUserTokensByUserID indicates an expected call of GetMCPServerUserTokensByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetMCPServerUserTokensByUserID(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserTokensByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserTokensByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// GetNotificationMessagesByStatus mocks base method.
|
||||
func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3364,66 +3186,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPerModel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPerModelRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel indicates an expected call of GetPRInsightsPerModel.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
func (m *MockStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsSummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.GetPRInsightsSummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary indicates an expected call of GetPRInsightsSummary.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsSummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsSummary", reflect.TypeOf((*MockStore)(nil).GetPRInsightsSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries mocks base method.
|
||||
func (m *MockStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsTimeSeries", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsTimeSeriesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries indicates an expected call of GetPRInsightsTimeSeries.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsTimeSeries(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsTimeSeries", reflect.TypeOf((*MockStore)(nil).GetPRInsightsTimeSeries), ctx, arg)
|
||||
}
|
||||
|
||||
// GetParameterSchemasByJobID mocks base method.
|
||||
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5748,21 +5510,6 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeModelThought)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
|
||||
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeTokenUsage mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5883,19 +5630,19 @@ func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessages mocks base method.
|
||||
func (m *MockStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatMessages", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatMessages indicates an expected call of InsertChatMessages.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessages(ctx, arg any) *gomock.Call {
|
||||
// InsertChatMessage indicates an expected call of InsertChatMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessages", reflect.TypeOf((*MockStore)(nil).InsertChatMessages), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatModelConfig mocks base method.
|
||||
@@ -6119,21 +5866,6 @@ func (mr *MockStoreMockRecorder) InsertLicense(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertLicense", reflect.TypeOf((*MockStore)(nil).InsertLicense), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertMCPServerConfig mocks base method.
|
||||
func (m *MockStore) InsertMCPServerConfig(ctx context.Context, arg database.InsertMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertMCPServerConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertMCPServerConfig indicates an expected call of InsertMCPServerConfig.
|
||||
func (mr *MockStoreMockRecorder) InsertMCPServerConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMCPServerConfig", reflect.TypeOf((*MockStore)(nil).InsertMCPServerConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertMemoryResourceMonitor mocks base method.
|
||||
func (m *MockStore) InsertMemoryResourceMonitor(ctx context.Context, arg database.InsertMemoryResourceMonitorParams) (database.WorkspaceAgentMemoryResourceMonitor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7275,34 +7007,6 @@ func (mr *MockStoreMockRecorder) SelectUsageEventsForPublishing(ctx, now any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SelectUsageEventsForPublishing", reflect.TypeOf((*MockStore)(nil).SelectUsageEventsForPublishing), ctx, now)
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessageByID mocks base method.
|
||||
func (m *MockStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteChatMessageByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessageByID indicates an expected call of SoftDeleteChatMessageByID.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteChatMessageByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessageByID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessagesAfterID mocks base method.
|
||||
func (m *MockStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteChatMessagesAfterID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessagesAfterID indicates an expected call of SoftDeleteChatMessagesAfterID.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7433,21 +7137,6 @@ func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs mocks base method.
|
||||
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatMCPServerIDs", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs indicates an expected call of UpdateChatMCPServerIDs.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatMCPServerIDs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatMCPServerIDs", reflect.TypeOf((*MockStore)(nil).UpdateChatMCPServerIDs), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMessageByID mocks base method.
|
||||
func (m *MockStore) UpdateChatMessageByID(ctx context.Context, arg database.UpdateChatMessageByIDParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7641,21 +7330,6 @@ func (mr *MockStoreMockRecorder) UpdateInboxNotificationReadStatus(ctx, arg any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInboxNotificationReadStatus", reflect.TypeOf((*MockStore)(nil).UpdateInboxNotificationReadStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateMCPServerConfig mocks base method.
|
||||
func (m *MockStore) UpdateMCPServerConfig(ctx context.Context, arg database.UpdateMCPServerConfigParams) (database.MCPServerConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateMCPServerConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateMCPServerConfig indicates an expected call of UpdateMCPServerConfig.
|
||||
func (mr *MockStoreMockRecorder) UpdateMCPServerConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMCPServerConfig", reflect.TypeOf((*MockStore)(nil).UpdateMCPServerConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateMemberRoles mocks base method.
|
||||
func (m *MockStore) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8806,20 +8480,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8980,21 +8640,6 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertMCPServerUserToken mocks base method.
|
||||
func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertMCPServerUserToken", ctx, arg)
|
||||
ret0, _ := ret[0].(database.MCPServerUserToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertMCPServerUserToken indicates an expected call of UpsertMCPServerUserToken.
|
||||
func (mr *MockStoreMockRecorder) UpsertMCPServerUserToken(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserToken", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserToken), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertNotificationReportGeneratorLog mocks base method.
|
||||
func (m *MockStore) UpsertNotificationReportGeneratorLog(ctx context.Context, arg database.UpsertNotificationReportGeneratorLogParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9241,21 +8886,6 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg)
|
||||
}
|
||||
|
||||
// UsageEventExistsByID mocks base method.
|
||||
func (m *MockStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UsageEventExistsByID", ctx, id)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UsageEventExistsByID indicates an expected call of UsageEventExistsByID.
|
||||
func (mr *MockStoreMockRecorder) UsageEventExistsByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageEventExistsByID", reflect.TypeOf((*MockStore)(nil).UsageEventExistsByID), ctx, id)
|
||||
}
|
||||
|
||||
// ValidateGroupIDs mocks base method.
|
||||
func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+12
-148
@@ -512,12 +512,6 @@ CREATE TYPE resource_type AS ENUM (
|
||||
'ai_seat'
|
||||
);
|
||||
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM (
|
||||
'none',
|
||||
'everyone',
|
||||
'service_accounts'
|
||||
);
|
||||
|
||||
CREATE TYPE startup_script_behavior AS ENUM (
|
||||
'blocking',
|
||||
'non-blocking'
|
||||
@@ -620,35 +614,28 @@ CREATE FUNCTION aggregate_usage_event() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
-- Check for supported event types and throw error for unknown types.
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
|
||||
-- Check for supported event types and throw error for unknown types
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
-- Extract the date from the created_at timestamp, always using UTC for
|
||||
-- consistency
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
-- Handle simple counter events by summing the count.
|
||||
-- Handle simple counter events by summing the count
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
-- Heartbeat events: keep the max value seen that day
|
||||
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
GREATEST(
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
@@ -805,7 +792,7 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
|
||||
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
@@ -820,8 +807,7 @@ BEGIN
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
@@ -832,18 +818,6 @@ BEGIN
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
@@ -1112,15 +1086,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1289,9 +1254,7 @@ CREATE TABLE chat_messages (
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint,
|
||||
runtime_ms bigint,
|
||||
deleted boolean DEFAULT false NOT NULL
|
||||
total_cost_micros bigint
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -1393,8 +1356,7 @@ CREATE TABLE chats (
|
||||
last_model_config_id uuid NOT NULL,
|
||||
archived boolean DEFAULT false NOT NULL,
|
||||
last_error text,
|
||||
mode chat_mode,
|
||||
mcp_server_ids uuid[] DEFAULT '{}'::uuid[] NOT NULL
|
||||
mode chat_mode
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -1671,53 +1633,6 @@ CREATE SEQUENCE licenses_id_seq
|
||||
|
||||
ALTER SEQUENCE licenses_id_seq OWNED BY licenses.id;
|
||||
|
||||
CREATE TABLE mcp_server_configs (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
display_name text NOT NULL,
|
||||
slug text NOT NULL,
|
||||
description text DEFAULT ''::text NOT NULL,
|
||||
icon_url text DEFAULT ''::text NOT NULL,
|
||||
transport text DEFAULT 'streamable_http'::text NOT NULL,
|
||||
url text NOT NULL,
|
||||
auth_type text DEFAULT 'none'::text NOT NULL,
|
||||
oauth2_client_id text DEFAULT ''::text NOT NULL,
|
||||
oauth2_client_secret text DEFAULT ''::text NOT NULL,
|
||||
oauth2_client_secret_key_id text,
|
||||
oauth2_auth_url text DEFAULT ''::text NOT NULL,
|
||||
oauth2_token_url text DEFAULT ''::text NOT NULL,
|
||||
oauth2_scopes text DEFAULT ''::text NOT NULL,
|
||||
api_key_header text DEFAULT 'Authorization'::text NOT NULL,
|
||||
api_key_value text DEFAULT ''::text NOT NULL,
|
||||
api_key_value_key_id text,
|
||||
custom_headers text DEFAULT '{}'::text NOT NULL,
|
||||
custom_headers_key_id text,
|
||||
tool_allow_list text[] DEFAULT '{}'::text[] NOT NULL,
|
||||
tool_deny_list text[] DEFAULT '{}'::text[] NOT NULL,
|
||||
availability text DEFAULT 'default_off'::text NOT NULL,
|
||||
enabled boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
updated_by uuid,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text]))),
|
||||
CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))),
|
||||
CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text])))
|
||||
);
|
||||
|
||||
CREATE TABLE mcp_server_user_tokens (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
mcp_server_config_id uuid NOT NULL,
|
||||
user_id uuid NOT NULL,
|
||||
access_token text NOT NULL,
|
||||
access_token_key_id text,
|
||||
refresh_token text DEFAULT ''::text NOT NULL,
|
||||
refresh_token_key_id text,
|
||||
token_type text DEFAULT 'Bearer'::text NOT NULL,
|
||||
expiry timestamp with time zone,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE notification_messages (
|
||||
id uuid NOT NULL,
|
||||
notification_template_id uuid NOT NULL,
|
||||
@@ -1908,11 +1823,9 @@ CREATE TABLE organizations (
|
||||
display_name text NOT NULL,
|
||||
icon text DEFAULT ''::text NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
|
||||
workspace_sharing_disabled boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -2712,7 +2625,7 @@ CREATE TABLE usage_events (
|
||||
publish_started_at timestamp with time zone,
|
||||
published_at timestamp with time zone,
|
||||
failure_message text,
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = ANY (ARRAY['dc_managed_agents_v1'::text, 'hb_ai_seats_v1'::text])))
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = 'dc_managed_agents_v1'::text))
|
||||
);
|
||||
|
||||
COMMENT ON TABLE usage_events IS 'usage_events contains usage data that is collected from the product and potentially shipped to the usage collector service.';
|
||||
@@ -3391,18 +3304,6 @@ ALTER TABLE ONLY licenses
|
||||
ALTER TABLE ONLY licenses
|
||||
ADD CONSTRAINT licenses_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY notification_messages
|
||||
ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3661,8 +3562,6 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
|
||||
@@ -3751,12 +3650,6 @@ CREATE INDEX idx_inbox_notifications_user_id_read_at ON inbox_notifications USIN
|
||||
|
||||
CREATE INDEX idx_inbox_notifications_user_id_template_id_targets ON inbox_notifications USING btree (user_id, template_id, targets);
|
||||
|
||||
CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (enabled) WHERE (enabled = true);
|
||||
|
||||
CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text));
|
||||
|
||||
CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id);
|
||||
|
||||
CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status);
|
||||
|
||||
CREATE INDEX idx_organization_member_organization_id_uuid ON organization_members USING btree (organization_id);
|
||||
@@ -3785,8 +3678,6 @@ CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree
|
||||
|
||||
CREATE UNIQUE INDEX idx_unique_preset_name ON template_version_presets USING btree (name, template_version_id);
|
||||
|
||||
CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, created_at) WHERE (event_type = 'hb_ai_seats_v1'::text);
|
||||
|
||||
CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at);
|
||||
|
||||
CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at);
|
||||
@@ -3961,7 +3852,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
|
||||
|
||||
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
|
||||
|
||||
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
|
||||
|
||||
@@ -4081,33 +3972,6 @@ ALTER TABLE ONLY jfrog_xray_scans
|
||||
ALTER TABLE ONLY jfrog_xray_scans
|
||||
ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_configs
|
||||
ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
|
||||
ALTER TABLE ONLY mcp_server_user_tokens
|
||||
ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY notification_messages
|
||||
ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -40,15 +40,6 @@ const (
|
||||
ForeignKeyInboxNotificationsUserID ForeignKeyConstraint = "inbox_notifications_user_id_fkey" // ALTER TABLE ONLY inbox_notifications ADD CONSTRAINT inbox_notifications_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyJfrogXrayScansAgentID ForeignKeyConstraint = "jfrog_xray_scans_agent_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE CASCADE;
|
||||
ForeignKeyJfrogXrayScansWorkspaceID ForeignKeyConstraint = "jfrog_xray_scans_workspace_id_fkey" // ALTER TABLE ONLY jfrog_xray_scans ADD CONSTRAINT jfrog_xray_scans_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE;
|
||||
ForeignKeyMcpServerConfigsAPIKeyValueKeyID ForeignKeyConstraint = "mcp_server_configs_api_key_value_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_api_key_value_key_id_fkey FOREIGN KEY (api_key_value_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsCreatedBy ForeignKeyConstraint = "mcp_server_configs_created_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL;
|
||||
ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE;
|
||||
ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest);
|
||||
ForeignKeyMcpServerUserTokensUserID ForeignKeyConstraint = "mcp_server_user_tokens_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationMessagesNotificationTemplateID ForeignKeyConstraint = "notification_messages_notification_template_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationMessagesUserID ForeignKeyConstraint = "notification_messages_user_id_fkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyNotificationPreferencesNotificationTemplateID ForeignKeyConstraint = "notification_preferences_notification_template_id_fkey" // ALTER TABLE ONLY notification_preferences ADD CONSTRAINT notification_preferences_notification_template_id_fkey FOREIGN KEY (notification_template_id) REFERENCES notification_templates(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -26,7 +26,6 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
||||
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
||||
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
||||
"GetUsers": "GetAuthorizedUsers",
|
||||
"GetChats": "GetAuthorizedChats",
|
||||
}
|
||||
|
||||
// Scan custom
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
DROP INDEX idx_aibridge_model_thoughts_interception_id;
|
||||
|
||||
DROP TABLE aibridge_model_thoughts;
|
||||
@@ -1,10 +0,0 @@
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id UUID NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts(interception_id);
|
||||
-52
@@ -1,52 +0,0 @@
|
||||
DELETE FROM custom_roles
|
||||
WHERE name = 'organization-service-account' AND is_system = true;
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN workspace_sharing_disabled boolean NOT NULL DEFAULT false;
|
||||
|
||||
-- Migrate back: 'none' -> disabled, everything else -> enabled.
|
||||
UPDATE organizations
|
||||
SET workspace_sharing_disabled = true
|
||||
WHERE shareable_workspace_owners = 'none';
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN shareable_workspace_owners;
|
||||
|
||||
DROP TYPE shareable_workspace_owners;
|
||||
|
||||
-- Restore the original single-role trigger from migration 408.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_organization_system_roles ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_organization_system_roles;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_org_member_system_role() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_org_member_system_role();
|
||||
@@ -1,101 +0,0 @@
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM ('none', 'everyone', 'service_accounts');
|
||||
|
||||
ALTER TABLE organizations
|
||||
ADD COLUMN shareable_workspace_owners shareable_workspace_owners NOT NULL DEFAULT 'everyone';
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
-- Migrate existing data from the boolean column.
|
||||
UPDATE organizations
|
||||
SET shareable_workspace_owners = 'none'
|
||||
WHERE workspace_sharing_disabled = true;
|
||||
|
||||
ALTER TABLE organizations DROP COLUMN workspace_sharing_disabled;
|
||||
|
||||
-- Defensively rename any existing 'organization-service-account' roles
|
||||
-- so they don't collide with the new system role.
|
||||
UPDATE custom_roles
|
||||
SET name = name || '-' || id::text
|
||||
-- lower(name) is part of the existing unique index
|
||||
WHERE lower(name) = 'organization-service-account';
|
||||
|
||||
-- Create skeleton organization-service-account system roles for all
|
||||
-- existing organizations, mirroring what migration 408 did for
|
||||
-- organization-member.
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
'organization-service-account',
|
||||
'',
|
||||
id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
FROM
|
||||
organizations;
|
||||
|
||||
-- Replace the single-role trigger with one that creates both system
|
||||
-- roles when a new organization is inserted.
|
||||
DROP TRIGGER IF EXISTS trigger_insert_org_member_system_role ON organizations;
|
||||
DROP FUNCTION IF EXISTS insert_org_member_system_role;
|
||||
|
||||
CREATE OR REPLACE FUNCTION insert_organization_system_roles() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
INSERT INTO custom_roles (
|
||||
name,
|
||||
display_name,
|
||||
organization_id,
|
||||
site_permissions,
|
||||
org_permissions,
|
||||
user_permissions,
|
||||
member_permissions,
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles
|
||||
AFTER INSERT ON organizations
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION insert_organization_system_roles();
|
||||
@@ -1,38 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_usage_events_ai_seats;
|
||||
|
||||
-- Remove hb_ai_seats_v1 rows so the original constraint can be restored.
|
||||
DELETE FROM usage_events WHERE event_type = 'hb_ai_seats_v1';
|
||||
DELETE FROM usage_events_daily WHERE event_type = 'hb_ai_seats_v1';
|
||||
|
||||
-- Restore original constraint.
|
||||
ALTER TABLE usage_events
|
||||
DROP CONSTRAINT usage_event_type_check,
|
||||
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1'));
|
||||
|
||||
-- Restore the original aggregate function without hb_ai_seats_v1 support.
|
||||
CREATE OR REPLACE FUNCTION aggregate_usage_event()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -1,50 +0,0 @@
|
||||
-- Expand the CHECK constraint to allow hb_ai_seats_v1.
|
||||
ALTER TABLE usage_events
|
||||
DROP CONSTRAINT usage_event_type_check,
|
||||
ADD CONSTRAINT usage_event_type_check CHECK (event_type IN ('dc_managed_agents_v1', 'hb_ai_seats_v1'));
|
||||
|
||||
-- Partial index for efficient lookups of AI seat heartbeat events by time.
|
||||
-- This will be used for the admin dashboard to see seat count over time.
|
||||
CREATE INDEX idx_usage_events_ai_seats
|
||||
ON usage_events (event_type, created_at)
|
||||
WHERE event_type = 'hb_ai_seats_v1';
|
||||
|
||||
-- Update the aggregate function to handle hb_ai_seats_v1 events.
|
||||
-- Heartbeat events replace the previous value for the same time period.
|
||||
CREATE OR REPLACE FUNCTION aggregate_usage_event()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- Check for supported event types and throw error for unknown types.
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
-- Handle simple counter events by summing the count.
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
-- Heartbeat events: keep the max value seen that day
|
||||
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
GREATEST(
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE chat_messages DROP COLUMN runtime_ms;
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN runtime_ms bigint;
|
||||
@@ -1,2 +0,0 @@
|
||||
DELETE FROM chat_messages WHERE deleted = true;
|
||||
ALTER TABLE chat_messages DROP COLUMN deleted;
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN deleted boolean NOT NULL DEFAULT false;
|
||||
@@ -1,6 +0,0 @@
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS mcp_server_ids;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_configs_enabled;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_configs_forced;
|
||||
DROP INDEX IF EXISTS idx_mcp_server_user_tokens_user_id;
|
||||
DROP TABLE IF EXISTS mcp_server_user_tokens;
|
||||
DROP TABLE IF EXISTS mcp_server_configs;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user