Compare commits

..

1 Commits

Author SHA1 Message Date
Kyle Carberry 8de13f6bf6 feat(chatd): add chat_runs and chat_run_steps tables
Introduces a run/step model for chat processing that replaces the
status/worker_id columns on the chats table with normalized tables
(chat_runs, chat_run_steps) and SQL views that derive status.

Key changes:
- New tables: chat_runs, chat_run_steps with triggers for numbering
  and constraint enforcement
- chat_run_steps tracks per-LLM-call token usage, cost, tool call
  counts, and message references (response_message_id, first/last)
- Token/cost columns moved from chat_messages to chat_run_steps
  (immutable audit trail independent of message edits/deletes)
- Cost analytics queries (GetChatCost*) read from steps, not messages
- SQL views: chat_run_steps_with_status, chat_runs_with_status,
  chats_with_status derive computed_status
- PersistStep callback now completes the current step with real token
  data, message IDs, and tool call counts, then creates a new step
  if the agentic loop continues (continuation_reason: tool_call)
- Each message inserted during a run is tagged with chat_run_id and
  chat_run_step_id linking it to its originating step
- createRunAndStep helper creates run + first step inside user-facing
  TXs to eliminate TOCTOU gaps
- processOnce batches up to maxAcquirePerCycle (10) acquisitions
- recoverStaleChatRunSteps replaces recoverStaleChats
- Status view uses 'streaming' (matching blink) for active steps
2026-03-14 19:55:44 +00:00
690 changed files with 26797 additions and 69953 deletions
-72
View File
@@ -1,72 +0,0 @@
---
name: pull-requests
description: "Guide for creating, updating, and following up on pull requests in the Coder repository. Use when asked to open a PR, update a PR, rewrite a PR description, or follow up on CI/check failures."
---
# Pull Request Skill
## When to Use This Skill
Use this skill when asked to:
- Create a pull request for the current branch.
- Update an existing PR branch or description.
- Rewrite a PR body.
- Follow up on CI or check failures for an existing PR.
## References
Use the canonical docs for shared conventions and validation guidance:
- PR title and description conventions:
`.claude/docs/PR_STYLE_GUIDE.md`
- Local validation commands and git hooks: `AGENTS.md` (Essential Commands and
Git Hooks sections)
## Lifecycle Rules
1. **Check for an existing PR** before creating a new one:
```bash
gh pr list --head "$(git branch --show-current)" --author @me --json number --jq '.[0].number // empty'
```
If that returns a number, update that PR. If it returns empty output,
create a new one.
2. **Check you are not on main.** If the current branch is `main` or `master`,
create a feature branch before doing PR work.
3. **Default to draft.** Use `gh pr create --draft` unless the user explicitly
asks for ready-for-review.
4. **Keep description aligned with the full diff.** Re-read the diff against
the base branch before writing or updating the title and body. Describe the
entire PR diff, not just the last commit.
5. **Never auto-merge.** Do not merge or mark ready for review unless the user
explicitly asks.
6. **Never push to main or master.**
## CI / Checks Follow-up
**Always watch CI checks after pushing.** Do not push and walk away.
After pushing:
- Monitor CI with `gh pr checks <PR_NUMBER> --watch`.
- Use `gh pr view <PR_NUMBER> --json statusCheckRollup` for programmatic check
status.
If checks fail:
1. Find the failed run ID from the `gh pr checks` output.
2. Read the logs with `gh run view <run-id> --log-failed`.
3. Fix the problem locally.
4. Run `make pre-commit`.
5. Push the fix.
## What Not to Do
- Do not reference or call helper scripts that do not exist in this
repository.
- Do not auto-merge or mark ready for review without explicit user request.
- Do not push to `origin/main` or `origin/master`.
- Do not skip local validation before pushing.
- Do not fabricate or embellish PR descriptions.
+1 -1
View File
@@ -113,7 +113,7 @@ Coder emphasizes clear error handling, with specific patterns required:
All tests should run in parallel using `t.Parallel()` to ensure efficient testing and expose potential race conditions. The codebase is rigorously linted with golangci-lint to maintain consistent code quality.
Git contributions follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Git contributions follow a standard format with commit messages structured as `type: <message>`, where type is one of `feat`, `fix`, or `chore`.
## Development Workflow
+25 -7
View File
@@ -4,13 +4,22 @@ This guide documents the PR description style used in the Coder repository, base
## PR Title Format
Format: `type(scope): description`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/) format:
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
```text
type(scope): brief description
```
Examples:
**Common types:**
- `feat`: New features
- `fix`: Bug fixes
- `refactor`: Code refactoring without behavior change
- `perf`: Performance improvements
- `docs`: Documentation changes
- `chore`: Dependency updates, tooling changes
**Examples:**
- `feat: add tracing to aibridge`
- `fix: move contexts to appropriate locations`
@@ -177,6 +186,16 @@ Dependabot PRs are auto-generated - don't try to match their verbose style for m
Changes from https://github.com/upstream/repo/pull/XXX/
```
## Attribution Footer
For AI-generated PRs, end with:
```markdown
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
```
## Creating PRs as Draft
**IMPORTANT**: Unless explicitly told otherwise, always create PRs as drafts using the `--draft` flag:
@@ -187,12 +206,11 @@ gh pr create --draft --title "..." --body "..."
After creating the PR, encourage the user to review it before marking as ready:
```text
```
I've created draft PR #XXXX. Please review the changes and mark it as ready for review when you're satisfied.
```
This allows the user to:
- Review the code changes before requesting reviews from maintainers
- Make additional adjustments if needed
- Ensure CI passes before notifying reviewers
+3 -5
View File
@@ -136,11 +136,9 @@ Then make your changes and push normally. Don't use `git push --force` unless th
## Commit Style
Format: `type(scope): message`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
- Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/)
- Format: `type(scope): message`
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
- Keep message titles concise (~70 characters)
- Use imperative, present tense in commit titles
-9
View File
@@ -1,9 +0,0 @@
paths:
# The triage workflow uses a quoted heredoc (<<'EOF') with ${VAR}
# placeholders that envsubst expands later. Shellcheck's SC2016
# warns about unexpanded variables in single-quoted strings, but
# the non-expansion is intentional here. Actionlint doesn't honor
# inline shellcheck disable directives inside heredocs.
.github/workflows/triage-via-chat-api.yaml:
ignore:
- 'SC2016'
+34 -96
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -191,7 +191,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@631208b7aac2daa8b707f55e7331f9112b0e062d # v1.44.0
uses: crate-ci/typos@2d0ce569feab1f8752f1dde43cc2f2aa53236e06 # v1.40.0
with:
config: .github/workflows/typos.toml
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -537,7 +537,7 @@ jobs:
embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }}
- name: Upload failed test db dumps
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -818,7 +818,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -826,7 +826,7 @@ jobs:
- name: Upload debug log
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coderd-debug-logs${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/e2e/test-results/debug.log
@@ -834,7 +834,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1108,7 +1108,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1198,7 +1198,7 @@ jobs:
make -j \
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
@@ -1216,28 +1216,11 @@ jobs:
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
# Free up disk space before building Docker images. The preceding
# Build step produces ~2 GB of binaries and packages, the Go build
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
# builds, pushes, and SBOM generation need headroom that isn't
# available without reclaiming some of that space.
- name: Clean up build cache
run: |
set -euxo pipefail
# Go caches are no longer needed — binaries are already compiled.
go clean -cache -modcache
# Remove .apk and .rpm packages that are not uploaded as
# artifacts and were only built as make prerequisites.
rm -f ./build/*.apk ./build/*.rpm
- name: Build Linux Docker images
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
DOCKER_CLI_EXPERIMENTAL: "enabled"
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
# the Docker image targets — they were already built above.
DOCKER_IMAGE_NO_PREREQUISITES: "true"
run: |
set -euxo pipefail
@@ -1319,7 +1302,7 @@ jobs:
id: attest_main
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1356,7 +1339,7 @@ jobs:
id: attest_latest
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1393,7 +1376,7 @@ jobs:
id: attest_version
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1455,60 +1438,15 @@ jobs:
^v
prune-untagged: true
- name: Upload build artifact (coder-linux-amd64.tar.gz)
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.tar.gz
path: ./build/*_linux_amd64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-amd64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-amd64.deb
path: ./build/*_linux_amd64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.tar.gz
path: ./build/*_linux_arm64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.deb
path: ./build/*_linux_arm64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.tar.gz
path: ./build/*_linux_armv7.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.deb
path: ./build/*_linux_armv7.deb
retention-days: 7
- name: Upload build artifact (coder-windows-amd64.zip)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-windows-amd64.zip
path: ./build/*_windows_amd64.zip
name: coder
path: |
./build/*.zip
./build/*.tar.gz
./build/*.deb
retention-days: 7
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
@@ -1543,7 +1481,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-141
View File
@@ -23,44 +23,6 @@ permissions:
concurrency: pr-${{ github.ref }}
jobs:
community-label:
runs-on: ubuntu-latest
permissions:
pull-requests: write
if: >-
${{
github.event_name == 'pull_request_target' &&
github.event.action == 'opened' &&
github.event.pull_request.author_association != 'MEMBER' &&
github.event.pull_request.author_association != 'COLLABORATOR' &&
github.event.pull_request.author_association != 'OWNER'
}}
steps:
- name: Add community label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const params = {
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
}
const labels = context.payload.pull_request.labels.map((label) => label.name)
if (labels.includes("community")) {
console.log('PR already has "community" label.')
return
}
console.log(
'Adding "community" label for author association "%s".',
context.payload.pull_request.author_association,
)
await github.rest.issues.addLabels({
...params,
labels: ["community"],
})
cla:
runs-on: ubuntu-latest
permissions:
@@ -83,109 +45,6 @@ jobs:
# Some users have signed a corporate CLA with Coder so are exempt from signing our community one.
allowlist: "coryb,aaronlehmann,dependabot*,blink-so*,blinkagent*"
title:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request_target' }}
steps:
- name: Validate PR title
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const { pull_request } = context.payload;
const title = pull_request.title;
const repo = { owner: context.repo.owner, repo: context.repo.repo };
const allowedTypes = [
"feat", "fix", "docs", "style", "refactor",
"perf", "test", "build", "ci", "chore", "revert",
];
const expectedFormat = `"type(scope): description" or "type: description"`;
const guidelinesLink = `See: https://github.com/coder/coder/blob/main/docs/about/contributing/CONTRIBUTING.md#commit-messages`;
const scopeHint = (type) =>
`Use a broader scope or no scope (e.g., "${type}: ...") for cross-cutting changes.\n` +
guidelinesLink;
console.log("Title: %s", title);
// Parse conventional commit format: type(scope)!: description
const match = title.match(/^(\w+)(\(([^)]*)\))?(!)?\s*:\s*.+/);
if (!match) {
core.setFailed(
`PR title does not match conventional commit format.\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
const type = match[1];
const scope = match[3]; // undefined if no parentheses
// Validate type.
if (!allowedTypes.includes(type)) {
core.setFailed(
`PR title has invalid type "${type}".\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
// If no scope, we're done.
if (!scope) {
console.log("No scope provided, title is valid.");
return;
}
console.log("Scope: %s", scope);
// Fetch changed files.
const files = await github.paginate(github.rest.pulls.listFiles, {
...repo,
pull_number: pull_request.number,
per_page: 100,
});
const changedPaths = files.map(f => f.filename);
console.log("Changed files: %d", changedPaths.length);
// Derive scope type from the changed files. The diff is the
// source of truth: if files exist under the scope, the path
// exists on the PR branch. No need for Contents API calls.
const isDir = changedPaths.some(f => f.startsWith(scope + "/"));
const isFile = changedPaths.some(f => f === scope);
const isStem = changedPaths.some(f => f.startsWith(scope + "."));
if (!isDir && !isFile && !isStem) {
core.setFailed(
`PR title scope "${scope}" does not match any files changed in this PR.\n` +
`Scopes must reference a path (directory or file stem) that contains changed files.\n` +
scopeHint(type)
);
return;
}
// Verify all changed files fall under the scope.
const outsideFiles = changedPaths.filter(f => {
if (isDir && f.startsWith(scope + "/")) return false;
if (f === scope) return false;
if (isStem && f.startsWith(scope + ".")) return false;
return true;
});
if (outsideFiles.length > 0) {
const listed = outsideFiles.map(f => " - " + f).join("\n");
core.setFailed(
`PR title scope "${scope}" does not contain all changed files.\n` +
`Files outside scope:\n${listed}\n\n` +
scopeHint(type)
);
return;
}
console.log("PR title is valid.");
release-labels:
runs-on: ubuntu-latest
permissions:
+19 -15
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -61,11 +61,11 @@ jobs:
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
permissions:
contents: read
id-token: write # to authenticate to EKS cluster
id-token: write
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,29 +76,33 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
role-to-assume: ${{ vars.AWS_DOGFOOD_DEPLOY_ROLE }}
aws-region: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Get Cluster Credentials
run: aws eks update-kubeconfig --name "$AWS_DOGFOOD_CLUSTER_NAME" --region "$AWS_DOGFOOD_DEPLOY_REGION"
env:
AWS_DOGFOOD_CLUSTER_NAME: ${{ vars.AWS_DOGFOOD_CLUSTER_NAME }}
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.8.2"
version: "2.7.0"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
# Retag image as dogfood while maintaining the multi-arch manifest
- name: Tag image as dogfood
@@ -142,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -48,7 +48,7 @@ jobs:
persist-credentials: false
- name: Docker login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v45.0.7
- uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.7
id: changed-files
with:
files: |
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -78,11 +78,11 @@ jobs:
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -30,7 +30,7 @@ jobs:
- name: Sync issues
id: sync
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
@@ -52,7 +52,7 @@ jobs:
- name: Complete release
id: complete
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+6 -6
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -14,12 +14,12 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Run Schmoder CI
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
with:
workflow: ci.yaml
repo: coder/schmoder
+12 -10
View File
@@ -80,7 +80,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -155,7 +155,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -358,7 +358,7 @@ jobs:
id: attest_base
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -474,7 +474,7 @@ jobs:
id: attest_main
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -518,7 +518,7 @@ jobs:
id: attest_latest
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -665,7 +665,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: release-artifacts
path: |
@@ -681,7 +681,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -700,11 +700,13 @@ jobs:
name: Publish to Homebrew tap
runs-on: ubuntu-latest
needs: release
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
if: ${{ !inputs.dry_run }}
steps:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -780,7 +782,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: SARIF file
path: results.sarif
+114 -1
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -63,3 +63,116 @@ jobs:
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
trivy:
permissions:
security-events: write
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup Node
uses: ./.github/actions/setup-node
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Install cosign
uses: ./.github/actions/install-cosign
- name: Install syft
uses: ./.github/actions/install-syft
- name: Install yq
run: go run github.com/mikefarah/yq/v4@v4.44.3
- name: Install mockgen
run: ./.github/scripts/retry.sh -- go install go.uber.org/mock/mockgen@v0.6.0
- name: Install protoc-gen-go
run: ./.github/scripts/retry.sh -- go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
- name: Install protoc-gen-go-drpc
run: ./.github/scripts/retry.sh -- go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.34
- name: Install Protoc
run: |
# protoc must be in lockstep with our dogfood Dockerfile or the
# version in the comments will differ. This is also defined in
# ci.yaml.
set -euxo pipefail
cd dogfood/coder
mkdir -p /usr/local/bin
mkdir -p /usr/local/include
DOCKER_BUILDKIT=1 docker build . --target proto -t protoc
protoc_path=/usr/local/bin/protoc
docker run --rm --entrypoint cat protoc /tmp/bin/protoc > $protoc_path
chmod +x $protoc_path
protoc --version
# Copy the generated files to the include directory.
docker run --rm -v /usr/local/include:/target protoc cp -r /tmp/include/google /target/
ls -la /usr/local/include/google/protobuf/
stat /usr/local/include/google/protobuf/timestamp.proto
- name: Build Coder linux amd64 Docker image
id: build
run: |
set -euo pipefail
version="$(./scripts/version.sh)"
image_job="build/coder_${version}_linux_amd64.tag"
# This environment variable force make to not build packages and
# archives (which the Docker image depends on due to technical reasons
# related to concurrent FS writes).
export DOCKER_IMAGE_NO_PREREQUISITES=true
# This environment variables forces scripts/build_docker.sh to build
# the base image tag locally instead of using the cached version from
# the registry.
CODER_IMAGE_BUILD_BASE_TAG="$(CODER_IMAGE_BASE=coder-base ./scripts/image_tag.sh --version "$version")"
export CODER_IMAGE_BUILD_BASE_TAG
# We would like to use make -j here, but it doesn't work with the some recent additions
# to our code generation.
make "$image_job"
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
output: trivy-results.sarif
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: trivy
path: trivy-results.sarif
retention-days: 7
- name: Send Slack notification on failure
if: ${{ failure() }}
run: |
msg="❌ Trivy Failed\n\nhttps://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
curl \
-qfsSL \
-X POST \
-H "Content-Type: application/json" \
--data "{\"content\": \"$msg\"}" \
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-295
View File
@@ -1,295 +0,0 @@
# This workflow reimplements the AI Triage Automation using the Coder Chat API
# instead of the Tasks API. The Chat API (/api/experimental/chats) is a simpler
# interface that does not require a dedicated GitHub Action or workspace
# provisioning — we just create a chat, poll for completion, and link the
# result on the issue. All API calls use curl + jq directly.
#
# Key differences from the Tasks API workflow (traiage.yaml):
# - No checkout of coder/create-task-action; everything is inline curl/jq.
# - No template_name / template_preset / prefix inputs — the Chat API handles
# resource allocation internally.
# - Uses POST /api/experimental/chats to create a chat session.
# - Polls GET /api/experimental/chats/<id> until the agent finishes.
# - Chat URL format: ${CODER_URL}/agents?chat=${CHAT_ID}
name: AI Triage via Chat API
on:
issues:
types:
- labeled
workflow_dispatch:
inputs:
issue_url:
description: "GitHub Issue URL to process"
required: true
type: string
permissions:
contents: read
jobs:
triage-chat:
name: Triage GitHub Issue via Chat API
runs-on: ubuntu-latest
if: github.event.label.name == 'chat-triage' || github.event_name == 'workflow_dispatch'
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
permissions:
contents: read
issues: write
steps:
# ------------------------------------------------------------------
# Step 1: Determine the GitHub user and issue URL.
# Identical to the Tasks API workflow — resolve the actor for
# workflow_dispatch or the issue sender for label events.
# ------------------------------------------------------------------
- name: Determine Inputs
id: determine-inputs
if: always()
env:
GITHUB_ACTOR: ${{ github.actor }}
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# For workflow_dispatch, use the actor who triggered it.
# For issues events, use the issue sender.
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
exit 1
fi
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
exit 0
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
exit 0
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
# ------------------------------------------------------------------
# Step 2: Verify the triggering user has push access.
# Unchanged from the Tasks API workflow.
# ------------------------------------------------------------------
- name: Verify push access
env:
GITHUB_REPOSITORY: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
set -euo pipefail
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
if [[ "${can_push}" != "true" ]]; then
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
exit 1
fi
# ------------------------------------------------------------------
# Step 3: Create a chat via the Coder Chat API.
# Unlike the Tasks API which provisions a full workspace, the Chat
# API creates a lightweight chat session. We POST to
# /api/experimental/chats with the triage prompt as the initial
# message and receive a chat ID back.
# ------------------------------------------------------------------
- name: Create chat via Coder Chat API
id: create-chat
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# Build the same triage prompt used by the Tasks API workflow.
TASK_PROMPT=$(cat <<'EOF'
Fix ${ISSUE_URL}
1. Use the gh CLI to read the issue description and comments.
2. Think carefully and try to understand the root cause. If the issue is unclear or not well defined, ask me to clarify and provide more information.
3. Write a proposed implementation plan to PLAN.md for me to review before starting implementation. Your plan should use TDD and only make the minimal changes necessary to fix the root cause.
4. When I approve your plan, start working on it. If you encounter issues with the plan, ask me for clarification and update the plan as required.
5. When you have finished implementation according to the plan, commit and push your changes, and create a PR using the gh CLI for me to review.
EOF
)
# Perform variable substitution on the prompt — scoped to $ISSUE_URL only.
# Using envsubst without arguments would expand every env var in scope
# (including CODER_SESSION_TOKEN), so we name the variable explicitly.
TASK_PROMPT=$(echo "${TASK_PROMPT}" | envsubst '$ISSUE_URL')
echo "Creating chat with prompt:"
echo "${TASK_PROMPT}"
# POST to the Chat API to create a new chat session.
RESPONSE=$(curl --silent --fail-with-body \
-X POST \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
-H "Content-Type: application/json" \
-d "$(jq -n --arg prompt "${TASK_PROMPT}" \
'{content: [{type: "text", text: $prompt}]}')" \
"${CODER_URL}/api/experimental/chats")
echo "Chat API response:"
echo "${RESPONSE}" | jq .
CHAT_ID=$(echo "${RESPONSE}" | jq -r '.id')
CHAT_STATUS=$(echo "${RESPONSE}" | jq -r '.status')
if [[ -z "${CHAT_ID}" || "${CHAT_ID}" == "null" ]]; then
echo "::error::Failed to create chat — no ID returned"
echo "Response: ${RESPONSE}"
exit 1
fi
# Validate that CHAT_ID is a UUID before using it in URL paths.
# This guards against unexpected API responses being interpolated
# into subsequent curl calls.
if [[ ! "${CHAT_ID}" =~ ^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$ ]]; then
echo "::error::CHAT_ID is not a valid UUID: ${CHAT_ID}"
exit 1
fi
CHAT_URL="${CODER_URL}/agents?chat=${CHAT_ID}"
echo "Chat created: ${CHAT_ID} (status: ${CHAT_STATUS})"
echo "Chat URL: ${CHAT_URL}"
echo "chat_id=${CHAT_ID}" >> "${GITHUB_OUTPUT}"
echo "chat_url=${CHAT_URL}" >> "${GITHUB_OUTPUT}"
# ------------------------------------------------------------------
# Step 4: Poll the chat status until the agent finishes.
# The Chat API is asynchronous — after creation the agent begins
# working in the background. We poll GET /api/experimental/chats/<id>
# every 5 seconds until the status is "waiting" (agent needs input),
# "completed" (agent finished), or "error". Timeout after 10 minutes.
# ------------------------------------------------------------------
- name: Poll chat status
id: poll-status
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
run: |
set -euo pipefail
POLL_INTERVAL=5
# 10 minutes = 600 seconds.
TIMEOUT=600
ELAPSED=0
echo "Polling chat ${CHAT_ID} every ${POLL_INTERVAL}s (timeout: ${TIMEOUT}s)..."
while true; do
RESPONSE=$(curl --silent --fail-with-body \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
"${CODER_URL}/api/experimental/chats/${CHAT_ID}")
STATUS=$(echo "${RESPONSE}" | jq -r '.status')
echo "[${ELAPSED}s] Chat status: ${STATUS}"
case "${STATUS}" in
waiting|completed)
echo "Chat reached terminal status: ${STATUS}"
echo "final_status=${STATUS}" >> "${GITHUB_OUTPUT}"
exit 0
;;
error)
echo "::error::Chat entered error state"
echo "${RESPONSE}" | jq .
echo "final_status=error" >> "${GITHUB_OUTPUT}"
exit 1
;;
pending|running)
# Still working — keep polling.
;;
*)
echo "::warning::Unknown chat status: ${STATUS}"
;;
esac
if [[ ${ELAPSED} -ge ${TIMEOUT} ]]; then
echo "::error::Timed out after ${TIMEOUT}s waiting for chat to finish"
echo "final_status=timeout" >> "${GITHUB_OUTPUT}"
exit 1
fi
sleep "${POLL_INTERVAL}"
ELAPSED=$((ELAPSED + POLL_INTERVAL))
done
# ------------------------------------------------------------------
# Step 5: Comment on the GitHub issue with a link to the chat.
# Only comment if the issue belongs to this repository (same guard
# as the Tasks API workflow).
# ------------------------------------------------------------------
- name: Comment on issue
if: startsWith(steps.determine-inputs.outputs.issue_url, format('{0}/{1}', github.server_url, github.repository))
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
COMMENT_BODY=$(cat <<EOF
🤖 **AI Triage Chat Created**
A Coder chat session has been created to investigate this issue.
**Chat URL:** ${CHAT_URL}
**Chat ID:** \`${CHAT_ID}\`
**Status:** ${FINAL_STATUS}
The agent is working on a triage plan. Visit the chat to follow progress or provide guidance.
EOF
)
gh issue comment "${ISSUE_URL}" --body "${COMMENT_BODY}"
echo "Comment posted on ${ISSUE_URL}"
# ------------------------------------------------------------------
# Step 6: Write a summary to the GitHub Actions step summary.
# ------------------------------------------------------------------
- name: Write summary
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
set -euo pipefail
{
echo "## AI Triage via Chat API"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Chat ID:** \`${CHAT_ID}\`"
echo "**Chat URL:** ${CHAT_URL}"
echo "**Status:** ${FINAL_STATUS}"
} >> "${GITHUB_STEP_SUMMARY}"
-2
View File
@@ -29,8 +29,6 @@ EDE = "EDE"
HELO = "HELO"
LKE = "LKE"
byt = "byt"
cpy = "cpy"
Cpy = "Cpy"
typ = "typ"
# file extensions used in seti icon theme
styl = "styl"
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+5 -19
View File
@@ -146,20 +146,11 @@ git config core.hooksPath scripts/githooks
Two hooks run automatically:
- **pre-commit**: Classifies staged files by type and runs either
the full `make pre-commit` or the lightweight `make pre-commit-light`
depending on whether Go, TypeScript, SQL, proto, or Makefile
changes are present. Falls back to the full target when
`CODER_HOOK_RUN_ALL=1` is set. A markdown-only commit takes
seconds; a Go change takes several minutes.
- **pre-push**: Classifies changed files (vs remote branch or
merge-base) and runs `make pre-push` when Go, TypeScript, SQL,
proto, or Makefile changes are detected. Skips tests entirely
for lightweight changes. Allowlisted in
`scripts/githooks/pre-push`. Runs only for developers who opt
in. Falls back to `make pre-push` when the diff range can't
be determined or `CODER_HOOK_RUN_ALL=1` is set. Allow at least
15 minutes for a full run.
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (heavier checks including tests).
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
who opt in. Allow at least 15 minutes.
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
@@ -217,11 +208,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
- PR titles follow the same `type(scope): message` format.
- When you use a scope, it must be a real filesystem path containing every
changed file.
- Use a broader path scope, or omit the scope, for cross-cutting changes.
- Example: `fix(coderd/chatd): ...` for changes only in `coderd/chatd/`.
### Frontend Patterns
+10 -42
View File
@@ -136,10 +136,18 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -506,12 +514,6 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
GREEN := $(shell tput setaf 2 2>/dev/null)
RED := $(shell tput setaf 1 2>/dev/null)
@@ -522,10 +524,6 @@ RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
.PHONY: fmt
# Subset of fmt that does not require Go or Node toolchains.
fmt-light: fmt/shfmt fmt/terraform fmt/markdown
.PHONY: fmt-light
fmt/go:
ifdef FILE
# Format single file
@@ -633,10 +631,6 @@ LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
.PHONY: lint
# Subset of lint that does not require Go or Node toolchains.
lint-light: lint/shellcheck lint/markdown lint/helm lint/bootstrap lint/migrations lint/actions/actionlint lint/typos
.PHONY: lint-light
lint/site-icons:
./scripts/check_site_icons.sh
.PHONY: lint/site-icons
@@ -779,25 +773,6 @@ pre-commit:
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit
# Lightweight pre-commit for changes that don't touch Go or
# TypeScript. Skips gen, lint/go, lint/ts, fmt/go, fmt/ts, and
# the binary build. Used by the pre-commit hook when only docs,
# shell, terraform, helm, or other fast-to-check files changed.
pre-commit-light:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit-light.XXXXXX")
echo "$(BOLD)pre-commit-light$(RESET) ($$logdir)"
echo "fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir fmt-light
$(check-unstaged)
echo "lint:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir lint-light
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit-light passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit-light
pre-push:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
@@ -806,7 +781,6 @@ pre-push:
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
test \
test-js \
test-storybook \
site/out/index.html
rm -rf $$logdir
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
@@ -1341,12 +1315,6 @@ test-js: site/node_modules/.installed
pnpm test:ci
.PHONY: test-js
test-storybook: site/node_modules/.installed
cd site/
pnpm playwright:install
pnpm exec vitest run --project=storybook
.PHONY: test-storybook
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
# dependency for any sqlc-cloud related targets.
sqlc-cloud-is-setup:
+2 -7
View File
@@ -385,16 +385,11 @@ func (a *agent) init() {
pathStore := agentgit.NewPathStore()
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
+9 -21
View File
@@ -713,15 +713,15 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
},
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
setSBInterval := func(_ *agenttest.Client, opts *agent.Options) {
opts.ServiceBannerRefreshInterval = testutil.IntervalFast
opts.ServiceBannerRefreshInterval = 5 * time.Millisecond
}
//nolint:dogsled // Allow the blank identifiers.
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:paralleltest // These tests need to swap the banner func.
for _, port := range sshPorts {
sshClient, err := conn.SSHClientOnPort(ctx, port)
@@ -733,10 +733,7 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
for i, test := range tests {
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner. We wait for two calls to ensure the value has been stored:
// the second call can only begin after the first iteration of
// fetchServiceBannerLoop completes (call + store), so after
// receiving two signals at least one store has happened.
// banner.
ready := make(chan struct{}, 2)
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
select {
@@ -745,8 +742,8 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
}
return []codersdk.BannerConfig{test.banner}, nil
})
testutil.TryReceive(ctx, t, ready)
testutil.TryReceive(ctx, t, ready)
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
session, err := sshClient.NewSession()
require.NoError(t, err)
@@ -3553,17 +3550,8 @@ func testSessionOutput(t *testing.T, session *ssh.Session, expected, unexpected
require.NoError(t, err)
ptty.WriteLine("exit 0")
waitErr := make(chan error, 1)
go func() {
waitErr <- session.Wait()
}()
select {
case err = <-waitErr:
require.NoError(t, err)
case <-time.After(testutil.WaitLong):
require.Fail(t, "timed out waiting for session to exit")
}
err = session.Wait()
require.NoError(t, err)
for _, unexpected := range unexpected {
require.NotContains(t, stdout.String(), unexpected, "should not show output")
-14
View File
@@ -57,26 +57,18 @@ type fakeContainerCLI struct {
}
func (f *fakeContainerCLI) List(_ context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.containers, f.listErr
}
func (f *fakeContainerCLI) DetectArchitecture(_ context.Context, _ string) (string, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.arch, f.archErr
}
func (f *fakeContainerCLI) Copy(ctx context.Context, name, src, dst string) error {
f.mu.Lock()
defer f.mu.Unlock()
return f.copyErr
}
func (f *fakeContainerCLI) ExecAs(ctx context.Context, name, user string, args ...string) ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()
return nil, f.execErr
}
@@ -2697,9 +2689,7 @@ func TestAPI(t *testing.T) {
// When: The container is recreated (new container ID) with config changes.
terraformContainer.ID = "new-container-id"
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = terraformContainer.ID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
@@ -2831,9 +2821,7 @@ func TestAPI(t *testing.T) {
// Simulate container rebuild: new container ID, changed display apps.
newContainerID := "new-container-id"
terraformContainer.ID = newContainerID
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = newContainerID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
@@ -4938,11 +4926,9 @@ func TestDevcontainerPrebuildSupport(t *testing.T) {
)
api.Start()
fCCLI.mu.Lock()
fCCLI.containers = codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{testContainer},
}
fCCLI.mu.Unlock()
// Given: We allow the dev container to be created.
fDCCLI.upID = testContainer.ID
+46 -31
View File
@@ -2,6 +2,7 @@ package agentdesktop
import (
"encoding/json"
"math"
"net/http"
"strconv"
"time"
@@ -12,7 +13,6 @@ import (
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
@@ -26,9 +26,9 @@ type DesktopAction struct {
Duration *int `json:"duration,omitempty"`
ScrollAmount *int `json:"scroll_amount,omitempty"`
ScrollDirection *string `json:"scroll_direction,omitempty"`
// ScaledWidth and ScaledHeight describe the declared model-facing desktop
// geometry. When provided, input coordinates are mapped from declared space
// to native desktop pixels before dispatching.
// ScaledWidth and ScaledHeight are the coordinate space the
// model is using. When provided, coordinates are linearly
// mapped from scaled → native before dispatching.
ScaledWidth *int `json:"scaled_width,omitempty"`
ScaledHeight *int `json:"scaled_height,omitempty"`
}
@@ -144,8 +144,17 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
geometry := desktopGeometryForAction(cfg, action)
scaleXY := geometry.DeclaredPointToNative
// Helper to scale a coordinate pair from the model's space to
// native display pixels.
scaleXY := func(x, y int) (int, int) {
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
}
return x, y
}
var resp DesktopActionResponse
@@ -183,7 +192,7 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
resp.Output = "type action performed"
case "cursor_position":
nativeX, nativeY, err := a.desktop.CursorPosition(ctx)
x, y, err := a.desktop.CursorPosition(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Cursor position failed.",
@@ -191,7 +200,6 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
})
return
}
x, y := geometry.NativePointToDeclared(nativeX, nativeY)
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
case "mouse_move":
@@ -439,10 +447,14 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
resp.Output = "hold_key action performed"
case "screenshot":
result, err := a.desktop.Screenshot(ctx, ScreenshotOptions{
TargetWidth: geometry.DeclaredWidth,
TargetHeight: geometry.DeclaredHeight,
})
var opts ScreenshotOptions
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
opts.TargetWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
opts.TargetHeight = *action.ScaledHeight
}
result, err := a.desktop.Screenshot(ctx, opts)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Screenshot failed.",
@@ -452,8 +464,16 @@ func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
}
resp.Output = "screenshot"
resp.ScreenshotData = result.Data
resp.ScreenshotWidth = geometry.DeclaredWidth
resp.ScreenshotHeight = geometry.DeclaredHeight
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
resp.ScreenshotWidth = *action.ScaledWidth
} else {
resp.ScreenshotWidth = cfg.Width
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
resp.ScreenshotHeight = *action.ScaledHeight
} else {
resp.ScreenshotHeight = cfg.Height
}
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -492,23 +512,6 @@ func coordFromAction(action DesktopAction) (x, y int, err error) {
return action.Coordinate[0], action.Coordinate[1], nil
}
func desktopGeometryForAction(cfg DisplayConfig, action DesktopAction) workspacesdk.DesktopGeometry {
declaredWidth := cfg.Width
declaredHeight := cfg.Height
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
declaredWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
declaredHeight = *action.ScaledHeight
}
return workspacesdk.NewDesktopGeometryWithDeclared(
cfg.Width,
cfg.Height,
declaredWidth,
declaredHeight,
)
}
// missingFieldError is returned when a required field is absent from
// a DesktopAction.
type missingFieldError struct {
@@ -519,3 +522,15 @@ type missingFieldError struct {
func (e *missingFieldError) Error() string {
return "Missing \"" + e.field + "\" for " + e.action + " action."
}
// scaleCoordinate maps a coordinate from scaled → native space.
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
if scaledDim == 0 || scaledDim == nativeDim {
return scaled
}
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
// Clamp to valid range.
native = math.Max(native, 0)
native = math.Min(native, float64(nativeDim-1))
return int(native)
}
+16 -125
View File
@@ -27,12 +27,10 @@ var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
// fakeDesktop is a minimal Desktop implementation for unit tests.
type fakeDesktop struct {
startErr error
cursorPos [2]int
startCfg agentdesktop.DisplayConfig
vncConnErr error
screenshotErr error
screenshotRes agentdesktop.ScreenshotResult
lastShotOpts agentdesktop.ScreenshotOptions
closed bool
// Track calls for assertions.
@@ -53,8 +51,7 @@ func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
return nil, f.vncConnErr
}
func (f *fakeDesktop) Screenshot(_ context.Context, opts agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
f.lastShotOpts = opts
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
return f.screenshotRes, f.screenshotErr
}
@@ -103,8 +100,8 @@ func (f *fakeDesktop) Type(_ context.Context, text string) error {
return nil
}
func (f *fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return f.cursorPos[0], f.cursorPos[1], nil
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return 10, 20, nil
}
func (f *fakeDesktop) Close() error {
@@ -138,12 +135,8 @@ func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
geometry := workspacesdk.DefaultDesktopGeometry()
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{
Width: geometry.NativeWidth,
Height: geometry.NativeHeight,
},
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
@@ -165,52 +158,11 @@ func TestHandleAction_Screenshot(t *testing.T) {
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
// Dimensions come from DisplayConfig, not the screenshot CLI.
assert.Equal(t, "screenshot", result.Output)
assert.Equal(t, "base64data", result.ScreenshotData)
assert.Equal(t, geometry.NativeWidth, result.ScreenshotWidth)
assert.Equal(t, geometry.NativeHeight, result.ScreenshotHeight)
assert.Equal(t, agentdesktop.ScreenshotOptions{
TargetWidth: geometry.NativeWidth,
TargetHeight: geometry.NativeHeight,
}, fake.lastShotOpts)
}
func TestHandleAction_ScreenshotUsesDeclaredDimensionsFromRequest(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "screenshot",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, agentdesktop.ScreenshotOptions{TargetWidth: 1280, TargetHeight: 720}, fake.lastShotOpts)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
assert.Equal(t, 1280, result.ScreenshotWidth)
assert.Equal(t, 720, result.ScreenshotHeight)
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
}
func TestHandleAction_LeftClick(t *testing.T) {
@@ -363,6 +315,7 @@ func TestHandleAction_HoldKey(t *testing.T) {
handler.ServeHTTP(rr, req)
}()
// Wait for the timer to be created, then advance past it.
trap.MustWait(req.Context()).MustRelease(req.Context())
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
@@ -436,6 +389,7 @@ func TestHandleAction_ScrollDown(t *testing.T) {
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// dy should be positive 5 for "down".
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
}
@@ -444,11 +398,13 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
// Native display is 1920x1080.
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
// Model is working in a 1280x720 coordinate space.
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
@@ -468,43 +424,12 @@ func TestHandleAction_CoordinateScaling(t *testing.T) {
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
// midpoint).
assert.Equal(t, 960, fake.lastMove[0])
assert.Equal(t, 540, fake.lastMove[1])
}
func TestHandleAction_CoordinateScalingClampsToLastPixel(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1366
sh := 768
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{1365, 767},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, 1919, fake.lastMove[0])
assert.Equal(t, 1079, fake.lastMove[1])
}
func TestClose_DelegatesToDesktop(t *testing.T) {
t.Parallel()
@@ -521,12 +446,15 @@ func TestClose_PreventsNewSessions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// After Close(), Start() will return an error because the
// underlying Desktop is closed.
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
// Simulate the closed desktop returning an error on Start().
fake.startErr = xerrors.New("desktop is closed")
rr := httptest.NewRecorder()
@@ -537,40 +465,3 @@ func TestClose_PreventsNewSessions(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
func TestHandleAction_CursorPositionReturnsDeclaredCoordinates(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
cursorPos: [2]int{960, 540},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "cursor_position",
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
// Native (960,540) in 1920x1080 should map to declared space in 1280x720.
assert.Equal(t, "x=640,y=360", resp.Output)
}
+170 -25
View File
@@ -2,9 +2,13 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -20,6 +24,28 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
@@ -52,31 +78,43 @@ type screenshotOutput struct {
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
dataDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: execer,
dataDir: dataDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
@@ -111,7 +149,7 @@ func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopNativeWidth, workspacesdk.DesktopNativeHeight))
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
@@ -361,8 +399,8 @@ func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, e
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
@@ -377,23 +415,130 @@ func (p *portableDesktop) ensureBinary(ctx context.Context) error {
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
)
p.binPath = scriptBinPath
p.binPath = cachedPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
}
@@ -2,6 +2,11 @@ package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
@@ -72,6 +77,7 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
@@ -82,13 +88,13 @@ func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := t.Context()
ctx := context.Background()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
@@ -105,6 +111,7 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -113,13 +120,13 @@ func TestPortableDesktop_Start_Idempotent(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
@@ -147,6 +154,7 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -155,13 +163,13 @@ func TestPortableDesktop_Screenshot(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
@@ -172,6 +180,7 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
@@ -180,13 +189,13 @@ func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
@@ -278,13 +287,13 @@ func TestPortableDesktop_MouseMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -363,13 +372,13 @@ func TestPortableDesktop_KeyboardMethods(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
@@ -395,13 +404,13 @@ func TestPortableDesktop_CursorPosition(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(t.Context())
x, y, err := pd.CursorPosition(context.Background())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
@@ -419,13 +428,13 @@ func TestPortableDesktop_Close(t *testing.T) {
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
ctx := context.Background()
_, err := pd.Start(ctx)
require.NoError(t, err)
@@ -448,6 +457,81 @@ func TestPortableDesktop_Close(t *testing.T) {
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
@@ -457,89 +541,173 @@ func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, binPath, pd.binPath)
assert.Equal(t, cachedPath, pd.binPath)
}
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
func TestEnsureBinary_Downloads(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
}
// Clear PATH so LookPath won't find a real binary.
// Ensure PATH doesn't contain a real portabledesktop binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
}
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+70 -154
View File
@@ -14,6 +14,7 @@ import (
"syscall"
"github.com/google/uuid"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -332,18 +333,25 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
return status, err
}
// Check if the target already exists so we can preserve its
// permissions on the temp file before rename.
var mode *os.FileMode
if stat, serr := api.filesystem.Stat(path); serr == nil {
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: is a directory", path)
f, err := api.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
m := stat.Mode()
mode = &m
return status, err
}
defer f.Close()
_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
api.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}
return api.atomicWrite(ctx, path, mode, r.Body)
return 0, nil
}
func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
@@ -439,163 +447,84 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
content := string(data)
for _, edit := range edits {
var err error
content, err = fuzzyReplace(content, edit)
if err != nil {
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
}
}
m := stat.Mode()
return api.atomicWrite(ctx, path, &m, strings.NewReader(content))
}
// atomicWrite writes content from r to path via a temp file in the
// same directory. If the target exists, its permissions are preserved.
// On failure the temp file is cleaned up and the original is
// untouched.
func (api *API) atomicWrite(ctx context.Context, path string, mode *os.FileMode, r io.Reader) (int, error) {
dir := filepath.Dir(path)
tmpName := filepath.Join(dir, fmt.Sprintf(".%s.tmp.%s", filepath.Base(path), uuid.New().String()[:8]))
tmpfile, err := api.filesystem.OpenFile(tmpName, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o666)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
}
cleanup := func() {
if err := api.filesystem.Remove(tmpName); err != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(err))
}
}
_, err = io.Copy(tmpfile, r)
if err != nil {
_ = tmpfile.Close()
cleanup()
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
cleanup()
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if mode != nil {
if err := api.filesystem.Chmod(tmpName, *mode); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.Error(err),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
}
if err := api.filesystem.Rename(tmpName, path); err != nil {
cleanup()
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
// Create an adjacent file to ensure it will be on the same device and can be
// moved atomically.
tmpfile, err := afero.TempFile(api.filesystem, filepath.Dir(path), filepath.Base(path))
if err != nil {
return http.StatusInternalServerError, err
}
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return status, xerrors.Errorf("write %s: %w", path, err)
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
err = api.filesystem.Rename(tmpfile.Name(), path)
if err != nil {
return http.StatusInternalServerError, err
}
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace
// (indentation-tolerant).
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When edit.ReplaceAll is false (the default), the search string must
// match exactly one location. If multiple matches are found, an error
// is returned asking the caller to include more context or set
// replace_all.
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still
// applied at the byte offsets of the original content so that surrounding
// text (including indentation of untouched lines) is preserved.
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
search := edit.Search
replace := edit.Replace
// Pass 1 exact substring match.
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
if edit.ReplaceAll {
return strings.ReplaceAll(content, search, replace), nil
}
count := strings.Count(content, search)
if count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
// Exactly one match.
return strings.Replace(content, search, replace, 1), nil
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search
// into lines.
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element
// from SplitAfter. Drop it so it doesn't interfere with line
// matching.
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
trimRight := func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}
trimAll := func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return "", xerrors.New("search string not found in file. Verify the search " +
"string matches the file content exactly, including whitespace " +
"and indentation")
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
@@ -620,26 +549,6 @@ outer:
return 0, 0, false
}
// countLineMatches counts how many non-overlapping contiguous
// subsequences of contentLines match searchLines according to eq.
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
count := 0
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
return count
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
count++
i += len(searchLines) - 1 // skip past this match
}
return count
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
@@ -653,3 +562,10 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+3 -209
View File
@@ -14,7 +14,6 @@ import (
"strings"
"syscall"
"testing"
"testing/iotest"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -400,83 +399,6 @@ func TestWriteFile(t *testing.T) {
}
}
func TestWriteFile_ReportsIOError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, nil)
tmpdir := os.TempDir()
path := filepath.Join(tmpdir, "write-io-error")
err := afero.WriteFile(fs, path, []byte("original"), 0o644)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// A reader that always errors simulates a failed body read
// (e.g. network interruption). The atomic write should leave
// the original file intact.
body := iotest.ErrReader(xerrors.New("simulated I/O error"))
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path), body)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
got := &codersdk.Error{}
err = json.NewDecoder(w.Body).Decode(got)
require.NoError(t, err)
require.ErrorContains(t, got, "simulated I/O error")
// The original file must survive the failed write.
data, err := afero.ReadFile(fs, path)
require.NoError(t, err)
require.Equal(t, "original", string(data))
}
func TestWriteFile_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Overwrite the file with new content.
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path),
bytes.NewReader([]byte("#!/bin/sh\necho world\n")))
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"write_file should preserve the original file's permissions")
}
func TestEditFiles(t *testing.T) {
t.Parallel()
@@ -636,8 +558,6 @@ func TestEditFiles(t *testing.T) {
},
errCode: http.StatusInternalServerError,
errors: []string{"rename failed"},
// Original file must survive the failed rename.
expected: map[string]string{failRenameFilePath: "foo bar"},
},
{
name: "Edit1",
@@ -656,9 +576,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
// When the second edit creates ambiguity (two "bar"
// occurrences), it should fail.
name: "EditEditAmbiguous",
name: "EditEdit", // Edits affect previous edits.
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
@@ -675,33 +593,7 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 2 occurrences"},
// File should not be modified on error.
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
},
{
// With replace_all the cascading edit replaces
// both occurrences.
name: "EditEditReplaceAll",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit-ra"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
},
{
name: "Multiline",
@@ -828,7 +720,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchErrors",
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
@@ -841,46 +733,9 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"search string not found in file"},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "AmbiguousExactMatch",
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ambig-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 3 occurrences"},
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
},
{
name: "ReplaceAllExact",
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
@@ -987,67 +842,6 @@ func TestEditFiles(t *testing.T) {
}
}
func TestEditFiles_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
// Sanity-check the initial mode.
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
body := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: path,
Edits: []workspacesdk.FileEdit{
{
Search: "hello",
Replace: "world",
},
},
},
},
}
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// Verify content was updated.
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
// Verify permissions are preserved after the
// temp-file-and-rename cycle.
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"edit_files should preserve the original file's permissions")
}
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
+2 -60
View File
@@ -1,13 +1,11 @@
package agentproc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -20,13 +18,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// maxWaitDuration is the maximum time a blocking
// process output request can wait, regardless of
// what the client requests.
maxWaitDuration = 5 * time.Minute
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
@@ -35,10 +26,10 @@ type API struct {
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv, workingDir),
manager: newManager(logger, execer, updateEnv),
pathStore: pathStore,
}
}
@@ -160,44 +151,6 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
return
}
// Enforce chat ID isolation. If the request carries
// a chat context, only allow access to processes
// belonging to that chat.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
if proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
// Check for blocking mode via query params.
waitStr := r.URL.Query().Get("wait")
wantWait := waitStr == "true"
if wantWait {
// Extend the write deadline so the HTTP server's
// WriteTimeout does not kill the connection while
// we block.
rc := http.NewResponseController(rw)
// Add headroom beyond the wait timeout so there's time to
// write the response after the blocking wait completes.
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration + 30*time.Second)); err != nil {
api.logger.Error(ctx, "extend write deadline for blocking process output",
slog.Error(err),
)
}
// Cap the wait at maxWaitDuration regardless of
// client-supplied timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, maxWaitDuration)
defer waitCancel()
_ = proc.waitForOutput(waitCtx)
// Fall through to read snapshot below.
}
output, truncated := proc.output()
info := proc.info()
@@ -215,17 +168,6 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// Enforce chat ID isolation.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
proc, procOK := api.manager.get(id)
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
+3 -277
View File
@@ -7,10 +7,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -78,22 +76,6 @@ func getOutput(t *testing.T, handler http.Handler, id string) *httptest.Response
return w
}
// getOutputWithHeaders sends a GET /{id}/output request with
// custom headers and returns the recorder.
func getOutputWithHeaders(t *testing.T, handler http.Handler, id string, headers http.Header) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
path := fmt.Sprintf("/%s/output", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
for k, v := range headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
@@ -115,25 +97,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, nil, nil)
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, updateEnv, nil)
}
// newTestAPIWithOptions creates a new API with optional
// updateEnv and workingDir hooks.
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
t.Cleanup(func() {
_ = api.Close()
})
@@ -278,100 +253,6 @@ func TestStartProcess(t *testing.T) {
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
t.Parallel()
// No working directory closure, so the process
// should fall back to $HOME. We verify through
// the process list API which reports the resolved
// working directory using native OS paths,
// avoiding shell path format mismatches on
// Windows (Git Bash returns POSIX paths).
handler := newTestAPI(t)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
t.Parallel()
// The closure provides a valid directory, so the
// process should start there. Use the marker file
// pattern to avoid path format mismatches on
// Windows.
tmpDir := t.TempDir()
handler := newTestAPIWithOptions(t, nil, func() string {
return tmpDir
})
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
t.Parallel()
// The closure returns a path that doesn't exist,
// so the process should fall back to $HOME.
handler := newTestAPIWithOptions(t, nil, func() string {
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
})
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
@@ -756,161 +637,6 @@ func TestProcessOutput(t *testing.T) {
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
t.Run("ChatIDEnforcement", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process with chat-a.
chatA := uuid.New()
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo secret",
Background: true,
}, http.Header{
workspacesdk.CoderChatIDHeader: {chatA.String()},
})
waitForExit(t, handler, id)
// Chat-b should NOT see this process.
chatB := uuid.New()
w1 := getOutputWithHeaders(t, handler, id, http.Header{
workspacesdk.CoderChatIDHeader: {chatB.String()},
})
require.Equal(t, http.StatusNotFound, w1.Code)
// Without any chat ID header, should return 200
// (backwards compatible).
w2 := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w2.Code)
})
t.Run("WaitForExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-wait && sleep 0.1",
})
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-wait")
})
t.Run("WaitAlreadyExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, id)
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.Contains(t, resp.Output, "done")
})
t.Run("WaitTimeout", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.IntervalMedium)
defer cancel()
w := getOutputWithWaitCtx(ctx, t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("ConcurrentWaiters", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
var (
wg sync.WaitGroup
resps [2]workspacesdk.ProcessOutputResponse
codes [2]int
)
for i := range 2 {
wg.Add(1)
go func() {
defer wg.Done()
w := getOutputWithWait(t, handler, id)
codes[i] = w.Code
_ = json.NewDecoder(w.Body).Decode(&resps[i])
}()
}
// Signal the process to exit so both waiters unblock.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
wg.Wait()
for i := range 2 {
require.Equal(t, http.StatusOK, codes[i], "waiter %d", i)
require.False(t, resps[i].Running, "waiter %d", i)
}
})
}
func getOutputWithWait(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
return getOutputWithWaitCtx(ctx, t, handler, id)
}
func getOutputWithWaitCtx(ctx context.Context, t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
path := fmt.Sprintf("/%s/output?wait=true", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
func TestSignalProcess(t *testing.T) {
@@ -1055,7 +781,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
return current, nil
}, pathStore, nil)
}, pathStore)
defer api.Close()
routes := api.Routes()
+2 -19
View File
@@ -39,13 +39,11 @@ const (
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
cond *sync.Cond
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
closed bool
totalBytes int
maxHead int
maxTail int
@@ -54,24 +52,20 @@ type HeadTailBuffer struct {
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// Write implements io.Writer. It is safe for concurrent use.
@@ -302,15 +296,6 @@ func truncateLines(s string) string {
return b.String()
}
// Close marks the buffer as closed and wakes any waiters.
// This is called when the process exits.
func (b *HeadTailBuffer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
b.cond.Broadcast()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
@@ -320,7 +305,5 @@ func (b *HeadTailBuffer) Reset() {
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.closed = false
b.totalBytes = 0
b.cond.Broadcast()
}
-26
View File
@@ -1,26 +0,0 @@
//go:build !windows
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Unix, Setpgid creates a new process group so
// that signals can be delivered to the entire group (the shell
// and all its children).
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// signalProcess sends a signal to the process group rooted at p.
// Using the negative PID sends the signal to every process in the
// group, ensuring child processes (e.g. from shell pipelines) are
// also signaled.
func signalProcess(p *os.Process, sig syscall.Signal) error {
return syscall.Kill(-p.Pid, sig)
}
-20
View File
@@ -1,20 +0,0 @@
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Windows, process groups are not supported in the
// same way as Unix, so this returns an empty struct.
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}
// signalProcess sends a signal directly to the process. Windows
// does not support process group signaling, so we fall back to
// sending the signal to the process itself.
func signalProcess(p *os.Process, _ syscall.Signal) error {
return p.Kill()
}
+21 -78
View File
@@ -70,25 +70,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
workingDir func() string
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
workingDir: workingDir,
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
@@ -111,9 +109,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
cmd.Dir = m.resolveWorkDir(req.WorkDir)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
@@ -158,7 +157,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc := &process{
id: id,
command: req.Command,
workDir: cmd.Dir,
workDir: req.WorkDir,
background: req.Background,
chatID: chatID,
cmd: cmd,
@@ -208,9 +207,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc.exitCode = &code
proc.mu.Unlock()
// Wake any waiters blocked on new output or
// process exit before closing the done channel.
proc.buf.Close()
close(proc.done)
}()
@@ -276,15 +272,13 @@ func (m *manager) signal(id string, sig string) error {
switch sig {
case "kill":
// Use process group kill to ensure child processes
// (e.g. from shell pipelines) are also killed.
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
// Use process group signal to ensure child processes
// are also terminated.
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
@@ -322,54 +316,3 @@ func (m *manager) Close() error {
return nil
}
// waitForOutput blocks until the buffer is closed (process
// exited) or the context is canceled. Returns nil when the
// buffer closed, ctx.Err() when the context expired.
func (p *process) waitForOutput(ctx context.Context) error {
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
nevermind := make(chan struct{})
defer close(nevermind)
go func() {
select {
case <-ctx.Done():
// Acquire the lock before broadcasting to
// guarantee the waiter has entered cond.Wait()
// (which atomically releases the lock).
// Without this, a Broadcast between the loop
// predicate check and cond.Wait() is lost.
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
p.buf.cond.Broadcast()
case <-nevermind:
}
}()
for ctx.Err() == nil && !p.buf.closed {
p.buf.cond.Wait()
}
return ctx.Err()
}
// resolveWorkDir returns the directory a process should start in.
// Priority: explicit request dir > agent configured dir > $HOME.
// Falls through when a candidate is empty or does not exist on
// disk, matching the behavior of SSH sessions.
func (m *manager) resolveWorkDir(requested string) string {
if requested != "" {
return requested
}
if m.workingDir != nil {
if dir := m.workingDir(); dir != "" {
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return dir
}
}
}
if home, err := os.UserHomeDir(); err == nil {
return home
}
return ""
}
+2 -2
View File
@@ -398,11 +398,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
},
})
if err != nil {
logger.Warn(ctx, "reporting script completed", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
}
})
if err != nil {
logger.Warn(ctx, "reporting script completed: track command goroutine", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
}
}()
+27 -6
View File
@@ -6,6 +6,7 @@ import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"github.com/google/uuid"
@@ -22,6 +23,26 @@ import (
"github.com/coder/coder/v2/testutil"
)
// logSink captures structured log entries for testing.
type logSink struct {
mu sync.Mutex
entries []slog.SinkEntry
}
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, e)
}
func (*logSink) Sync() {}
func (s *logSink) getEntries() []slog.SinkEntry {
s.mu.Lock()
defer s.mu.Unlock()
return append([]slog.SinkEntry{}, s.entries...)
}
// getField returns the value of a field by name from a slog.Map.
func getField(fields slog.Map, name string) interface{} {
for _, f := range fields {
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
sink := testutil.NewFakeSink(t)
logger := sink.Logger(slog.LevelInfo)
sink := &logSink{}
logger := slog.Make(sink)
workspaceID := uuid.New()
templateID := uuid.New()
templateVersionID := uuid.New()
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 1
return len(sink.getEntries()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
entries := sink.Entries()
entries := sink.getEntries()
require.Len(t, entries, 1)
entry := entries[0]
require.Equal(t, slog.LevelInfo, entry.Level)
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req2)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 2
return len(sink.getEntries()) >= 2
}, testutil.WaitShort, testutil.IntervalFast)
entries = sink.Entries()
entries = sink.getEntries()
entry = entries[1]
require.Len(t, entries, 2)
require.Equal(t, slog.LevelInfo, entry.Level)
-9
View File
@@ -78,9 +78,6 @@ func withDone(t *testing.T) []reaper.Option {
// processes and passes their PIDs through the shared channel.
func TestReap(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
@@ -127,9 +124,6 @@ func TestReap(t *testing.T) {
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
func TestForkReapExitCodes(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
@@ -170,9 +164,6 @@ func TestForkReapExitCodes(t *testing.T) {
// ensures SIGINT cannot kill the parent test binary.
func TestReapInterrupt(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
-15
View File
@@ -46,7 +46,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
autoUpdates string
copyParametersFrom string
useParameterDefaults bool
noWait bool
// Organization context is only required if more than 1 template
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
@@ -373,14 +372,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
if noWait {
_, _ = fmt.Fprintf(inv.Stdout,
"\nThe %s workspace has been created and is building in the background.\n",
cliui.Keyword(workspace.Name),
)
return nil
}
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil {
return xerrors.Errorf("watch build: %w", err)
@@ -454,12 +445,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Description: "Automatically accept parameter defaults when no value is provided.",
Value: serpent.BoolOf(&useParameterDefaults),
},
serpent.Option{
Flag: "no-wait",
Env: "CODER_CREATE_NO_WAIT",
Description: "Return immediately after creating the workspace. The build will run in the background.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
)
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
-75
View File
@@ -603,81 +603,6 @@ func TestCreate(t *testing.T) {
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
}
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was actually created.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
})
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
}))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--use-parameter-defaults",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was created and parameters were applied.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
})
}
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
-6
View File
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+1 -16
View File
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
Annotations struct {
ReadOnlyHint *bool `json:"readOnlyHint"`
DestructiveHint *bool `json:"destructiveHint"`
IdempotentHint *bool `json:"idempotentHint"`
OpenWorldHint *bool `json:"openWorldHint"`
} `json:"annotations"`
Name string `json:"name"`
} `json:"tools"`
} `json:"result"`
}
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
annotations := toolsResponse.Result.Tools[0].Annotations
require.NotNil(t, annotations.ReadOnlyHint)
require.NotNil(t, annotations.DestructiveHint)
require.NotNil(t, annotations.IdempotentHint)
require.NotNil(t, annotations.OpenWorldHint)
assert.True(t, *annotations.ReadOnlyHint)
assert.False(t, *annotations.DestructiveHint)
assert.True(t, *annotations.IdempotentHint)
assert.False(t, *annotations.OpenWorldHint)
// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
+1 -1
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("create role: %w", err)
return xerrors.Errorf("patch role: %w", err)
}
}
-23
View File
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
)
build = workspace.LatestBuild
default:
// If the last build was a failed start, run a stop
// first to clean up any partially-provisioned
// resources.
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop,
})
if stopErr != nil {
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
}
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
if stopErr != nil {
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
}
// Re-fetch workspace after stop completes so
// startWorkspace sees the latest state.
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
}
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
// It's possible for a workspace build to fail due to the template requiring starting
// workspaces with the active version.
-52
View File
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
}
func TestStart_FailedStartCleansUp(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
store, ps := dbtestutil.NewDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: store,
Pubsub: ps,
IncludeProvisionerDaemon: true,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Insert a failed start build directly into the database so that
// the workspace's latest build is a failed "start" transition.
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
ID: workspace.ID,
OwnerID: member.ID,
OrganizationID: owner.OrganizationID,
TemplateID: template.ID,
}).
Seed(database.WorkspaceBuild{
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
}).
Failed().
Do()
inv, root := clitest.New(t, "start", workspace.Name)
clitest.SetupConfig(t, memberClient, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
// The CLI should detect the failed start and clean up first.
pty.ExpectMatch("Cleaning up before retrying")
pty.ExpectMatch("workspace has been started")
_ = testutil.TryReceive(ctx, t, doneChan)
}
+17 -26
View File
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
)
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
// Bypass rate limiting for support bundle collection since it makes many API calls.
// Note: this can only be done by the owner user.
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
cliLog.Debug(inv.Context(), "running as owner")
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
} else if !ok {
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
} else {
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
}
// Check if we're running inside a workspace
if val, found := os.LookupEnv("CODER"); found && val == "true" {
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
}
// Bypass rate limiting for support bundle collection since it makes many API calls.
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
deps := support.Deps{
Client: client,
// Support adds a sink so we don't need to supply one ourselves.
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
return
}
var docsURL string
if bun.Deployment.Config != nil {
docsURL = bun.Deployment.Config.Values.DocsURL.String()
} else {
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
if bun.Deployment.Config == nil {
cliui.Error(inv.Stdout, "No deployment configuration available!")
return
}
if bun.Deployment.HealthReport != nil {
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
} else {
cliui.Warn(inv.Stdout, "No deployment health report available.")
docsURL := bun.Deployment.Config.Values.DocsURL.String()
if bun.Deployment.HealthReport == nil {
cliui.Error(inv.Stdout, "No deployment health report available!")
return
}
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
if bun.Network.Netcheck == nil {
+22 -47
View File
@@ -28,9 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -52,21 +50,9 @@ func TestSupportBundle(t *testing.T) {
dc.Values.Prometheus.Enable = true
secretValue := uuid.NewString()
seedSecretDeploymentOptions(t, &dc, secretValue)
// Use a mock healthcheck function to avoid flaky DERP health
// checks in CI. The DERP checker performs real network operations
// (portmapper gateway probing, STUN) that can hang for 60s+ on
// macOS CI runners. Since this test validates support bundle
// generation, not healthcheck correctness, a canned report is
// sufficient.
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dc.Values,
HealthcheckFunc: func(_ context.Context, _ string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
return &healthsdk.HealthcheckReport{
Time: time.Now(),
Healthy: true,
Severity: health.SeverityOK,
}
},
DeploymentValues: dc.Values,
HealthcheckTimeout: testutil.WaitSuperLong,
})
t.Cleanup(func() { closer.Close() })
@@ -74,7 +60,7 @@ func TestSupportBundle(t *testing.T) {
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Set up test fixtures
setupCtx := testutil.Context(t, testutil.WaitLong)
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
workspaceWithAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, func(agents []*proto.Agent) []*proto.Agent {
// This should not show up in the bundle output
agents[0].Env["SECRET_VALUE"] = secretValue
@@ -83,6 +69,22 @@ func TestSupportBundle(t *testing.T) {
workspaceWithoutAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, nil)
memberWorkspace := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, member.ID, nil)
// Wait for healthcheck to complete successfully before continuing with sub-tests.
// The result is cached so subsequent requests will be fast.
healthcheckDone := make(chan *healthsdk.HealthcheckReport)
go func() {
defer close(healthcheckDone)
hc, err := healthsdk.New(client).DebugHealth(setupCtx)
if err != nil {
assert.NoError(t, err, "seed healthcheck cache")
return
}
healthcheckDone <- &hc
}()
if _, ok := testutil.AssertReceive(setupCtx, t, healthcheckDone); !ok {
t.Fatal("healthcheck did not complete in time -- this may be a transient issue")
}
t.Run("WorkspaceWithAgent", func(t *testing.T) {
t.Parallel()
@@ -130,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
assertBundleContents(t, path, true, false, []string{secretValue})
})
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
t.Run("NoPrivilege", func(t *testing.T) {
t.Parallel()
d := t.TempDir()
path := filepath.Join(d, "bundle.zip")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
clitest.SetupConfig(t, memberClient, root)
err := inv.Run()
require.NoError(t, err)
r, err := zip.OpenReader(path)
require.NoError(t, err, "open zip file")
defer r.Close()
fileNames := make(map[string]struct{}, len(r.File))
for _, f := range r.File {
fileNames[f.Name] = struct{}{}
}
// These should always be present in the zip structure, even if
// the content is null/empty for non-admin users.
for _, name := range []string{
"deployment/buildinfo.json",
"deployment/config.json",
"workspace/workspace.json",
"logs.txt",
"cli_logs.txt",
"network/netcheck.json",
"network/interfaces.json",
} {
require.Contains(t, fileNames, name)
}
require.ErrorContains(t, err, "failed authorization check")
})
// This ensures that the CLI does not panic when trying to generate a support bundle
@@ -180,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("received request: %s %s", r.Method, r.URL)
switch r.URL.Path {
case "/api/v2/users/me":
resp := codersdk.User{}
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
case "/api/v2/authcheck":
// Fake auth check
resp := codersdk.AuthorizationResponse{
-4
View File
@@ -20,10 +20,6 @@ OPTIONS:
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
Specify the source workspace name to copy parameters from.
--no-wait bool, $CODER_CREATE_NO_WAIT
Return immediately after creating the workspace. The build will run in
the background.
--parameter string-array, $CODER_RICH_PARAMETER
Rich parameter value in the format "name=value".
@@ -6,7 +6,7 @@ USAGE:
List all organization members
OPTIONS:
-c, --column [username|name|last seen at|user created at|user updated at|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
-c, --column [username|name|user id|organization id|created at|updated at|organization roles] (default: username,organization roles)
Columns to display in table output.
-o, --output table|json (default: table)
+1 -1
View File
@@ -7,7 +7,7 @@
"last_seen_at": "====[timestamp]=====",
"name": "test-daemon",
"version": "v0.0.0-devel",
"api_version": "1.16",
"api_version": "1.15",
"provisioners": [
"echo"
],
-6
View File
@@ -170,12 +170,6 @@ AI BRIDGE OPTIONS:
exporting these records to external SIEM or observability systems.
AI BRIDGE PROXY OPTIONS:
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
Comma-separated list of CIDR ranges that are permitted even though
they fall within blocked private/reserved IP ranges. By default all
private ranges are blocked to prevent SSRF attacks. Use this to allow
access to specific internal networks.
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
provider requests.
+10 -11
View File
@@ -8,17 +8,16 @@ USAGE:
Aliases: user
SUBCOMMANDS:
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
oidc-claims Display the OIDC claims for the authenticated user.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user
cannot log into the platform
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user cannot
log into the platform
———
Run `coder --help` for a list of global options.
-4
View File
@@ -24,10 +24,6 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
-24
View File
@@ -1,24 +0,0 @@
coder v0.0.0-devel
USAGE:
coder users oidc-claims [flags]
Display the OIDC claims for the authenticated user.
- Display your OIDC claims:
$ coder users oidc-claims
- Display your OIDC claims as JSON:
$ coder users oidc-claims -o json
OPTIONS:
-c, --column [key|value] (default: key,value)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
-11
View File
@@ -752,11 +752,6 @@ workspace_prebuilds:
# limit; disabled when set to zero.
# (default: 3, type: int)
failure_hard_limit: 3
# Configure the background chat processing daemon.
chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
@@ -873,12 +868,6 @@ aibridgeproxy:
# by the system. If not provided, the system certificate pool is used.
# (default: <unset>, type: string)
upstream_proxy_ca: ""
# Comma-separated list of CIDR ranges that are permitted even though they fall
# within blocked private/reserved IP ranges. By default all private ranges are
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
# networks.
# (default: <unset>, type: string-array)
allowed_private_cidrs: []
# Configure data retention policies for various database tables. Retention
# policies automatically purge old data to reduce database size and improve
# performance. Setting a retention duration to 0 disables automatic purging for
+12 -37
View File
@@ -17,14 +17,13 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" && !serviceAccount {
if email == "" {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin || serviceAccount {
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
-53
View File
@@ -8,7 +8,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
-79
View File
@@ -1,79 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) userOIDCClaims() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.ChangeFormatterData(
cliui.TableFormat([]claimRow{}, []string{"key", "value"}),
func(data any) (any, error) {
resp, ok := data.(codersdk.OIDCClaimsResponse)
if !ok {
return nil, xerrors.Errorf("expected type %T, got %T", resp, data)
}
rows := make([]claimRow, 0, len(resp.Claims))
for k, v := range resp.Claims {
rows = append(rows, claimRow{
Key: k,
Value: fmt.Sprintf("%v", v),
})
}
return rows, nil
},
),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "oidc-claims",
Short: "Display the OIDC claims for the authenticated user.",
Long: FormatExamples(
Example{
Description: "Display your OIDC claims",
Command: "coder users oidc-claims",
},
Example{
Description: "Display your OIDC claims as JSON",
Command: "coder users oidc-claims -o json",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
resp, err := client.UserOIDCClaims(inv.Context())
if err != nil {
return xerrors.Errorf("get oidc claims: %w", err)
}
out, err := formatter.Format(inv.Context(), resp)
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
type claimRow struct {
Key string `json:"-" table:"key,default_sort"`
Value string `json:"-" table:"value"`
}
-161
View File
@@ -1,161 +0,0 @@
package cli_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestUserOIDCClaims(t *testing.T) {
t.Parallel()
newOIDCTest := func(t *testing.T) (*oidctest.FakeIDP, *codersdk.Client) {
t.Helper()
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
ownerClient := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
return fake, ownerClient
}
t.Run("OwnClaims", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
"groups": []string{"admin", "eng"},
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims", "-o", "json")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
var resp codersdk.OIDCClaimsResponse
err = json.Unmarshal(buf.Bytes(), &resp)
require.NoError(t, err, "unmarshal JSON output")
require.NotEmpty(t, resp.Claims, "claims should not be empty")
assert.Equal(t, "alice@coder.com", resp.Claims["email"])
})
t.Run("Table", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "bob@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
output := buf.String()
require.Contains(t, output, "email")
require.Contains(t, output, "bob@coder.com")
})
t.Run("NotOIDCUser", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, client, root)
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not an OIDC user")
})
// Verify that two different OIDC users each only see their own
// claims. The endpoint has no user parameter, so there is no way
// to request another user's claims by design.
t.Run("OnlyOwnClaims", func(t *testing.T) {
t.Parallel()
aliceFake, aliceOwnerClient := newOIDCTest(t)
aliceClaims := jwt.MapClaims{
"email": "alice-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
aliceClient, aliceLoginResp := aliceFake.Login(t, aliceOwnerClient, aliceClaims)
defer aliceLoginResp.Body.Close()
bobFake, bobOwnerClient := newOIDCTest(t)
bobClaims := jwt.MapClaims{
"email": "bob-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
bobClient, bobLoginResp := bobFake.Login(t, bobOwnerClient, bobClaims)
defer bobLoginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
// Alice sees her own claims.
aliceResp, err := aliceClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "alice-isolation@coder.com", aliceResp.Claims["email"])
// Bob sees his own claims.
bobResp, err := bobClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "bob-isolation@coder.com", bobResp.Claims["email"])
})
t.Run("ClaimsNeverNull", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
// Use minimal claims — just enough for OIDC login.
claims := jwt.MapClaims{
"email": "minimal@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.UserOIDCClaims(ctx)
require.NoError(t, err)
require.NotNil(t, resp.Claims, "claims should never be nil, expected empty map")
})
}
-1
View File
@@ -19,7 +19,6 @@ func (r *RootCmd) users() *serpent.Command {
r.userSingle(),
r.userDelete(),
r.userEditRoles(),
r.userOIDCClaims(),
r.createUserStatusCommand(codersdk.UserStatusActive),
r.createUserStatusCommand(codersdk.UserStatusSuspended),
},
+8 -13
View File
@@ -103,7 +103,7 @@ type Options struct {
UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric)
}
func New(opts Options, workspace database.Workspace, agent database.WorkspaceAgent) *API {
func New(opts Options, workspace database.Workspace) *API {
if opts.Clock == nil {
opts.Clock = quartz.NewReal()
}
@@ -156,8 +156,7 @@ func New(opts Options, workspace database.Workspace, agent database.WorkspaceAge
}
api.StatsAPI = &StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: api.agent,
Workspace: api.cachedWorkspaceFields,
Database: opts.Database,
Log: opts.Log,
@@ -176,18 +175,16 @@ func New(opts Options, workspace database.Workspace, agent database.WorkspaceAge
}
api.AppsAPI = &AppsAPI{
AgentID: agent.ID,
AgentFn: api.agent,
Database: opts.Database,
Log: opts.Log,
Workspace: api.cachedWorkspaceFields,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
AgentID: agent.ID,
AgentFn: api.agent,
Workspace: api.cachedWorkspaceFields,
Database: opts.Database,
Log: opts.Log,
@@ -207,8 +204,7 @@ func New(opts Options, workspace database.Workspace, agent database.WorkspaceAge
}
api.ConnLogAPI = &ConnLogAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: api.agent,
ConnectionLogger: opts.ConnectionLogger,
Database: opts.Database,
Workspace: api.cachedWorkspaceFields,
@@ -226,6 +222,7 @@ func New(opts Options, workspace database.Workspace, agent database.WorkspaceAge
api.SubAgentAPI = &SubAgentAPI{
OwnerID: opts.OwnerID,
OrganizationID: opts.OrganizationID,
AgentID: opts.AgentID,
AgentFn: api.agent,
Log: opts.Log,
Clock: opts.Clock,
@@ -300,10 +297,8 @@ func (a *API) agent(ctx context.Context) (database.WorkspaceAgent, error) {
func (a *API) refreshCachedWorkspace(ctx context.Context) {
ws, err := a.opts.Database.GetWorkspaceByID(ctx, a.opts.WorkspaceID)
if err != nil {
// Do not clear the cache on transient DB errors. Stale data is
// preferable to no data, which forces callers to fall back to
// expensive queries like GetWorkspaceByAgentID.
a.opts.Log.Warn(ctx, "failed to refresh cached workspace fields", slog.Error(err))
a.cachedWorkspaceFields.Clear()
return
}
@@ -346,11 +341,11 @@ func (a *API) startCacheRefreshLoop(ctx context.Context) {
a.cachedWorkspaceFields.Clear()
}
func (a *API) publishWorkspaceUpdate(ctx context.Context, agentID uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
a.opts.PublishWorkspaceUpdateFn(ctx, a.opts.OwnerID, wspubsub.WorkspaceEvent{
Kind: kind,
WorkspaceID: a.opts.WorkspaceID,
AgentID: &agentID,
AgentID: &agent.ID,
})
return nil
}
+33 -38
View File
@@ -24,19 +24,22 @@ import (
)
type AppsAPI struct {
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
Log slog.Logger
Workspace *CachedWorkspaceFields
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
a.Log.Debug(ctx, "got batch app health update",
slog.F("agent_id", a.AgentID.String()),
slog.F("agent_id", workspaceAgent.ID.String()),
slog.F("updates", req.Updates),
)
@@ -44,9 +47,9 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, a.AgentID)
apps, err := a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", a.AgentID, err)
return nil, xerrors.Errorf("get workspace apps by agent ID %q: %w", workspaceAgent.ID, err)
}
var newApps []database.WorkspaceApp
@@ -107,7 +110,7 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAppHealthUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAppHealthUpdate)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
@@ -146,8 +149,12 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: a.AgentID,
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
@@ -157,10 +164,11 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
})
}
ws, ok := a.Workspace.AsWorkspaceIdentity()
if !ok {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Workspace identity not cached.",
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
@@ -182,8 +190,8 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: ws.ID,
AgentID: a.AgentID,
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
@@ -200,7 +208,7 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, a.AgentID, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
@@ -209,14 +217,14 @@ func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateApp
}
}
// Notify on state change to Working/Idle for AI tasks.
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState)
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, ws.ID, time.Time{})
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
@@ -253,6 +261,8 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
@@ -269,20 +279,11 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
taskID := a.Workspace.TaskID()
if !taskID.Valid {
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only fetch fresh agent state for task workspaces, since we need
// the current lifecycle state to decide whether to send notifications.
agent, err := a.AgentFn(ctx)
if err != nil {
a.Log.Warn(ctx, "failed to get agent for AI task notification", slog.Error(err))
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
@@ -295,7 +296,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
task, err := a.Database.GetTaskByID(ctx, taskID.UUID)
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
@@ -320,20 +321,14 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
return
}
ws, ok := a.Workspace.AsWorkspaceIdentity()
if !ok {
a.Log.Warn(ctx, "failed to get workspace identity for AI task notification")
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
ws.OwnerID,
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": ws.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
@@ -343,7 +338,7 @@ func (a *AppsAPI) enqueueAITaskStateNotification(
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
ws.ID, ws.OwnerID, ws.OrganizationID, appID,
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
+42 -28
View File
@@ -67,10 +67,12 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -103,10 +105,12 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -140,10 +144,12 @@ func TestBatchUpdateAppHealths(t *testing.T) {
publishCalled := false
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -174,7 +180,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -201,7 +209,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -229,7 +239,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: nil,
@@ -267,26 +279,14 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: uuid.UUID{7},
},
}
cachedWs := &agentapi.CachedWorkspaceFields{}
cachedWs.UpdateValues(workspace)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
Workspace: cachedWs,
PublishWorkspaceUpdateFn: func(_ context.Context, agnt uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, agnt, agent.ID)
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
@@ -309,6 +309,14 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
@@ -355,7 +363,9 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
@@ -382,7 +392,9 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
@@ -410,7 +422,9 @@ func TestWorkspaceAgentAppStatus(t *testing.T) {
}
api := &agentapi.AppsAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
-10
View File
@@ -4,7 +4,6 @@ import (
"context"
"sync"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
@@ -24,14 +23,12 @@ type CachedWorkspaceFields struct {
lock sync.RWMutex
identity database.WorkspaceIdentity
taskID uuid.NullUUID
}
func (cws *CachedWorkspaceFields) Clear() {
cws.lock.Lock()
defer cws.lock.Unlock()
cws.identity = database.WorkspaceIdentity{}
cws.taskID = uuid.NullUUID{}
}
func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
@@ -45,13 +42,6 @@ func (cws *CachedWorkspaceFields) UpdateValues(ws database.Workspace) {
cws.identity.OwnerUsername = ws.OwnerUsername
cws.identity.TemplateName = ws.TemplateName
cws.identity.AutostartSchedule = ws.AutostartSchedule
cws.taskID = ws.TaskID
}
func (cws *CachedWorkspaceFields) TaskID() uuid.NullUUID {
cws.lock.RLock()
defer cws.lock.RUnlock()
return cws.taskID
}
// Returns the Workspace, true, unless the workspace has not been cached (nuked or was a prebuild).
+19 -4
View File
@@ -14,11 +14,11 @@ import (
"github.com/coder/coder/v2/coderd/connectionlog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
)
type ConnLogAPI struct {
AgentID uuid.UUID
AgentName string
AgentFn func(context.Context) (database.WorkspaceAgent, error)
ConnectionLogger *atomic.Pointer[connectionlog.ConnectionLogger]
Workspace *CachedWorkspaceFields
Database database.Store
@@ -53,12 +53,27 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
}
}
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
rbacCtx := ctx
var ws database.WorkspaceIdentity
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
ws = dbws
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
// Fetch contextual data for this connection log event.
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, xerrors.Errorf("get agent: %w", err)
}
if ws.Equal(database.WorkspaceIdentity{}) {
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent id: %w", err)
}
@@ -82,7 +97,7 @@ func (a *ConnLogAPI) ReportConnection(ctx context.Context, req *agentproto.Repor
WorkspaceOwnerID: ws.OwnerID,
WorkspaceID: ws.ID,
WorkspaceName: ws.Name,
AgentName: a.AgentName,
AgentName: workspaceAgent.Name,
Type: connectionType,
Code: code,
Ip: logIP,
+4 -3
View File
@@ -114,9 +114,10 @@ func TestConnectionLog(t *testing.T) {
api := &agentapi.ConnLogAPI{
ConnectionLogger: asAtomicPointer[connectionlog.ConnectionLogger](connLogger),
Database: mDB,
AgentID: agent.ID,
AgentName: agent.Name,
Workspace: &agentapi.CachedWorkspaceFields{},
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
}
api.ReportConnection(context.Background(), &agentproto.ReportConnectionRequest{
Connection: &agentproto.Connection{
+2 -2
View File
@@ -30,7 +30,7 @@ type LifecycleAPI struct {
WorkspaceID uuid.UUID
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
TimeNowFn func() time.Time // defaults to dbtime.Now()
Metrics *LifecycleMetrics
@@ -122,7 +122,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLifecycleUpdate)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
+4 -4
View File
@@ -85,7 +85,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -206,7 +206,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
@@ -323,7 +323,7 @@ func TestUpdateLifecycle(t *testing.T) {
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
atomic.AddInt64(&publishCalled, 1)
return nil
},
@@ -410,7 +410,7 @@ func TestUpdateLifecycle(t *testing.T) {
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishCalled = true
return nil
},
+3 -3
View File
@@ -19,7 +19,7 @@ type LogsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID, wspubsub.WorkspaceEventKind) error
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)
TimeNowFn func() time.Time // defaults to dbtime.Now()
@@ -125,7 +125,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentLogsOverflow)
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentLogsOverflow)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
@@ -145,7 +145,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
// If these are the first logs being appended, we publish a UI update
// to notify the UI that logs are now available.
err = a.PublishWorkspaceUpdateFn(ctx, workspaceAgent.ID, wspubsub.WorkspaceEventKindAgentFirstLogs)
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentFirstLogs)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
+6 -6
View File
@@ -51,7 +51,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -155,7 +155,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -203,7 +203,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -296,7 +296,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -340,7 +340,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
@@ -387,7 +387,7 @@ func TestBatchCreateLogs(t *testing.T) {
},
Database: dbM,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(ctx context.Context, _ uuid.UUID, kind wspubsub.WorkspaceEventKind) error {
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
publishWorkspaceUpdateCalled = true
return nil
},
+5 -6
View File
@@ -32,12 +32,16 @@ type ManifestAPI struct {
DerpForceWebSockets bool
WorkspaceID uuid.UUID
AgentFn func(ctx context.Context) (database.WorkspaceAgent, error)
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Database database.Store
DerpMapFn func() *tailcfg.DERPMap
}
func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
var (
dbApps []database.WorkspaceApp
scripts []database.WorkspaceAgentScript
@@ -46,11 +50,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
devcontainers []database.WorkspaceAgentDevcontainer
)
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("getting workspace agent: %w", err)
}
var eg errgroup.Group
eg.Go(func() (err error) {
dbApps, err = a.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID)
+9 -3
View File
@@ -322,7 +322,9 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
@@ -387,7 +389,9 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return childAgent, nil },
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return childAgent, nil
},
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
@@ -508,7 +512,9 @@ func TestGetManifest(t *testing.T) {
DisableDirectConnections: true,
DerpForceWebSockets: true,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceID: workspace.ID,
Database: mDB,
DerpMapFn: derpMapFn,
+22 -4
View File
@@ -5,18 +5,18 @@ import (
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
)
type MetadataAPI struct {
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
@@ -45,11 +45,29 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
maxErrorLen = maxValueLen
)
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceByAgentID on every metadata update.
var err error
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
var (
collectedAt = a.now()
allKeysLen = 0
dbUpdate = database.UpdateWorkspaceAgentMetadataParams{
WorkspaceAgentID: a.AgentID,
WorkspaceAgentID: workspaceAgent.ID,
// These need to be `make(x, 0, len(req.Metadata))` instead of
// `make(x, len(req.Metadata))` because we may not insert all
// metadata if the keys are large.
@@ -103,7 +121,7 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B
}
// Use batcher to batch metadata updates.
err := a.Batcher.Add(a.AgentID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
err = a.Batcher.Add(workspaceAgent.ID, dbUpdate.Key, dbUpdate.Value, dbUpdate.Error, dbUpdate.CollectedAt)
if err != nil {
return nil, xerrors.Errorf("add metadata to batcher: %w", err)
}
+9 -3
View File
@@ -80,7 +80,9 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
@@ -157,7 +159,9 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
@@ -237,7 +241,9 @@ func TestBatchUpdateMetadata(t *testing.T) {
t.Cleanup(batcher.Close)
api := &agentapi.MetadataAPI{
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &agentapi.CachedWorkspaceFields{},
Log: testutil.Logger(t),
Batcher: batcher,
+25 -8
View File
@@ -4,21 +4,20 @@ import (
"context"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/codersdk"
)
type StatsAPI struct {
AgentID uuid.UUID
AgentName string
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Workspace *CachedWorkspaceFields
Database database.Store
Log slog.Logger
@@ -45,13 +44,32 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
return res, nil
}
// Inject RBAC object into context for dbauthz fast path, avoid having to
// call GetWorkspaceAgentByID on every stats update.
rbacCtx := ctx
if dbws, ok := a.Workspace.AsWorkspaceIdentity(); ok {
var err error
rbacCtx, err = dbauthz.WithWorkspaceRBAC(ctx, dbws.RBACObject())
if err != nil {
// Don't error level log here, will exit the function. We want to fall back to GetWorkspaceByAgentID.
//nolint:gocritic
a.Log.Debug(ctx, "Cached workspace was present but RBAC object was invalid", slog.F("err", err))
}
}
workspaceAgent, err := a.AgentFn(rbacCtx)
if err != nil {
return nil, err
}
// If cache is empty (prebuild or invalid), fall back to DB
var ws database.WorkspaceIdentity
var ok bool
if ws, ok = a.Workspace.AsWorkspaceIdentity(); !ok {
w, err := a.Database.GetWorkspaceByAgentID(ctx, a.AgentID)
w, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", a.AgentID, err)
return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err)
}
ws = database.WorkspaceIdentityFromWorkspace(w)
}
@@ -72,12 +90,11 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR
req.Stats.SessionCountReconnectingPty = 0
}
err := a.StatsReporter.ReportAgentStats(
err = a.StatsReporter.ReportAgentStats(
ctx,
a.now(),
ws,
a.AgentID,
a.AgentName,
workspaceAgent,
req.Stats,
false,
)
+18 -12
View File
@@ -119,8 +119,9 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -228,8 +229,9 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -262,8 +264,9 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -344,8 +347,9 @@ func TestUpdateStats(t *testing.T) {
// ws.AutostartSchedule = workspace.AutostartSchedule
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &ws,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -455,8 +459,9 @@ func TestUpdateStats(t *testing.T) {
)
defer wut.Close()
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
@@ -591,8 +596,9 @@ func TestUpdateStats(t *testing.T) {
}
)
api := agentapi.StatsAPI{
AgentID: agent.ID,
AgentName: agent.Name,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Workspace: &workspaceAsCacheFields,
Database: dbM,
StatsReporter: workspacestats.NewReporter(workspacestats.ReporterOptions{
+2 -6
View File
@@ -25,6 +25,7 @@ import (
type SubAgentAPI struct {
OwnerID uuid.UUID
OrganizationID uuid.UUID
AgentID uuid.UUID
AgentFn func(context.Context) (database.WorkspaceAgent, error)
Log slog.Logger
@@ -294,12 +295,7 @@ func (a *SubAgentAPI) ListSubAgents(ctx context.Context, _ *agentproto.ListSubAg
//nolint:gocritic // This gives us only the permissions required to do the job.
ctx = dbauthz.AsSubAgentAPI(ctx, a.OrganizationID, a.OwnerID)
parentAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, xerrors.Errorf("get parent agent: %w", err)
}
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, parentAgent.ID)
workspaceAgents, err := a.Database.GetWorkspaceAgentsByParentID(ctx, a.AgentID)
if err != nil {
return nil, err
}
+6 -3
View File
@@ -81,9 +81,12 @@ func TestSubAgentAPI(t *testing.T) {
return &agentapi.SubAgentAPI{
OwnerID: user.ID,
OrganizationID: org.ID,
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agent, nil },
Clock: clock,
Database: dbauthz.New(db, auth, logger, accessControlStore),
AgentID: agent.ID,
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Clock: clock,
Database: dbauthz.New(db, auth, logger, accessControlStore),
}
}
-38
View File
@@ -1,38 +0,0 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+1 -1
View File
@@ -773,7 +773,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
}
if statusResp.Status != agentapisdk.StatusStable {
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
Message: "Task app is not ready to accept input.",
Detail: fmt.Sprintf("Status: %s", statusResp.Status),
})
-5
View File
@@ -789,11 +789,6 @@ func TestTasks(t *testing.T) {
})
require.Error(t, err, "wanted error due to bad status")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "not ready to accept input")
statusResponse = agentapisdk.StatusStable
//nolint:tparallel // Not intended to run in parallel.
+1725 -2200
View File
File diff suppressed because it is too large Load Diff
+1723 -2171
View File
File diff suppressed because it is too large Load Diff
+1 -2
View File
@@ -32,8 +32,7 @@ type Auditable interface {
idpsync.OrganizationSyncSettings |
idpsync.GroupSyncSettings |
idpsync.RoleSyncSettings |
database.TaskTable |
database.AiSeatState
database.TaskTable
}
// Map is a map of changed fields in an audited resource. It maps field names to
-8
View File
@@ -132,8 +132,6 @@ func ResourceTarget[T Auditable](tgt T) string {
return "Organization Role Sync"
case database.TaskTable:
return typed.Name
case database.AiSeatState:
return "AI Seat"
default:
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
}
@@ -198,8 +196,6 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
return noID // Org field on audit log has org id
case database.TaskTable:
return typed.ID
case database.AiSeatState:
return typed.UserID
default:
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
}
@@ -255,8 +251,6 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
return database.ResourceTypeIdpSyncSettingsGroup
case database.TaskTable:
return database.ResourceTypeTask
case database.AiSeatState:
return database.ResourceTypeAiSeat
default:
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
}
@@ -315,8 +309,6 @@ func ResourceRequiresOrgID[T Auditable]() bool {
return true
case database.TaskTable:
return true
case database.AiSeatState:
return false
default:
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
}
@@ -6,8 +6,8 @@ import (
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/codersdk"
)
File diff suppressed because it is too large Load Diff
+86
View File
@@ -0,0 +1,86 @@
package chatd
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
calls++
return database.Chat{}, nil
},
)
require.NoError(t, err)
require.Equal(t, chat, refreshed)
require.Equal(t, 0, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReloadsWhenWorkspaceMissing(t *testing.T) {
t.Parallel()
chatID := uuid.New()
workspaceID := uuid.New()
chat := database.Chat{ID: chatID}
reloaded := database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
calls := 0
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(_ context.Context, id uuid.UUID) (database.Chat, error) {
calls++
require.Equal(t, chatID, id)
return reloaded, nil
},
)
require.NoError(t, err)
require.Equal(t, reloaded, refreshed)
require.Equal(t, 1, calls)
}
func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
t.Parallel()
chat := database.Chat{ID: uuid.New()}
loadErr := xerrors.New("boom")
refreshed, err := refreshChatWorkspaceSnapshot(
context.Background(),
chat,
func(context.Context, uuid.UUID) (database.Chat, error) {
return database.Chat{}, loadErr
},
)
require.Error(t, err)
require.ErrorContains(t, err, "reload chat workspace state")
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
File diff suppressed because it is too large Load Diff
@@ -16,8 +16,8 @@ import (
"charm.land/fantasy/schema"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatretry"
"github.com/coder/coder/v2/codersdk"
)
@@ -39,14 +39,10 @@ var ErrInterrupted = xerrors.New("chat interrupted")
// persistence layer is responsible for splitting these into
// separate database messages by role.
type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
ShouldContinue bool
}
// RunOptions configures a single streaming chat loop run.
@@ -127,7 +123,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
switch c.GetType() {
case fantasy.ContentTypeText:
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
if !ok || strings.TrimSpace(text.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.TextPart{
@@ -136,7 +132,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
})
case fantasy.ContentTypeReasoning:
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
if !ok || strings.TrimSpace(reasoning.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ReasoningPart{
@@ -265,7 +261,6 @@ func Run(ctx context.Context, opts RunOptions) error {
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
stepStart := time.Now()
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -367,13 +362,12 @@ func Run(ctx context.Context, opts RunOptions) error {
// the chat was interrupted between the previous
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ShouldContinue: result.shouldContinue,
}); err != nil { if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
@@ -617,12 +611,10 @@ func processStepStream(
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: the stream may surface the
// cancel as context.Canceled or propagate the
// ErrInterrupted cause directly, depending on
// the provider implementation.
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
@@ -640,23 +632,6 @@ func processStepStream(
}
}
// The stream iterator may stop yielding parts without
// producing a StreamPartTypeError when the context is
// canceled (e.g. some providers close the response body
// silently). Detect this case and flush partial content
// so that persistInterruptedStep can save it.
if ctx.Err() != nil &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
@@ -897,7 +872,8 @@ func persistInterruptedStep(
persistCtx := context.WithoutCancel(ctx)
if err := opts.PersistStep(persistCtx, PersistedStep{
Content: content,
Content: content,
ShouldContinue: false,
}); err != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(err)
@@ -7,7 +7,6 @@ import (
"strings"
"sync"
"testing"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
@@ -65,8 +64,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.Greater(t, persistedStep.Runtime, time.Duration(0),
"step runtime should be positive")
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
@@ -578,84 +575,6 @@ func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *test
assert.False(t, localTR.ProviderExecuted)
}
func TestToResponseMessages_FiltersEmptyTextAndReasoningParts(t *testing.T) {
t.Parallel()
sr := stepResult{
content: []fantasy.Content{
// Empty text — should be filtered.
fantasy.TextContent{Text: ""},
// Whitespace-only text — should be filtered.
fantasy.TextContent{Text: " \t\n"},
// Empty reasoning — should be filtered.
fantasy.ReasoningContent{Text: ""},
// Whitespace-only reasoning — should be filtered.
fantasy.ReasoningContent{Text: " \n"},
// Non-empty text — should pass through.
fantasy.TextContent{Text: "hello world"},
// Leading/trailing whitespace with content — kept
// with the original value (not trimmed).
fantasy.TextContent{Text: " hello "},
// Non-empty reasoning — should pass through.
fantasy.ReasoningContent{Text: "let me think"},
// Tool call — should be unaffected by filtering.
fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Input: `{"path":"main.go"}`,
},
// Local tool result — should be unaffected by filtering.
fantasy.ToolResultContent{
ToolCallID: "tc-1",
ToolName: "read_file",
Result: fantasy.ToolResultOutputContentText{Text: "file contents"},
},
},
}
msgs := sr.toResponseMessages()
require.Len(t, msgs, 2, "expected assistant + tool messages")
// First message: assistant role with non-empty text, reasoning,
// and the tool call. The four empty/whitespace-only parts must
// have been dropped.
assistantMsg := msgs[0]
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
require.Len(t, assistantMsg.Content, 4,
"assistant message should have 2x TextPart, ReasoningPart, and ToolCallPart")
// Part 0: non-empty text.
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[0])
require.True(t, ok, "part 0 should be TextPart")
assert.Equal(t, "hello world", textPart.Text)
// Part 1: padded text — original whitespace preserved.
paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[1])
require.True(t, ok, "part 1 should be TextPart")
assert.Equal(t, " hello ", paddedPart.Text)
// Part 2: non-empty reasoning.
reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](assistantMsg.Content[2])
require.True(t, ok, "part 2 should be ReasoningPart")
assert.Equal(t, "let me think", reasoningPart.Text)
// Part 3: tool call (unaffected by text/reasoning filtering).
toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[3])
require.True(t, ok, "part 3 should be ToolCallPart")
assert.Equal(t, "tc-1", toolCallPart.ToolCallID)
assert.Equal(t, "read_file", toolCallPart.ToolName)
// Second message: tool role with the local tool result.
toolMsg := msgs[1]
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1,
"tool message should have only the local ToolResultPart")
toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "tool part should be ToolResultPart")
assert.Equal(t, "tc-1", toolResultPart.ToolCallID)
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false

Some files were not shown because too many files have changed in this diff Show More