Compare commits

..

1 Commits

Author SHA1 Message Date
Kyle Carberry 64629eaf15 fix(agents): make chat title nullable to eliminate sidebar flicker
The chat title column is now nullable in PostgreSQL. Chats are created
with title = NULL instead of a truncated fallback string. The async LLM
title generator fires whenever title IS NULL, providing a clean
one-way transition from skeleton to real title.

This eliminates the sidebar flicker caused by:
- The fallback-to-generated title swap on first render
- Query invalidations (refetchOnWindowFocus, reconnect, follow-up
  messages) briefly reverting to stale fallback titles

Backend changes:
- Migration makes chats.title nullable, drops 'New Chat' default
- InsertChat no longer accepts a title parameter
- titleInput() triggers on NULL title instead of comparing fallback
- Removed chatTitleFromMessage, fallbackChatTitle, and
  subagentFallbackChatTitle helper functions

Frontend changes:
- Chat.title is now string | null in TypeScript types
- Sidebar renders a shimmer skeleton for null titles
- Search, aria-labels, and TopBar handle null gracefully
2026-03-10 15:07:03 +00:00
702 changed files with 20580 additions and 80750 deletions
-72
View File
@@ -1,72 +0,0 @@
---
name: pull-requests
description: "Guide for creating, updating, and following up on pull requests in the Coder repository. Use when asked to open a PR, update a PR, rewrite a PR description, or follow up on CI/check failures."
---
# Pull Request Skill
## When to Use This Skill
Use this skill when asked to:
- Create a pull request for the current branch.
- Update an existing PR branch or description.
- Rewrite a PR body.
- Follow up on CI or check failures for an existing PR.
## References
Use the canonical docs for shared conventions and validation guidance:
- PR title and description conventions:
`.claude/docs/PR_STYLE_GUIDE.md`
- Local validation commands and git hooks: `AGENTS.md` (Essential Commands and
Git Hooks sections)
## Lifecycle Rules
1. **Check for an existing PR** before creating a new one:
```bash
gh pr list --head "$(git branch --show-current)" --author @me --json number --jq '.[0].number // empty'
```
If that returns a number, update that PR. If it returns empty output,
create a new one.
2. **Check you are not on main.** If the current branch is `main` or `master`,
create a feature branch before doing PR work.
3. **Default to draft.** Use `gh pr create --draft` unless the user explicitly
asks for ready-for-review.
4. **Keep description aligned with the full diff.** Re-read the diff against
the base branch before writing or updating the title and body. Describe the
entire PR diff, not just the last commit.
5. **Never auto-merge.** Do not merge or mark ready for review unless the user
explicitly asks.
6. **Never push to main or master.**
## CI / Checks Follow-up
**Always watch CI checks after pushing.** Do not push and walk away.
After pushing:
- Monitor CI with `gh pr checks <PR_NUMBER> --watch`.
- Use `gh pr view <PR_NUMBER> --json statusCheckRollup` for programmatic check
status.
If checks fail:
1. Find the failed run ID from the `gh pr checks` output.
2. Read the logs with `gh run view <run-id> --log-failed`.
3. Fix the problem locally.
4. Run `make pre-commit`.
5. Push the fix.
## What Not to Do
- Do not reference or call helper scripts that do not exist in this
repository.
- Do not auto-merge or mark ready for review without explicit user request.
- Do not push to `origin/main` or `origin/master`.
- Do not skip local validation before pushing.
- Do not fabricate or embellish PR descriptions.
+1 -1
View File
@@ -113,7 +113,7 @@ Coder emphasizes clear error handling, with specific patterns required:
All tests should run in parallel using `t.Parallel()` to ensure efficient testing and expose potential race conditions. The codebase is rigorously linted with golangci-lint to maintain consistent code quality.
Git contributions follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Git contributions follow a standard format with commit messages structured as `type: <message>`, where type is one of `feat`, `fix`, or `chore`.
## Development Workflow
+14 -5
View File
@@ -4,13 +4,22 @@ This guide documents the PR description style used in the Coder repository, base
## PR Title Format
Format: `type(scope): description`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/) format:
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
```text
type(scope): brief description
```
Examples:
**Common types:**
- `feat`: New features
- `fix`: Bug fixes
- `refactor`: Code refactoring without behavior change
- `perf`: Performance improvements
- `docs`: Documentation changes
- `chore`: Dependency updates, tooling changes
**Examples:**
- `feat: add tracing to aibridge`
- `fix: move contexts to appropriate locations`
+3 -5
View File
@@ -136,11 +136,9 @@ Then make your changes and push normally. Don't use `git push --force` unless th
## Commit Style
Format: `type(scope): message`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
- Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/)
- Format: `type(scope): message`
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
- Keep message titles concise (~70 characters)
- Use imperative, present tense in commit titles
-9
View File
@@ -1,9 +0,0 @@
paths:
# The triage workflow uses a quoted heredoc (<<'EOF') with ${VAR}
# placeholders that envsubst expands later. Shellcheck's SC2016
# warns about unexpanded variables in single-quoted strings, but
# the non-expansion is intentional here. Actionlint doesn't honor
# inline shellcheck disable directives inside heredocs.
.github/workflows/triage-via-chat-api.yaml:
ignore:
- 'SC2016'
-1
View File
@@ -64,7 +64,6 @@ runs:
TEST_PACKAGES: ${{ inputs.test-packages }}
RACE_DETECTION: ${{ inputs.race-detection }}
TS_DEBUG_DISCO: "true"
TS_DEBUG_DERP: "true"
LC_CTYPE: "en_US.UTF-8"
LC_ALL: "en_US.UTF-8"
run: |
+34 -96
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -191,7 +191,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@631208b7aac2daa8b707f55e7331f9112b0e062d # v1.44.0
uses: crate-ci/typos@2d0ce569feab1f8752f1dde43cc2f2aa53236e06 # v1.40.0
with:
config: .github/workflows/typos.toml
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -327,7 +327,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -379,7 +379,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -537,7 +537,7 @@ jobs:
embedded-pg-cache: ${{ steps.embedded-pg-cache.outputs.embedded-pg-cache }}
- name: Upload failed test db dumps
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
@@ -575,7 +575,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -637,7 +637,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -709,7 +709,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -736,7 +736,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -769,7 +769,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -818,7 +818,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -826,7 +826,7 @@ jobs:
- name: Upload debug log
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coderd-debug-logs${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/e2e/test-results/debug.log
@@ -834,7 +834,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -849,7 +849,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -930,7 +930,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1005,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1043,7 +1043,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1097,7 +1097,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -1108,7 +1108,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1198,7 +1198,7 @@ jobs:
make -j \
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
@@ -1216,28 +1216,11 @@ jobs:
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
# Free up disk space before building Docker images. The preceding
# Build step produces ~2 GB of binaries and packages, the Go build
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
# builds, pushes, and SBOM generation need headroom that isn't
# available without reclaiming some of that space.
- name: Clean up build cache
run: |
set -euxo pipefail
# Go caches are no longer needed — binaries are already compiled.
go clean -cache -modcache
# Remove .apk and .rpm packages that are not uploaded as
# artifacts and were only built as make prerequisites.
rm -f ./build/*.apk ./build/*.rpm
- name: Build Linux Docker images
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
DOCKER_CLI_EXPERIMENTAL: "enabled"
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
# the Docker image targets — they were already built above.
DOCKER_IMAGE_NO_PREREQUISITES: "true"
run: |
set -euxo pipefail
@@ -1319,7 +1302,7 @@ jobs:
id: attest_main
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1356,7 +1339,7 @@ jobs:
id: attest_latest
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1393,7 +1376,7 @@ jobs:
id: attest_version
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1455,60 +1438,15 @@ jobs:
^v
prune-untagged: true
- name: Upload build artifact (coder-linux-amd64.tar.gz)
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.tar.gz
path: ./build/*_linux_amd64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-amd64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-amd64.deb
path: ./build/*_linux_amd64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.tar.gz
path: ./build/*_linux_arm64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-arm64.deb
path: ./build/*_linux_arm64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.tar.gz
path: ./build/*_linux_armv7.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-linux-armv7.deb
path: ./build/*_linux_armv7.deb
retention-days: 7
- name: Upload build artifact (coder-windows-amd64.zip)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: coder-windows-amd64.zip
path: ./build/*_windows_amd64.zip
name: coder
path: |
./build/*.zip
./build/*.tar.gz
./build/*.deb
retention-days: 7
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
@@ -1543,7 +1481,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-141
View File
@@ -23,44 +23,6 @@ permissions:
concurrency: pr-${{ github.ref }}
jobs:
community-label:
runs-on: ubuntu-latest
permissions:
pull-requests: write
if: >-
${{
github.event_name == 'pull_request_target' &&
github.event.action == 'opened' &&
github.event.pull_request.author_association != 'MEMBER' &&
github.event.pull_request.author_association != 'COLLABORATOR' &&
github.event.pull_request.author_association != 'OWNER'
}}
steps:
- name: Add community label
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const params = {
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
}
const labels = context.payload.pull_request.labels.map((label) => label.name)
if (labels.includes("community")) {
console.log('PR already has "community" label.')
return
}
console.log(
'Adding "community" label for author association "%s".',
context.payload.pull_request.author_association,
)
await github.rest.issues.addLabels({
...params,
labels: ["community"],
})
cla:
runs-on: ubuntu-latest
permissions:
@@ -83,109 +45,6 @@ jobs:
# Some users have signed a corporate CLA with Coder so are exempt from signing our community one.
allowlist: "coryb,aaronlehmann,dependabot*,blink-so*,blinkagent*"
title:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request_target' }}
steps:
- name: Validate PR title
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const { pull_request } = context.payload;
const title = pull_request.title;
const repo = { owner: context.repo.owner, repo: context.repo.repo };
const allowedTypes = [
"feat", "fix", "docs", "style", "refactor",
"perf", "test", "build", "ci", "chore", "revert",
];
const expectedFormat = `"type(scope): description" or "type: description"`;
const guidelinesLink = `See: https://github.com/coder/coder/blob/main/docs/about/contributing/CONTRIBUTING.md#commit-messages`;
const scopeHint = (type) =>
`Use a broader scope or no scope (e.g., "${type}: ...") for cross-cutting changes.\n` +
guidelinesLink;
console.log("Title: %s", title);
// Parse conventional commit format: type(scope)!: description
const match = title.match(/^(\w+)(\(([^)]*)\))?(!)?\s*:\s*.+/);
if (!match) {
core.setFailed(
`PR title does not match conventional commit format.\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
const type = match[1];
const scope = match[3]; // undefined if no parentheses
// Validate type.
if (!allowedTypes.includes(type)) {
core.setFailed(
`PR title has invalid type "${type}".\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
// If no scope, we're done.
if (!scope) {
console.log("No scope provided, title is valid.");
return;
}
console.log("Scope: %s", scope);
// Fetch changed files.
const files = await github.paginate(github.rest.pulls.listFiles, {
...repo,
pull_number: pull_request.number,
per_page: 100,
});
const changedPaths = files.map(f => f.filename);
console.log("Changed files: %d", changedPaths.length);
// Derive scope type from the changed files. The diff is the
// source of truth: if files exist under the scope, the path
// exists on the PR branch. No need for Contents API calls.
const isDir = changedPaths.some(f => f.startsWith(scope + "/"));
const isFile = changedPaths.some(f => f === scope);
const isStem = changedPaths.some(f => f.startsWith(scope + "."));
if (!isDir && !isFile && !isStem) {
core.setFailed(
`PR title scope "${scope}" does not match any files changed in this PR.\n` +
`Scopes must reference a path (directory or file stem) that contains changed files.\n` +
scopeHint(type)
);
return;
}
// Verify all changed files fall under the scope.
const outsideFiles = changedPaths.filter(f => {
if (isDir && f.startsWith(scope + "/")) return false;
if (f === scope) return false;
if (isStem && f.startsWith(scope + ".")) return false;
return true;
});
if (outsideFiles.length > 0) {
const listed = outsideFiles.map(f => " - " + f).join("\n");
core.setFailed(
`PR title scope "${scope}" does not contain all changed files.\n` +
`Files outside scope:\n${listed}\n\n` +
scopeHint(type)
);
return;
}
console.log("PR title is valid.");
release-labels:
runs-on: ubuntu-latest
permissions:
+19 -15
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -61,11 +61,11 @@ jobs:
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
permissions:
contents: read
id-token: write # to authenticate to EKS cluster
id-token: write
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,29 +76,33 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
role-to-assume: ${{ vars.AWS_DOGFOOD_DEPLOY_ROLE }}
aws-region: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Get Cluster Credentials
run: aws eks update-kubeconfig --name "$AWS_DOGFOOD_CLUSTER_NAME" --region "$AWS_DOGFOOD_DEPLOY_REGION"
env:
AWS_DOGFOOD_CLUSTER_NAME: ${{ vars.AWS_DOGFOOD_CLUSTER_NAME }}
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.8.2"
version: "2.7.0"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
# Retag image as dogfood while maintaining the multi-arch manifest
- name: Tag image as dogfood
@@ -142,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -48,7 +48,7 @@ jobs:
persist-credentials: false
- name: Docker login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v45.0.7
- uses: tj-actions/changed-files@e0021407031f5be11a464abee9a0776171c79891 # v45.0.7
id: changed-files
with:
files: |
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -78,11 +78,11 @@ jobs:
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -30,7 +30,7 @@ jobs:
- name: Sync issues
id: sync
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0.5.0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: sync
@@ -52,7 +52,7 @@ jobs:
- name: Complete release
id: complete
uses: linear/linear-release-action@5cbaabc187ceb63eee9d446e62e68e5c29a03ae8 # v0
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
with:
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
command: complete
+1 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+6 -6
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -14,12 +14,12 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
- name: Run Schmoder CI
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4
with:
workflow: ci.yaml
repo: coder/schmoder
+12 -10
View File
@@ -80,7 +80,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -155,7 +155,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -358,7 +358,7 @@ jobs:
id: attest_base
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -474,7 +474,7 @@ jobs:
id: attest_main
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -518,7 +518,7 @@ jobs:
id: attest_latest
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0
uses: actions/attest@e59cbc1ad1ac2d59339667419eb8cdde6eb61e3d # v3.2.0
with:
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -665,7 +665,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: release-artifacts
path: |
@@ -681,7 +681,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -700,11 +700,13 @@ jobs:
name: Publish to Homebrew tap
runs-on: ubuntu-latest
needs: release
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
if: ${{ !inputs.dry_run }}
steps:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -780,7 +782,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: SARIF file
path: results.sarif
+4 -4
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -69,7 +69,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.34.0
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
@@ -160,7 +160,7 @@ jobs:
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: trivy
path: trivy-results.sarif
+3 -3
View File
@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
-295
View File
@@ -1,295 +0,0 @@
# This workflow reimplements the AI Triage Automation using the Coder Chat API
# instead of the Tasks API. The Chat API (/api/experimental/chats) is a simpler
# interface that does not require a dedicated GitHub Action or workspace
# provisioning — we just create a chat, poll for completion, and link the
# result on the issue. All API calls use curl + jq directly.
#
# Key differences from the Tasks API workflow (traiage.yaml):
# - No checkout of coder/create-task-action; everything is inline curl/jq.
# - No template_name / template_preset / prefix inputs — the Chat API handles
# resource allocation internally.
# - Uses POST /api/experimental/chats to create a chat session.
# - Polls GET /api/experimental/chats/<id> until the agent finishes.
# - Chat URL format: ${CODER_URL}/agents?chat=${CHAT_ID}
name: AI Triage via Chat API
on:
issues:
types:
- labeled
workflow_dispatch:
inputs:
issue_url:
description: "GitHub Issue URL to process"
required: true
type: string
permissions:
contents: read
jobs:
triage-chat:
name: Triage GitHub Issue via Chat API
runs-on: ubuntu-latest
if: github.event.label.name == 'chat-triage' || github.event_name == 'workflow_dispatch'
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
permissions:
contents: read
issues: write
steps:
# ------------------------------------------------------------------
# Step 1: Determine the GitHub user and issue URL.
# Identical to the Tasks API workflow — resolve the actor for
# workflow_dispatch or the issue sender for label events.
# ------------------------------------------------------------------
- name: Determine Inputs
id: determine-inputs
if: always()
env:
GITHUB_ACTOR: ${{ github.actor }}
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# For workflow_dispatch, use the actor who triggered it.
# For issues events, use the issue sender.
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
exit 1
fi
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
exit 0
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
exit 0
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
# ------------------------------------------------------------------
# Step 2: Verify the triggering user has push access.
# Unchanged from the Tasks API workflow.
# ------------------------------------------------------------------
- name: Verify push access
env:
GITHUB_REPOSITORY: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
set -euo pipefail
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
if [[ "${can_push}" != "true" ]]; then
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
exit 1
fi
# ------------------------------------------------------------------
# Step 3: Create a chat via the Coder Chat API.
# Unlike the Tasks API which provisions a full workspace, the Chat
# API creates a lightweight chat session. We POST to
# /api/experimental/chats with the triage prompt as the initial
# message and receive a chat ID back.
# ------------------------------------------------------------------
- name: Create chat via Coder Chat API
id: create-chat
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
# Build the same triage prompt used by the Tasks API workflow.
TASK_PROMPT=$(cat <<'EOF'
Fix ${ISSUE_URL}
1. Use the gh CLI to read the issue description and comments.
2. Think carefully and try to understand the root cause. If the issue is unclear or not well defined, ask me to clarify and provide more information.
3. Write a proposed implementation plan to PLAN.md for me to review before starting implementation. Your plan should use TDD and only make the minimal changes necessary to fix the root cause.
4. When I approve your plan, start working on it. If you encounter issues with the plan, ask me for clarification and update the plan as required.
5. When you have finished implementation according to the plan, commit and push your changes, and create a PR using the gh CLI for me to review.
EOF
)
# Perform variable substitution on the prompt — scoped to $ISSUE_URL only.
# Using envsubst without arguments would expand every env var in scope
# (including CODER_SESSION_TOKEN), so we name the variable explicitly.
TASK_PROMPT=$(echo "${TASK_PROMPT}" | envsubst '$ISSUE_URL')
echo "Creating chat with prompt:"
echo "${TASK_PROMPT}"
# POST to the Chat API to create a new chat session.
RESPONSE=$(curl --silent --fail-with-body \
-X POST \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
-H "Content-Type: application/json" \
-d "$(jq -n --arg prompt "${TASK_PROMPT}" \
'{content: [{type: "text", text: $prompt}]}')" \
"${CODER_URL}/api/experimental/chats")
echo "Chat API response:"
echo "${RESPONSE}" | jq .
CHAT_ID=$(echo "${RESPONSE}" | jq -r '.id')
CHAT_STATUS=$(echo "${RESPONSE}" | jq -r '.status')
if [[ -z "${CHAT_ID}" || "${CHAT_ID}" == "null" ]]; then
echo "::error::Failed to create chat — no ID returned"
echo "Response: ${RESPONSE}"
exit 1
fi
# Validate that CHAT_ID is a UUID before using it in URL paths.
# This guards against unexpected API responses being interpolated
# into subsequent curl calls.
if [[ ! "${CHAT_ID}" =~ ^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$ ]]; then
echo "::error::CHAT_ID is not a valid UUID: ${CHAT_ID}"
exit 1
fi
CHAT_URL="${CODER_URL}/agents?chat=${CHAT_ID}"
echo "Chat created: ${CHAT_ID} (status: ${CHAT_STATUS})"
echo "Chat URL: ${CHAT_URL}"
echo "chat_id=${CHAT_ID}" >> "${GITHUB_OUTPUT}"
echo "chat_url=${CHAT_URL}" >> "${GITHUB_OUTPUT}"
# ------------------------------------------------------------------
# Step 4: Poll the chat status until the agent finishes.
# The Chat API is asynchronous — after creation the agent begins
# working in the background. We poll GET /api/experimental/chats/<id>
# every 5 seconds until the status is "waiting" (agent needs input),
# "completed" (agent finished), or "error". Timeout after 10 minutes.
# ------------------------------------------------------------------
- name: Poll chat status
id: poll-status
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
run: |
set -euo pipefail
POLL_INTERVAL=5
# 10 minutes = 600 seconds.
TIMEOUT=600
ELAPSED=0
echo "Polling chat ${CHAT_ID} every ${POLL_INTERVAL}s (timeout: ${TIMEOUT}s)..."
while true; do
RESPONSE=$(curl --silent --fail-with-body \
-H "Coder-Session-Token: ${CODER_SESSION_TOKEN}" \
"${CODER_URL}/api/experimental/chats/${CHAT_ID}")
STATUS=$(echo "${RESPONSE}" | jq -r '.status')
echo "[${ELAPSED}s] Chat status: ${STATUS}"
case "${STATUS}" in
waiting|completed)
echo "Chat reached terminal status: ${STATUS}"
echo "final_status=${STATUS}" >> "${GITHUB_OUTPUT}"
exit 0
;;
error)
echo "::error::Chat entered error state"
echo "${RESPONSE}" | jq .
echo "final_status=error" >> "${GITHUB_OUTPUT}"
exit 1
;;
pending|running)
# Still working — keep polling.
;;
*)
echo "::warning::Unknown chat status: ${STATUS}"
;;
esac
if [[ ${ELAPSED} -ge ${TIMEOUT} ]]; then
echo "::error::Timed out after ${TIMEOUT}s waiting for chat to finish"
echo "final_status=timeout" >> "${GITHUB_OUTPUT}"
exit 1
fi
sleep "${POLL_INTERVAL}"
ELAPSED=$((ELAPSED + POLL_INTERVAL))
done
# ------------------------------------------------------------------
# Step 5: Comment on the GitHub issue with a link to the chat.
# Only comment if the issue belongs to this repository (same guard
# as the Tasks API workflow).
# ------------------------------------------------------------------
- name: Comment on issue
if: startsWith(steps.determine-inputs.outputs.issue_url, format('{0}/{1}', github.server_url, github.repository))
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
GH_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
COMMENT_BODY=$(cat <<EOF
🤖 **AI Triage Chat Created**
A Coder chat session has been created to investigate this issue.
**Chat URL:** ${CHAT_URL}
**Chat ID:** \`${CHAT_ID}\`
**Status:** ${FINAL_STATUS}
The agent is working on a triage plan. Visit the chat to follow progress or provide guidance.
EOF
)
gh issue comment "${ISSUE_URL}" --body "${COMMENT_BODY}"
echo "Comment posted on ${ISSUE_URL}"
# ------------------------------------------------------------------
# Step 6: Write a summary to the GitHub Actions step summary.
# ------------------------------------------------------------------
- name: Write summary
env:
CHAT_ID: ${{ steps.create-chat.outputs.chat_id }}
CHAT_URL: ${{ steps.create-chat.outputs.chat_url }}
FINAL_STATUS: ${{ steps.poll-status.outputs.final_status }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
run: |
set -euo pipefail
{
echo "## AI Triage via Chat API"
echo ""
echo "**Issue:** ${ISSUE_URL}"
echo "**Chat ID:** \`${CHAT_ID}\`"
echo "**Chat URL:** ${CHAT_URL}"
echo "**Status:** ${FINAL_STATUS}"
} >> "${GITHUB_STEP_SUMMARY}"
-2
View File
@@ -29,8 +29,6 @@ EDE = "EDE"
HELO = "HELO"
LKE = "LKE"
byt = "byt"
cpy = "cpy"
Cpy = "Cpy"
typ = "typ"
# file extensions used in seti icon theme
styl = "styl"
+1 -17
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
with:
egress-policy: audit
@@ -30,22 +30,6 @@ jobs:
with:
persist-credentials: false
- name: Rewrite same-repo links for PR branch
if: github.event_name == 'pull_request'
env:
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
run: |
# Rewrite same-repo blob/tree main links to the PR head SHA
# so that files or directories introduced in the PR are
# reachable during link checking.
{
echo 'replacementPatterns:'
echo " - pattern: \"https://github.com/coder/coder/blob/main/\""
echo " replacement: \"https://github.com/coder/coder/blob/${HEAD_SHA}/\""
echo " - pattern: \"https://github.com/coder/coder/tree/main/\""
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
} >> .github/.linkspector.yml
- name: Check Markdown links
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
id: markdown-link-check
+6 -60
View File
@@ -50,7 +50,7 @@ Only pause to ask for confirmation when:
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
| **Pre-push** | `make pre-push` | Heavier CI checks (allowlisted) |
| **Pre-push** | `make pre-push` | All CI checks including tests |
### Documentation Commands
@@ -100,31 +100,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
```
### API Design
- Add swagger annotations when introducing new HTTP endpoints. Do this in
the same change as the handler so the docs do not get missed before
release.
- For user-scoped or resource-scoped routes, prefer path parameters over
query parameters when that matches existing route patterns.
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
### Database Query Naming
- Use `ByX` when `X` is the lookup or filter column.
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
dimension.
- Avoid `ByX` names for grouped queries.
### Database-to-SDK Conversions
- Extract explicit db-to-SDK conversion helpers instead of inlining large
conversion blocks inside handlers.
- Keep nullable-field handling, type coercion, and response shaping in the
converter so handlers stay focused on request flow and authorization.
## Quick Reference
### Full workflows available in imported WORKFLOWS.md
@@ -146,20 +121,11 @@ git config core.hooksPath scripts/githooks
Two hooks run automatically:
- **pre-commit**: Classifies staged files by type and runs either
the full `make pre-commit` or the lightweight `make pre-commit-light`
depending on whether Go, TypeScript, SQL, proto, or Makefile
changes are present. Falls back to the full target when
`CODER_HOOK_RUN_ALL=1` is set. A markdown-only commit takes
seconds; a Go change takes several minutes.
- **pre-push**: Classifies changed files (vs remote branch or
merge-base) and runs `make pre-push` when Go, TypeScript, SQL,
proto, or Makefile changes are detected. Skips tests entirely
for lightweight changes. Allowlisted in
`scripts/githooks/pre-push`. Runs only for developers who opt
in. Falls back to `make pre-push` when the diff range can't
be determined or `CODER_HOOK_RUN_ALL=1` is set. Allow at least
15 minutes for a full run.
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (full CI suite including tests).
Runs before pushing to catch everything CI would. Allow at least
15 minutes (race tests are slow without cache).
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
@@ -217,26 +183,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
- PR titles follow the same `type(scope): message` format.
- When you use a scope, it must be a real filesystem path containing every
changed file.
- Use a broader path scope, or omit the scope, for cross-cutting changes.
- Example: `fix(coderd/chatd): ...` for changes only in `coderd/chatd/`.
### Frontend Patterns
- Prefer existing shared UI components and utilities over custom
implementations. Reuse common primitives such as loading, table, and error
handling components when they fit the use case.
- Use Storybook stories for all component and page testing, including
visual presentation, user interactions, keyboard navigation, focus
management, and accessibility behavior. Do not create standalone
vitest/RTL test files for components or pages. Stories double as living
documentation, visual regression coverage, and interaction test suites
via `play` functions. Reserve plain vitest files for pure logic only:
utility functions, data transformations, hooks tested via
`renderHook()` that do not require DOM assertions, and query/cache
operations with no rendered output.
### Writing Comments
+68 -92
View File
@@ -27,7 +27,6 @@ ifdef MAKE_TIMED
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
.SHELLFLAGS = $@ -ceu
export MAKE_TIMED
export MAKE_LOGDIR
endif
# This doesn't work on directories.
@@ -115,18 +114,15 @@ POSTGRES_VERSION ?= 17
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
# nproc/4 (min 2) since test, lint, and build targets have internal
# nproc/4 (min 2) since test and lint targets have internal
# parallelism. Override: make pre-push PARALLEL_JOBS=8
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
# Use the highest ZSTD compression level in release builds to
# minimize artifact size. For non-release CI builds (e.g. main
# branch preview), use multithreaded level 6 which is ~99% faster
# at the cost of ~30% larger archives.
ifeq ($(CODER_RELEASE),true)
# Use the highest ZSTD compression level in CI.
ifdef CI
ZSTDFLAGS := -22 --ultra
else
ZSTDFLAGS := -6 -T0
ZSTDFLAGS := -6
endif
# Common paths to exclude from find commands, this rule is written so
@@ -136,10 +132,18 @@ endif
# the search path so that these exclusions match.
FIND_EXCLUSIONS= \
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
# Source files used for make targets, evaluated on use.
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
# Same as GO_SRC_FILES but excluding certain files that have problematic
# Makefile dependencies (e.g. pnpm).
MOST_GO_SRC_FILES := $(shell \
find . \
$(FIND_EXCLUSIONS) \
-type f \
-name '*.go' \
-not -name '*_test.go' \
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
)
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
@@ -506,26 +510,13 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
# Only wildcard the go files in the develop directory to avoid rebuilds
# when project files are changd. Technically changes to some imports may
# not be detected, but it's unlikely to cause any issues.
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
GREEN := $(shell tput setaf 2 2>/dev/null)
RED := $(shell tput setaf 1 2>/dev/null)
YELLOW := $(shell tput setaf 3 2>/dev/null)
DIM := $(shell tput dim 2>/dev/null || tput setaf 8 2>/dev/null)
RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
.PHONY: fmt
# Subset of fmt that does not require Go or Node toolchains.
fmt-light: fmt/shfmt fmt/terraform fmt/markdown
.PHONY: fmt-light
fmt/go:
ifdef FILE
# Format single file
@@ -633,10 +624,6 @@ LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
.PHONY: lint
# Subset of lint that does not require Go or Node toolchains.
lint-light: lint/shellcheck lint/markdown lint/helm lint/bootstrap lint/migrations lint/actions/actionlint lint/typos
.PHONY: lint-light
lint/site-icons:
./scripts/check_site_icons.sh
.PHONY: lint/site-icons
@@ -723,93 +710,89 @@ lint/typos: build/typos-$(TYPOS_VERSION)
build/typos-$(TYPOS_VERSION) --config .github/workflows/typos.toml
.PHONY: lint/typos
# pre-commit and pre-push mirror CI checks locally.
# pre-commit and pre-push mirror CI "required" jobs locally.
# See the "required" job's needs list in .github/workflows/ci.yaml.
#
# pre-commit runs checks that don't need external services (Docker,
# Playwright). This is the git pre-commit hook default since Docker
# and browser issues in the local environment would otherwise block
# Playwright). This is the git pre-commit hook default since test
# and Docker failures in the local environment would otherwise block
# all commits.
#
# pre-push adds heavier checks: Go tests, JS tests, and site build.
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
# pre-push runs the full CI suite including tests. This is the git
# pre-push hook default, catching everything CI would before pushing.
#
# pre-commit uses two phases: gen+fmt first, then lint+build. This
# avoids races where gen's `go run` creates temporary .go files that
# lint's find-based checks pick up. Within each phase, targets run in
# parallel via -j. It fails if any tracked files have unstaged
# changes afterward.
# pre-push uses two-phase execution: gen+fmt+test-postgres-docker
# first (writes files, starts Docker), then lint+build+test in
# parallel. pre-commit uses two phases: gen+fmt first, then
# lint+build. This avoids races where gen's `go run` creates
# temporary .go files that lint's find-based checks pick up.
# Within each phase, targets run in parallel via -j. Both fail if
# any tracked files have unstaged changes afterward.
#
# Both pre-commit and pre-push:
# gen, fmt, lint, lint/typos, slim binary (local arch)
#
# pre-push only (need external services or are slow):
# site/out/index.html (pnpm build)
# test-postgres-docker + test (needs Docker)
# test-js, test-e2e (needs Playwright)
# sqlc-vet (needs Docker)
# offlinedocs/check
#
# Omitted:
# test-go-pg-17 (same tests, different PG version)
define check-unstaged
unstaged="$$(git diff --name-only)"
if [[ -n $$unstaged ]]; then
echo "$(RED)✗ check unstaged changes$(RESET)"
echo "$$unstaged" | sed 's/^/ - /'
echo ""
echo "$(DIM) Verify generated changes are correct before staging:$(RESET)"
echo "$(DIM) git diff$(RESET)"
echo "$(DIM) git add -u && git commit$(RESET)"
echo "ERROR: unstaged changes in tracked files:"
echo "$$unstaged"
echo
echo "Review each change (git diff), verify correctness, then stage:"
echo " git add -u && git commit"
exit 1
fi
endef
define check-untracked
untracked=$$(git ls-files --other --exclude-standard)
if [[ -n $$untracked ]]; then
echo "$(YELLOW)? check untracked files$(RESET)"
echo "$$untracked" | sed 's/^/ - /'
echo ""
echo "$(DIM) Review if these should be committed or added to .gitignore.$(RESET)"
echo "WARNING: untracked files (not in this commit, won't be in CI):"
echo "$$untracked"
echo
fi
endef
pre-commit:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit.XXXXXX")
echo "$(BOLD)pre-commit$(RESET) ($$logdir)"
echo "gen + fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir gen fmt
echo "=== Phase 1/2: gen + fmt ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt
$(check-unstaged)
echo "lint + build:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
echo "=== Phase 2/2: lint + build ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
echo "$(BOLD)$(GREEN)=== pre-commit passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-commit
# Lightweight pre-commit for changes that don't touch Go or
# TypeScript. Skips gen, lint/go, lint/ts, fmt/go, fmt/ts, and
# the binary build. Used by the pre-commit hook when only docs,
# shell, terraform, helm, or other fast-to-check files changed.
pre-commit-light:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit-light.XXXXXX")
echo "$(BOLD)pre-commit-light$(RESET) ($$logdir)"
echo "fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir fmt-light
$(check-unstaged)
echo "lint:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir lint-light
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit-light passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
.PHONY: pre-commit-light
pre-push:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
echo "$(BOLD)pre-push$(RESET) ($$logdir)"
echo "test + build site:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
echo "=== Phase 1/2: gen + fmt + postgres ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt test-postgres-docker
$(check-unstaged)
echo "=== Phase 2/2: lint + build + test ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT) \
site/out/index.html \
test \
test-js \
test-storybook \
site/out/index.html
rm -rf $$logdir
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
test-e2e \
test-race \
sqlc-vet \
offlinedocs/check
$(check-unstaged)
echo "$(BOLD)$(GREEN)=== pre-push passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-push
offlinedocs/check: offlinedocs/node_modules/.installed
@@ -1341,11 +1324,6 @@ test-js: site/node_modules/.installed
pnpm test:ci
.PHONY: test-js
test-storybook: site/node_modules/.installed
cd site/
pnpm exec vitest run --project=storybook
.PHONY: test-storybook
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
# dependency for any sqlc-cloud related targets.
sqlc-cloud-is-setup:
@@ -1494,5 +1472,3 @@ dogfood/coder/nix.hash: flake.nix flake.lock
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
.PHONY: count-test-databases
+2 -16
View File
@@ -39,7 +39,6 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/clistat"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit"
@@ -311,7 +310,6 @@ type agent struct {
filesAPI *agentfiles.API
gitAPI *agentgit.API
processAPI *agentproc.API
desktopAPI *agentdesktop.API
socketServerEnabled bool
socketPath string
@@ -385,18 +383,10 @@ func (a *agent) init() {
pathStore := agentgit.NewPathStore()
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
if m := a.manifest.Load(); m != nil {
return m.Directory
}
return ""
})
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
@@ -2067,10 +2057,6 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
}
if err := a.desktopAPI.Close(); err != nil {
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
-56
View File
@@ -3040,62 +3040,6 @@ func TestAgent_Reconnect(t *testing.T) {
closer.Close()
}
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
fCoordinator := tailnettest.NewFakeCoordinator()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
agentID,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{{
Script: "echo hello",
Timeout: 30 * time.Second,
RunOnStart: true,
}},
},
statsCh,
fCoordinator,
)
defer client.Close()
closer := agent.New(agent.Options{
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()
// Wait for the agent to reach Ready state.
require.Eventually(t, func() bool {
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
}, testutil.WaitShort, testutil.IntervalFast)
statesBefore := slices.Clone(client.GetLifecycleStates())
// Disconnect by closing the coordinator response channel.
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
close(call1.Resps)
// Wait for reconnect.
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
// Wait for a stats report as a deterministic steady-state proof.
testutil.RequireReceive(ctx, t, statsCh)
statesAfter := client.GetLifecycleStates()
require.Equal(t, statesBefore, statesAfter,
"lifecycle states should not be re-reported after reconnect")
closer.Close()
}
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
-14
View File
@@ -57,26 +57,18 @@ type fakeContainerCLI struct {
}
func (f *fakeContainerCLI) List(_ context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.containers, f.listErr
}
func (f *fakeContainerCLI) DetectArchitecture(_ context.Context, _ string) (string, error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.arch, f.archErr
}
func (f *fakeContainerCLI) Copy(ctx context.Context, name, src, dst string) error {
f.mu.Lock()
defer f.mu.Unlock()
return f.copyErr
}
func (f *fakeContainerCLI) ExecAs(ctx context.Context, name, user string, args ...string) ([]byte, error) {
f.mu.Lock()
defer f.mu.Unlock()
return nil, f.execErr
}
@@ -2697,9 +2689,7 @@ func TestAPI(t *testing.T) {
// When: The container is recreated (new container ID) with config changes.
terraformContainer.ID = "new-container-id"
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = terraformContainer.ID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
Apps: []agentcontainers.SubAgentApp{{Slug: "app2"}}, // Changed app triggers recreation logic.
@@ -2831,9 +2821,7 @@ func TestAPI(t *testing.T) {
// Simulate container rebuild: new container ID, changed display apps.
newContainerID := "new-container-id"
terraformContainer.ID = newContainerID
fCCLI.mu.Lock()
fCCLI.containers.Containers = []codersdk.WorkspaceAgentContainer{terraformContainer}
fCCLI.mu.Unlock()
fDCCLI.upID = newContainerID
fDCCLI.readConfig.MergedConfiguration.Customizations.Coder = []agentcontainers.CoderCustomization{{
DisplayApps: map[codersdk.DisplayApp]bool{
@@ -4938,11 +4926,9 @@ func TestDevcontainerPrebuildSupport(t *testing.T) {
)
api.Start()
fCCLI.mu.Lock()
fCCLI.containers = codersdk.WorkspaceAgentListContainersResponse{
Containers: []codersdk.WorkspaceAgentContainer{testContainer},
}
fCCLI.mu.Unlock()
// Given: We allow the dev container to be created.
fDCCLI.upID = testContainer.ID
-536
View File
@@ -1,536 +0,0 @@
package agentdesktop
import (
"encoding/json"
"math"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
// DesktopAction is the request body for the desktop action endpoint.
type DesktopAction struct {
Action string `json:"action"`
Coordinate *[2]int `json:"coordinate,omitempty"`
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
Text *string `json:"text,omitempty"`
Duration *int `json:"duration,omitempty"`
ScrollAmount *int `json:"scroll_amount,omitempty"`
ScrollDirection *string `json:"scroll_direction,omitempty"`
// ScaledWidth and ScaledHeight are the coordinate space the
// model is using. When provided, coordinates are linearly
// mapped from scaled → native before dispatching.
ScaledWidth *int `json:"scaled_width,omitempty"`
ScaledHeight *int `json:"scaled_height,omitempty"`
}
// DesktopActionResponse is the response from the desktop action
// endpoint.
type DesktopActionResponse struct {
Output string `json:"output,omitempty"`
ScreenshotData string `json:"screenshot_data,omitempty"`
ScreenshotWidth int `json:"screenshot_width,omitempty"`
ScreenshotHeight int `json:"screenshot_height,omitempty"`
}
// API exposes the desktop streaming HTTP routes for the agent.
type API struct {
logger slog.Logger
desktop Desktop
clock quartz.Clock
}
// NewAPI creates a new desktop streaming API.
func NewAPI(logger slog.Logger, desktop Desktop, clock quartz.Clock) *API {
if clock == nil {
clock = quartz.NewReal()
}
return &API{
logger: logger,
desktop: desktop,
clock: clock,
}
}
// Routes returns the chi router for mounting at /api/v0/desktop.
func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/vnc", a.handleDesktopVNC)
r.Post("/action", a.handleAction)
return r
}
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Start the desktop session (idempotent).
_, err := a.desktop.Start(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start desktop session.",
Detail: err.Error(),
})
return
}
// Get a VNC connection.
vncConn, err := a.desktop.VNCConn(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to connect to VNC server.",
Detail: err.Error(),
})
return
}
defer vncConn.Close()
// Accept WebSocket from coderd.
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
return
}
// No read limit — RFB framebuffer updates can be large.
conn.SetReadLimit(-1)
wsCtx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
// Bicopy raw bytes between WebSocket and VNC TCP.
agentssh.Bicopy(wsCtx, wsNetConn, vncConn)
}
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handlerStart := a.clock.Now()
// Ensure the desktop is running and grab native dimensions.
cfg, err := a.desktop.Start(ctx)
if err != nil {
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
slog.Error(err),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start desktop session.",
Detail: err.Error(),
})
return
}
var action DesktopAction
if err := json.NewDecoder(r.Body).Decode(&action); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to decode request body.",
Detail: err.Error(),
})
return
}
a.logger.Info(ctx, "handleAction: started",
slog.F("action", action.Action),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
// Helper to scale a coordinate pair from the model's space to
// native display pixels.
scaleXY := func(x, y int) (int, int) {
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
}
return x, y
}
var resp DesktopActionResponse
switch action.Action {
case "key":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for key action.",
})
return
}
if err := a.desktop.KeyPress(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key press failed.",
Detail: err.Error(),
})
return
}
resp.Output = "key action performed"
case "type":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for type action.",
})
return
}
if err := a.desktop.Type(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Type action failed.",
Detail: err.Error(),
})
return
}
resp.Output = "type action performed"
case "cursor_position":
x, y, err := a.desktop.CursorPosition(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Cursor position failed.",
Detail: err.Error(),
})
return
}
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
case "mouse_move":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Move(ctx, x, y); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Mouse move failed.",
Detail: err.Error(),
})
return
}
resp.Output = "mouse_move action performed"
case "left_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
stepStart := a.clock.Now()
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
a.logger.Warn(ctx, "handleAction: Click failed",
slog.F("action", "left_click"),
slog.F("step", "click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left click failed.",
Detail: err.Error(),
})
return
}
a.logger.Debug(ctx, "handleAction: Click completed",
slog.F("action", "left_click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
resp.Output = "left_click action performed"
case "left_click_drag":
if action.Coordinate == nil || action.StartCoordinate == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"coordinate\" or \"start_coordinate\" for left_click_drag.",
})
return
}
sx, sy := scaleXY(action.StartCoordinate[0], action.StartCoordinate[1])
ex, ey := scaleXY(action.Coordinate[0], action.Coordinate[1])
if err := a.desktop.Drag(ctx, sx, sy, ex, ey); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left click drag failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_click_drag action performed"
case "left_mouse_down":
if err := a.desktop.ButtonDown(ctx, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left mouse down failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_mouse_down action performed"
case "left_mouse_up":
if err := a.desktop.ButtonUp(ctx, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left mouse up failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_mouse_up action performed"
case "right_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Click(ctx, x, y, MouseButtonRight); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Right click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "right_click action performed"
case "middle_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Click(ctx, x, y, MouseButtonMiddle); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Middle click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "middle_click action performed"
case "double_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.DoubleClick(ctx, x, y, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Double click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "double_click action performed"
case "triple_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
for range 3 {
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Triple click failed.",
Detail: err.Error(),
})
return
}
}
resp.Output = "triple_click action performed"
case "scroll":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
amount := 3
if action.ScrollAmount != nil {
amount = *action.ScrollAmount
}
direction := "down"
if action.ScrollDirection != nil {
direction = *action.ScrollDirection
}
var dx, dy int
switch direction {
case "up":
dy = -amount
case "down":
dy = amount
case "left":
dx = -amount
case "right":
dx = amount
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid scroll direction: " + direction,
})
return
}
if err := a.desktop.Scroll(ctx, x, y, dx, dy); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Scroll failed.",
Detail: err.Error(),
})
return
}
resp.Output = "scroll action performed"
case "hold_key":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for hold_key action.",
})
return
}
dur := 1000
if action.Duration != nil {
dur = *action.Duration
}
if err := a.desktop.KeyDown(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key down failed.",
Detail: err.Error(),
})
return
}
timer := a.clock.NewTimer(time.Duration(dur)*time.Millisecond, "agentdesktop", "hold_key")
defer timer.Stop()
select {
case <-ctx.Done():
// Context canceled; release the key immediately.
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
}
return
case <-timer.C:
}
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key up failed.",
Detail: err.Error(),
})
return
}
resp.Output = "hold_key action performed"
case "screenshot":
var opts ScreenshotOptions
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
opts.TargetWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
opts.TargetHeight = *action.ScaledHeight
}
result, err := a.desktop.Screenshot(ctx, opts)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Screenshot failed.",
Detail: err.Error(),
})
return
}
resp.Output = "screenshot"
resp.ScreenshotData = result.Data
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
resp.ScreenshotWidth = *action.ScaledWidth
} else {
resp.ScreenshotWidth = cfg.Width
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
resp.ScreenshotHeight = *action.ScaledHeight
} else {
resp.ScreenshotHeight = cfg.Height
}
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown action: " + action.Action,
})
return
}
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
if ctx.Err() != nil {
a.logger.Error(ctx, "handleAction: context canceled before writing response",
slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs),
slog.Error(ctx.Err()),
)
return
}
a.logger.Info(ctx, "handleAction: writing response",
slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs),
)
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// Close shuts down the desktop session if one is running.
func (a *API) Close() error {
return a.desktop.Close()
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
// returning an error if the coordinate field is missing.
func coordFromAction(action DesktopAction) (x, y int, err error) {
if action.Coordinate == nil {
return 0, 0, &missingFieldError{field: "coordinate", action: action.Action}
}
return action.Coordinate[0], action.Coordinate[1], nil
}
// missingFieldError is returned when a required field is absent from
// a DesktopAction.
type missingFieldError struct {
field string
action string
}
func (e *missingFieldError) Error() string {
return "Missing \"" + e.field + "\" for " + e.action + " action."
}
// scaleCoordinate maps a coordinate from scaled → native space.
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
if scaledDim == 0 || scaledDim == nativeDim {
return scaled
}
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
// Clamp to valid range.
native = math.Max(native, 0)
native = math.Min(native, float64(nativeDim-1))
return int(native)
}
-467
View File
@@ -1,467 +0,0 @@
package agentdesktop_test
import (
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
// fakeDesktop is a minimal Desktop implementation for unit tests.
type fakeDesktop struct {
startErr error
startCfg agentdesktop.DisplayConfig
vncConnErr error
screenshotErr error
screenshotRes agentdesktop.ScreenshotResult
closed bool
// Track calls for assertions.
lastMove [2]int
lastClick [3]int // x, y, button
lastScroll [4]int // x, y, dx, dy
lastKey string
lastTyped string
lastKeyDown string
lastKeyUp string
}
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
return f.startCfg, f.startErr
}
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
return nil, f.vncConnErr
}
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
return f.screenshotRes, f.screenshotErr
}
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
f.lastMove = [2]int{x, y}
return nil
}
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 1}
return nil
}
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 2}
return nil
}
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
f.lastScroll = [4]int{x, y, dx, dy}
return nil
}
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
f.lastKey = key
return nil
}
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
f.lastKeyDown = key
return nil
}
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
f.lastKeyUp = key
return nil
}
func (f *fakeDesktop) Type(_ context.Context, text string) error {
f.lastTyped = text
return nil
}
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return 10, 20, nil
}
func (f *fakeDesktop) Close() error {
f.closed = true
return nil
}
func TestHandleDesktopVNC_StartError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
var resp codersdk.Response
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Failed to start desktop session.", resp.Message)
}
func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "screenshot"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
// Dimensions come from DisplayConfig, not the screenshot CLI.
assert.Equal(t, "screenshot", result.Output)
assert.Equal(t, "base64data", result.ScreenshotData)
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
}
func TestHandleAction_LeftClick(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{
Action: "left_click",
Coordinate: &[2]int{100, 200},
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "left_click action performed", resp.Output)
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
}
func TestHandleAction_UnknownAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "explode"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestHandleAction_KeyAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "Return"
body := agentdesktop.DesktopAction{
Action: "key",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "Return", fake.lastKey)
}
func TestHandleAction_TypeAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "hello world"
body := agentdesktop.DesktopAction{
Action: "type",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "hello world", fake.lastTyped)
}
func TestHandleAction_HoldKey(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
mClk := quartz.NewMock(t)
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
defer trap.Close()
api := agentdesktop.NewAPI(logger, fake, mClk)
defer api.Close()
text := "Shift_L"
dur := 100
body := agentdesktop.DesktopAction{
Action: "hold_key",
Text: &text,
Duration: &dur,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
done := make(chan struct{})
go func() {
defer close(done)
handler.ServeHTTP(rr, req)
}()
// Wait for the timer to be created, then advance past it.
trap.MustWait(req.Context()).MustRelease(req.Context())
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
<-done
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "hold_key action performed", resp.Output)
assert.Equal(t, "Shift_L", fake.lastKeyDown)
assert.Equal(t, "Shift_L", fake.lastKeyUp)
}
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "hold_key"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
}
func TestHandleAction_ScrollDown(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
dir := "down"
amount := 5
body := agentdesktop.DesktopAction{
Action: "scroll",
Coordinate: &[2]int{500, 400},
ScrollDirection: &dir,
ScrollAmount: &amount,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// dy should be positive 5 for "down".
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
}
func TestHandleAction_CoordinateScaling(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
// Native display is 1920x1080.
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
// Model is working in a 1280x720 coordinate space.
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{640, 360},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
// midpoint).
assert.Equal(t, 960, fake.lastMove[0])
assert.Equal(t, 540, fake.lastMove[1])
}
func TestClose_DelegatesToDesktop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
assert.True(t, fake.closed)
}
func TestClose_PreventsNewSessions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// After Close(), Start() will return an error because the
// underlying Desktop is closed.
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
// Simulate the closed desktop returning an error on Start().
fake.startErr = xerrors.New("desktop is closed")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
-91
View File
@@ -1,91 +0,0 @@
package agentdesktop
import (
"context"
"net"
)
// Desktop abstracts a virtual desktop session running inside a workspace.
type Desktop interface {
// Start launches the desktop session. It is idempotent — calling
// Start on an already-running session returns the existing
// config. The returned DisplayConfig describes the running
// session.
Start(ctx context.Context) (DisplayConfig, error)
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames. Each call returns a new
// connection; multiple clients can connect simultaneously.
// Start must be called before VNCConn.
VNCConn(ctx context.Context) (net.Conn, error)
// Screenshot captures the current framebuffer as a PNG and
// returns it base64-encoded. TargetWidth/TargetHeight in opts
// are the desired output dimensions (the implementation
// rescales); pass 0 to use native resolution.
Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error)
// Mouse operations.
// Move moves the mouse cursor to absolute coordinates.
Move(ctx context.Context, x, y int) error
// Click performs a mouse button click at the given coordinates.
Click(ctx context.Context, x, y int, button MouseButton) error
// DoubleClick performs a double-click at the given coordinates.
DoubleClick(ctx context.Context, x, y int, button MouseButton) error
// ButtonDown presses and holds a mouse button.
ButtonDown(ctx context.Context, button MouseButton) error
// ButtonUp releases a mouse button.
ButtonUp(ctx context.Context, button MouseButton) error
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
Scroll(ctx context.Context, x, y, dx, dy int) error
// Drag moves from (startX,startY) to (endX,endY) while holding
// the left mouse button.
Drag(ctx context.Context, startX, startY, endX, endY int) error
// Keyboard operations.
// KeyPress sends a key-down then key-up for a key combo string
// (e.g. "Return", "ctrl+c").
KeyPress(ctx context.Context, keys string) error
// KeyDown presses and holds a key.
KeyDown(ctx context.Context, key string) error
// KeyUp releases a key.
KeyUp(ctx context.Context, key string) error
// Type types a string of text character-by-character.
Type(ctx context.Context, text string) error
// CursorPosition returns the current cursor coordinates.
CursorPosition(ctx context.Context) (x, y int, err error)
// Close shuts down the desktop session and cleans up resources.
Close() error
}
// DisplayConfig describes a running desktop session.
type DisplayConfig struct {
Width int // native width in pixels
Height int // native height in pixels
VNCPort int // local TCP port for the VNC server
Display int // X11 display number (e.g. 1 for :1), -1 if N/A
}
// MouseButton identifies a mouse button.
type MouseButton string
const (
MouseButtonLeft MouseButton = "left"
MouseButtonRight MouseButton = "right"
MouseButtonMiddle MouseButton = "middle"
)
// ScreenshotOptions configures a screenshot capture.
type ScreenshotOptions struct {
TargetWidth int // 0 = native
TargetHeight int // 0 = native
}
// ScreenshotResult is a captured screenshot.
type ScreenshotResult struct {
Data string // base64-encoded PNG
}
-399
View File
@@ -1,399 +0,0 @@
package agentdesktop
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
VNCPort int `json:"vncPort"`
Geometry string `json:"geometry"` // e.g. "1920x1080"
}
// desktopSession tracks a running portabledesktop process.
type desktopSession struct {
cmd *exec.Cmd
vncPort int
width int // native width, parsed from geometry
height int // native height, parsed from geometry
display int // X11 display number, -1 if not available
cancel context.CancelFunc
}
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
type cursorOutput struct {
X int `json:"x"`
Y int `json:"y"`
}
// screenshotOutput is the JSON output from
// `portabledesktop screenshot --json`.
type screenshotOutput struct {
Data string `json:"data"`
}
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
scriptBinDir string // coder script bin directory
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. scriptBinDir is
// the coder script bin directory checked for the binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
scriptBinDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
scriptBinDir: scriptBinDir,
}
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return DisplayConfig{}, xerrors.New("desktop is closed")
}
if err := p.ensureBinary(ctx); err != nil {
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
// If we have an existing session, check if it's still alive.
if p.session != nil {
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
return DisplayConfig{
Width: p.session.width,
Height: p.session.height,
VNCPort: p.session.vncPort,
Display: p.session.display,
}, nil
}
// Process died — clean up and recreate.
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
p.session.cancel()
p.session = nil
}
// Spawn portabledesktop up --json.
sessionCtx, sessionCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
}
// Parse the JSON output to get VNC port and geometry.
var output portableDesktopOutput
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
}
if output.VNCPort == 0 {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
}
var w, h int
if output.Geometry != "" {
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
slog.F("geometry", output.Geometry),
slog.Error(err),
)
}
}
p.logger.Info(ctx, "started portabledesktop session",
slog.F("vnc_port", output.VNCPort),
slog.F("width", w),
slog.F("height", h),
slog.F("pid", cmd.Process.Pid),
)
p.session = &desktopSession{
cmd: cmd,
vncPort: output.VNCPort,
width: w,
height: h,
display: -1,
cancel: sessionCancel,
}
return DisplayConfig{
Width: w,
Height: h,
VNCPort: output.VNCPort,
Display: -1,
}, nil
}
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames.
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
p.mu.Lock()
session := p.session
p.mu.Unlock()
if session == nil {
return nil, xerrors.New("desktop session not started")
}
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
}
// Screenshot captures the current framebuffer as a base64-encoded PNG.
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
args := []string{"screenshot", "--json"}
if opts.TargetWidth > 0 {
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
}
if opts.TargetHeight > 0 {
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
}
out, err := p.runCmd(ctx, args...)
if err != nil {
return ScreenshotResult{}, err
}
var result screenshotOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
}
return ScreenshotResult(result), nil
}
// Move moves the mouse cursor to absolute coordinates.
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
return err
}
// Click performs a mouse button click at the given coordinates.
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// DoubleClick performs a double-click at the given coordinates.
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// ButtonDown presses and holds a mouse button.
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "down", string(button))
return err
}
// ButtonUp releases a mouse button.
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "up", string(button))
return err
}
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
return err
}
// Drag moves from (startX,startY) to (endX,endY) while holding the
// left mouse button.
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
return err
}
// KeyPress sends a key-down then key-up for a key combo string.
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
_, err := p.runCmd(ctx, "keyboard", "key", keys)
return err
}
// KeyDown presses and holds a key.
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "down", key)
return err
}
// KeyUp releases a key.
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "up", key)
return err
}
// Type types a string of text character-by-character.
func (p *portableDesktop) Type(ctx context.Context, text string) error {
_, err := p.runCmd(ctx, "keyboard", "type", text)
return err
}
// CursorPosition returns the current cursor coordinates.
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
out, err := p.runCmd(ctx, "cursor", "--json")
if err != nil {
return 0, 0, err
}
var result cursorOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
}
return result.X, result.Y, nil
}
// Close shuts down the desktop session and cleans up resources.
func (p *portableDesktop) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
if p.session != nil {
p.session.cancel()
// Xvnc is a child process — killing it cleans up the X
// session.
_ = p.session.cmd.Process.Kill()
_ = p.session.cmd.Wait()
p.session = nil
}
return nil
}
// runCmd executes a portabledesktop subcommand and returns combined
// output. The caller must have previously called ensureBinary.
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
start := time.Now()
//nolint:gosec // args are constructed by the caller, not user input.
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if err != nil {
p.logger.Warn(ctx, "portabledesktop command failed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
slog.Error(err),
slog.F("output", string(out)),
)
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
}
if elapsed > 5*time.Second {
p.logger.Warn(ctx, "portabledesktop command slow",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
} else {
p.logger.Debug(ctx, "portabledesktop command completed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
}
return string(out), nil
}
// ensureBinary resolves the portabledesktop binary from PATH or the
// coder script bin directory. It must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
}
// 1. Check PATH.
if path, err := exec.LookPath("portabledesktop"); err == nil {
p.logger.Info(ctx, "found portabledesktop in PATH",
slog.F("path", path),
)
p.binPath = path
return nil
}
// 2. Check the coder script bin directory.
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
// On Windows, permission bits don't indicate executability,
// so accept any regular file.
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
p.logger.Info(ctx, "found portabledesktop in script bin directory",
slog.F("path", scriptBinPath),
)
p.binPath = scriptBinPath
return nil
}
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
slog.F("path", scriptBinPath),
slog.F("mode", info.Mode().String()),
)
}
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
}
@@ -1,545 +0,0 @@
package agentdesktop
import (
"context"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty"
)
// recordedExecer implements agentexec.Execer by recording every
// invocation and delegating to a real shell command built from a
// caller-supplied mapping of subcommand → shell script body.
type recordedExecer struct {
mu sync.Mutex
commands [][]string
// scripts maps a subcommand keyword (e.g. "up", "screenshot")
// to a shell snippet whose stdout will be the command output.
scripts map[string]string
}
func (r *recordedExecer) record(cmd string, args ...string) {
r.mu.Lock()
defer r.mu.Unlock()
r.commands = append(r.commands, append([]string{cmd}, args...))
}
func (r *recordedExecer) allCommands() [][]string {
r.mu.Lock()
defer r.mu.Unlock()
out := make([][]string, len(r.commands))
copy(out, r.commands)
return out
}
// scriptFor finds the first matching script key present in args.
func (r *recordedExecer) scriptFor(args []string) string {
for _, a := range args {
if s, ok := r.scripts[a]; ok {
return s
}
}
// Fallback: succeed silently.
return "true"
}
func (r *recordedExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
r.record(cmd, args...)
script := r.scriptFor(args)
//nolint:gosec // Test helper — script content is controlled by the test.
return exec.CommandContext(ctx, "sh", "-c", script)
}
func (r *recordedExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
r.record(cmd, args...)
return pty.CommandContext(ctx, "sh", "-c", r.scriptFor(args))
}
// --- portableDesktop tests ---
func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := t.Context()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
assert.Equal(t, 1920, cfg.Width)
assert.Equal(t, 1080, cfg.Height)
assert.Equal(t, 5901, cfg.VNCPort)
assert.Equal(t, -1, cfg.Display)
// Clean up the long-running process.
require.NoError(t, pd.Close())
}
func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
cfg2, err := pd.Start(ctx)
require.NoError(t, err)
assert.Equal(t, cfg1, cfg2, "second Start should return the same config")
// The execer should have been called exactly once for "up".
cmds := rec.allCommands()
upCalls := 0
for _, c := range cmds {
for _, a := range c {
if a == "up" {
upCalls++
}
}
}
assert.Equal(t, 1, upCalls, "expected exactly one 'up' invocation")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"screenshot": `echo '{"data":"abc123"}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
assert.Equal(t, "abc123", result.Data)
}
func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"screenshot": `echo '{"data":"x"}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
})
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
// The last command should contain the target dimension flags.
last := cmds[len(cmds)-1]
joined := strings.Join(last, " ")
assert.Contains(t, joined, "--target-width 800")
assert.Contains(t, joined, "--target-height 600")
}
func TestPortableDesktop_MouseMethods(t *testing.T) {
t.Parallel()
// Each sub-test verifies a single mouse method dispatches the
// correct CLI arguments.
tests := []struct {
name string
invoke func(context.Context, *portableDesktop) error
wantArgs []string // substrings expected in a recorded command
}{
{
name: "Move",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Move(ctx, 42, 99)
},
wantArgs: []string{"mouse", "move", "42", "99"},
},
{
name: "Click",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Click(ctx, 10, 20, MouseButtonLeft)
},
// Click does move then click.
wantArgs: []string{"mouse", "click", "left"},
},
{
name: "DoubleClick",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.DoubleClick(ctx, 5, 6, MouseButtonRight)
},
wantArgs: []string{"mouse", "click", "right"},
},
{
name: "ButtonDown",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.ButtonDown(ctx, MouseButtonMiddle)
},
wantArgs: []string{"mouse", "down", "middle"},
},
{
name: "ButtonUp",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.ButtonUp(ctx, MouseButtonLeft)
},
wantArgs: []string{"mouse", "up", "left"},
},
{
name: "Scroll",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Scroll(ctx, 50, 60, 3, 4)
},
wantArgs: []string{"mouse", "scroll", "3", "4"},
},
{
name: "Drag",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Drag(ctx, 10, 20, 30, 40)
},
// Drag ends with mouse up left.
wantArgs: []string{"mouse", "up", "left"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"mouse": `echo ok`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds, "expected at least one command")
// Find at least one recorded command that contains
// all expected argument substrings.
found := false
for _, cmd := range cmds {
joined := strings.Join(cmd, " ")
match := true
for _, want := range tt.wantArgs {
if !strings.Contains(joined, want) {
match = false
break
}
}
if match {
found = true
break
}
}
assert.True(t, found,
"no recorded command matched %v; got %v", tt.wantArgs, cmds)
})
}
}
func TestPortableDesktop_KeyboardMethods(t *testing.T) {
t.Parallel()
tests := []struct {
name string
invoke func(context.Context, *portableDesktop) error
wantArgs []string
}{
{
name: "KeyPress",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyPress(ctx, "Return")
},
wantArgs: []string{"keyboard", "key", "Return"},
},
{
name: "KeyDown",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyDown(ctx, "shift")
},
wantArgs: []string{"keyboard", "down", "shift"},
},
{
name: "KeyUp",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyUp(ctx, "shift")
},
wantArgs: []string{"keyboard", "up", "shift"},
},
{
name: "Type",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Type(ctx, "hello world")
},
wantArgs: []string{"keyboard", "type", "hello world"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"keyboard": `echo ok`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(t.Context(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
last := cmds[len(cmds)-1]
joined := strings.Join(last, " ")
for _, want := range tt.wantArgs {
assert.Contains(t, joined, want)
}
})
}
}
func TestPortableDesktop_CursorPosition(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"cursor": `echo '{"x":100,"y":200}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(t.Context())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
}
func TestPortableDesktop_Close(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1024x768"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
scriptBinDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := t.Context()
_, err := pd.Start(ctx)
require.NoError(t, err)
// Session should exist.
pd.mu.Lock()
require.NotNil(t, pd.session)
pd.mu.Unlock()
require.NoError(t, pd.Close())
// Session should be cleaned up.
pd.mu.Lock()
assert.Nil(t, pd.session)
assert.True(t, pd.closed)
pd.mu.Unlock()
// Subsequent Start must fail.
_, err = pd.Start(ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
t.Parallel()
// When binPath is already set, ensureBinary should return
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(t.Context())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
require.NoError(t, os.Chmod(binPath, 0o755))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.NoError(t, err)
assert.Equal(t, binPath, pd.binPath)
}
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not support Unix permission bits")
}
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
scriptBinDir := t.TempDir()
binPath := filepath.Join(scriptBinDir, "portabledesktop")
// Write without execute permission.
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
_ = binPath
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: scriptBinDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestEnsureBinary_NotFound(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
scriptBinDir: t.TempDir(), // empty directory
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
+51 -174
View File
@@ -333,68 +333,22 @@ func (api *API) writeFile(ctx context.Context, r *http.Request, path string) (HT
return status, err
}
// Check if the target already exists so we can preserve its
// permissions on the temp file before rename.
var origMode os.FileMode
var haveOrigMode bool
if stat, serr := api.filesystem.Stat(path); serr == nil {
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: is a directory", path)
}
origMode = stat.Mode()
haveOrigMode = true
}
// Write to a temp file in the same directory so the rename is
// always on the same device (atomic).
tmpfile, err := afero.TempFile(api.filesystem, dir, filepath.Base(path))
f, err := api.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
return status, err
}
tmpName := tmpfile.Name()
defer f.Close()
_, err = io.Copy(tmpfile, r.Body)
if err != nil && !errors.Is(err, io.EOF) {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("write %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if haveOrigMode {
if err := api.filesystem.Chmod(tmpName, origMode); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
}
if err := api.filesystem.Rename(tmpName, path); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
api.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}
return 0, nil
@@ -493,10 +447,13 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
content := string(data)
for _, edit := range edits {
var err error
content, err = fuzzyReplace(content, edit)
if err != nil {
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
}
@@ -506,135 +463,68 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
if err != nil {
return http.StatusInternalServerError, err
}
tmpName := tmpfile.Name()
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
_ = tmpfile.Close()
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Close before rename to flush buffered data and catch write
// errors (e.g. delayed allocation failures).
if err := tmpfile.Close(); err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
// Set permissions on the temp file before rename so there is
// no window where the target has wrong permissions.
if err := api.filesystem.Chmod(tmpName, stat.Mode()); err != nil {
api.logger.Warn(ctx, "unable to set file permissions",
slog.F("path", path),
slog.Error(err),
)
}
err = api.filesystem.Rename(tmpName, path)
err = api.filesystem.Rename(tmpfile.Name(), path)
if err != nil {
if rerr := api.filesystem.Remove(tmpName); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
status := http.StatusInternalServerError
if errors.Is(err, os.ErrPermission) {
status = http.StatusForbidden
}
return status, err
return http.StatusInternalServerError, err
}
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace
// (indentation-tolerant).
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When edit.ReplaceAll is false (the default), the search string must
// match exactly one location. If multiple matches are found, an error
// is returned asking the caller to include more context or set
// replace_all.
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still
// applied at the byte offsets of the original content so that surrounding
// text (including indentation of untouched lines) is preserved.
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
search := edit.Search
replace := edit.Replace
// Pass 1 exact substring match.
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
if edit.ReplaceAll {
return strings.ReplaceAll(content, search, replace), nil
}
count := strings.Count(content, search)
if count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
// Exactly one match.
return strings.Replace(content, search, replace, 1), nil
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search
// into lines.
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element
// from SplitAfter. Drop it so it doesn't interfere with line
// matching.
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
trimRight := func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}
trimAll := func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return "", xerrors.New("search string not found in file. Verify the search " +
"string matches the file content exactly, including whitespace " +
"and indentation")
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
@@ -659,26 +549,6 @@ outer:
return 0, 0, false
}
// countLineMatches counts how many non-overlapping contiguous
// subsequences of contentLines match searchLines according to eq.
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
count := 0
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
return count
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
count++
i += len(searchLines) - 1 // skip past this match
}
return count
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
@@ -692,3 +562,10 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+3 -207
View File
@@ -14,7 +14,6 @@ import (
"strings"
"syscall"
"testing"
"testing/iotest"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -400,83 +399,6 @@ func TestWriteFile(t *testing.T) {
}
}
func TestWriteFile_ReportsIOError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := afero.NewMemMapFs()
api := agentfiles.NewAPI(logger, fs, nil)
tmpdir := os.TempDir()
path := filepath.Join(tmpdir, "write-io-error")
err := afero.WriteFile(fs, path, []byte("original"), 0o644)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// A reader that always errors simulates a failed body read
// (e.g. network interruption). The atomic write should leave
// the original file intact.
body := iotest.ErrReader(xerrors.New("simulated I/O error"))
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path), body)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
got := &codersdk.Error{}
err = json.NewDecoder(w.Body).Decode(got)
require.NoError(t, err)
require.ErrorContains(t, got, "simulated I/O error")
// The original file must survive the failed write.
data, err := afero.ReadFile(fs, path)
require.NoError(t, err)
require.Equal(t, "original", string(data))
}
func TestWriteFile_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Overwrite the file with new content.
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost,
fmt.Sprintf("/write-file?path=%s", path),
bytes.NewReader([]byte("#!/bin/sh\necho world\n")))
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"write_file should preserve the original file's permissions")
}
func TestEditFiles(t *testing.T) {
t.Parallel()
@@ -654,9 +576,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
// When the second edit creates ambiguity (two "bar"
// occurrences), it should fail.
name: "EditEditAmbiguous",
name: "EditEdit", // Edits affect previous edits.
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
@@ -673,33 +593,7 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 2 occurrences"},
// File should not be modified on error.
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
},
{
// With replace_all the cascading edit replaces
// both occurrences.
name: "EditEditReplaceAll",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit-ra"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
},
{
name: "Multiline",
@@ -826,7 +720,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchErrors",
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
@@ -839,46 +733,9 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"search string not found in file"},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "AmbiguousExactMatch",
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ambig-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 3 occurrences"},
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
},
{
name: "ReplaceAllExact",
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
@@ -985,67 +842,6 @@ func TestEditFiles(t *testing.T) {
}
}
func TestEditFiles_PreservesPermissions(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("file permissions are not reliably supported on Windows")
}
dir := t.TempDir()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
osFs := afero.NewOsFs()
api := agentfiles.NewAPI(logger, osFs, nil)
path := filepath.Join(dir, "script.sh")
err := afero.WriteFile(osFs, path, []byte("#!/bin/sh\necho hello\n"), 0o755)
require.NoError(t, err)
// Sanity-check the initial mode.
info, err := osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm())
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
body := workspacesdk.FileEditRequest{
Files: []workspacesdk.FileEdits{
{
Path: path,
Edits: []workspacesdk.FileEdit{
{
Search: "hello",
Replace: "world",
},
},
},
},
}
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(body)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/edit-files", buf)
api.Routes().ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.Code)
// Verify content was updated.
data, err := afero.ReadFile(osFs, path)
require.NoError(t, err)
require.Equal(t, "#!/bin/sh\necho world\n", string(data))
// Verify permissions are preserved after the
// temp-file-and-rename cycle.
info, err = osFs.Stat(path)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o755), info.Mode().Perm(),
"edit_files should preserve the original file's permissions")
}
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
t.Parallel()
+4 -87
View File
@@ -1,13 +1,10 @@
package agentproc
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -20,13 +17,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// maxWaitDuration is the maximum time a blocking
// process output request can wait, regardless of
// what the client requests.
maxWaitDuration = 5 * time.Minute
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
@@ -35,10 +25,10 @@ type API struct {
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv, workingDir),
manager: newManager(logger, execer, updateEnv),
pathStore: pathStore,
}
}
@@ -79,12 +69,7 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
return
}
var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok {
chatID = id.String()
}
proc, err := api.manager.start(req, chatID)
proc, err := api.manager.start(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start process.",
@@ -120,28 +105,7 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok {
chatID = id.String()
}
infos := api.manager.list(chatID)
// Sort by running state (running first), then by started_at
// descending so the most recent processes appear first.
sort.Slice(infos, func(i, j int) bool {
if infos[i].Running != infos[j].Running {
return infos[i].Running
}
return infos[i].StartedAt > infos[j].StartedAt
})
// Cap the response to avoid bloating LLM context.
const maxListProcesses = 10
if len(infos) > maxListProcesses {
infos = infos[:maxListProcesses]
}
infos := api.manager.list()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
Processes: infos,
})
@@ -160,42 +124,6 @@ func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
return
}
// Enforce chat ID isolation. If the request carries
// a chat context, only allow access to processes
// belonging to that chat.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
if proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
// Check for blocking mode via query params.
waitStr := r.URL.Query().Get("wait")
wantWait := waitStr == "true"
if wantWait {
// Extend the write deadline so the HTTP server's
// WriteTimeout does not kill the connection while
// we block.
rc := http.NewResponseController(rw)
if err := rc.SetWriteDeadline(time.Now().Add(maxWaitDuration)); err != nil {
api.logger.Error(ctx, "extend write deadline for blocking process output",
slog.Error(err),
)
}
// Cap the wait at maxWaitDuration regardless of
// client-supplied timeout.
waitCtx, waitCancel := context.WithTimeout(ctx, maxWaitDuration)
defer waitCancel()
_ = proc.waitForOutput(waitCtx)
// Fall through to read snapshot below.
}
output, truncated := proc.output()
info := proc.info()
@@ -213,17 +141,6 @@ func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
// Enforce chat ID isolation.
if chatID, _, ok := agentgit.ExtractChatContext(r); ok {
proc, procOK := api.manager.get(id)
if procOK && proc.chatID != "" && proc.chatID != chatID.String() {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
}
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
+6 -478
View File
@@ -7,10 +7,8 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
@@ -29,7 +27,7 @@ import (
)
// postStart sends a POST /start request and returns the recorder.
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) *httptest.ResponseRecorder {
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -40,13 +38,6 @@ func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcess
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
for _, h := range headers {
for k, vals := range h {
for _, v := range vals {
r.Header.Add(k, v)
}
}
}
handler.ServeHTTP(w, r)
return w
}
@@ -78,22 +69,6 @@ func getOutput(t *testing.T, handler http.Handler, id string) *httptest.Response
return w
}
// getOutputWithHeaders sends a GET /{id}/output request with
// custom headers and returns the recorder.
func getOutputWithHeaders(t *testing.T, handler http.Handler, id string, headers http.Header) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
path := fmt.Sprintf("/%s/output", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
for k, v := range headers {
req.Header[k] = v
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
@@ -115,25 +90,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, nil, nil)
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
return newTestAPIWithOptions(t, updateEnv, nil)
}
// newTestAPIWithOptions creates a new API with optional
// updateEnv and workingDir hooks.
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
t.Cleanup(func() {
_ = api.Close()
})
@@ -172,10 +140,10 @@ func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.Pro
// startAndGetID is a helper that starts a process and returns
// the process ID.
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) string {
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
t.Helper()
w := postStart(t, handler, req, headers...)
w := postStart(t, handler, req)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
@@ -278,100 +246,6 @@ func TestStartProcess(t *testing.T) {
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
t.Parallel()
// No working directory closure, so the process
// should fall back to $HOME. We verify through
// the process list API which reports the resolved
// working directory using native OS paths,
// avoiding shell path format mismatches on
// Windows (Git Bash returns POSIX paths).
handler := newTestAPI(t)
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
t.Parallel()
// The closure provides a valid directory, so the
// process should start there. Use the marker file
// pattern to avoid path format mismatches on
// Windows.
tmpDir := t.TempDir()
handler := newTestAPIWithOptions(t, nil, func() string {
return tmpDir
})
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
t.Parallel()
// The closure returns a path that doesn't exist,
// so the process should fall back to $HOME.
handler := newTestAPIWithOptions(t, nil, func() string {
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
})
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo ok",
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var listResp workspacesdk.ListProcessesResponse
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
var proc *workspacesdk.ProcessInfo
for i := range listResp.Processes {
if listResp.Processes[i].ID == id {
proc = &listResp.Processes[i]
break
}
}
require.NotNil(t, proc, "process not found in list")
require.Equal(t, homeDir, proc.WorkDir)
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
@@ -459,180 +333,6 @@ func TestListProcesses(t *testing.T) {
require.Empty(t, resp.Processes)
})
t.Run("FilterByChatID", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
chatA := uuid.New().String()
chatB := uuid.New().String()
headersA := http.Header{workspacesdk.CoderChatIDHeader: {chatA}}
headersB := http.Header{workspacesdk.CoderChatIDHeader: {chatB}}
// Start processes with different chat IDs.
id1 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-a",
}, headersA)
waitForExit(t, handler, id1)
id2 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-b",
}, headersB)
waitForExit(t, handler, id2)
id3 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-a-2",
}, headersA)
waitForExit(t, handler, id3)
// List with chat A header should return 2 processes.
w := getListWithChatHeader(t, handler, chatA)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
ids := make(map[string]bool)
for _, p := range resp.Processes {
ids[p.ID] = true
}
require.True(t, ids[id1])
require.True(t, ids[id3])
// List with chat B header should return 1 process.
w2 := getListWithChatHeader(t, handler, chatB)
require.Equal(t, http.StatusOK, w2.Code)
var resp2 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w2.Body).Decode(&resp2)
require.NoError(t, err)
require.Len(t, resp2.Processes, 1)
require.Equal(t, id2, resp2.Processes[0].ID)
// List without chat header should return all 3.
w3 := getList(t, handler)
require.Equal(t, http.StatusOK, w3.Code)
var resp3 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w3.Body).Decode(&resp3)
require.NoError(t, err)
require.Len(t, resp3.Processes, 3)
})
t.Run("ChatIDFiltering", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
chatID := uuid.New().String()
headers := http.Header{workspacesdk.CoderChatIDHeader: {chatID}}
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo with-chat",
}, headers)
waitForExit(t, handler, id)
// Listing with the same chat header should return
// the process.
w := getListWithChatHeader(t, handler, chatID)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 1)
require.Equal(t, id, resp.Processes[0].ID)
// Listing with a different chat header should not
// return the process.
w2 := getListWithChatHeader(t, handler, uuid.New().String())
require.Equal(t, http.StatusOK, w2.Code)
var resp2 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w2.Body).Decode(&resp2)
require.NoError(t, err)
require.Empty(t, resp2.Processes)
// Listing without a chat header should return the
// process (no filtering).
w3 := getList(t, handler)
require.Equal(t, http.StatusOK, w3.Code)
var resp3 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w3.Body).Decode(&resp3)
require.NoError(t, err)
require.Len(t, resp3.Processes, 1)
})
t.Run("SortAndLimit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start 12 short-lived processes so we exceed the
// limit of 10.
for i := 0; i < 12; i++ {
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("echo proc-%d", i),
})
waitForExit(t, handler, id)
}
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 10, "should be capped at 10")
// All returned processes are exited, so they should
// be sorted by StartedAt descending (newest first).
for i := 1; i < len(resp.Processes); i++ {
require.GreaterOrEqual(t, resp.Processes[i-1].StartedAt, resp.Processes[i].StartedAt,
"processes should be sorted by started_at descending")
}
})
t.Run("RunningProcessesSortedFirst", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start an exited process first.
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, exitedID)
// Start a running process after.
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
// Running process should come first regardless of
// start order.
require.Equal(t, runningID, resp.Processes[0].ID)
require.True(t, resp.Processes[0].Running)
require.Equal(t, exitedID, resp.Processes[1].ID)
require.False(t, resp.Processes[1].Running)
// Clean up.
postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
t.Run("MixedRunningAndExited", func(t *testing.T) {
t.Parallel()
@@ -681,23 +381,6 @@ func TestListProcesses(t *testing.T) {
})
}
// getListWithChatHeader sends a GET /list request with the
// Coder-Chat-Id header set and returns the recorder.
func getListWithChatHeader(t *testing.T, handler http.Handler, chatID string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
if chatID != "" {
r.Header.Set(workspacesdk.CoderChatIDHeader, chatID)
}
handler.ServeHTTP(w, r)
return w
}
func TestProcessOutput(t *testing.T) {
t.Parallel()
@@ -756,161 +439,6 @@ func TestProcessOutput(t *testing.T) {
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
t.Run("ChatIDEnforcement", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process with chat-a.
chatA := uuid.New()
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo secret",
Background: true,
}, http.Header{
workspacesdk.CoderChatIDHeader: {chatA.String()},
})
waitForExit(t, handler, id)
// Chat-b should NOT see this process.
chatB := uuid.New()
w1 := getOutputWithHeaders(t, handler, id, http.Header{
workspacesdk.CoderChatIDHeader: {chatB.String()},
})
require.Equal(t, http.StatusNotFound, w1.Code)
// Without any chat ID header, should return 200
// (backwards compatible).
w2 := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w2.Code)
})
t.Run("WaitForExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-wait && sleep 0.1",
})
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-wait")
})
t.Run("WaitAlreadyExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, id)
w := getOutputWithWait(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.False(t, resp.Running)
require.Contains(t, resp.Output, "done")
})
t.Run("WaitTimeout", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.IntervalMedium)
defer cancel()
w := getOutputWithWaitCtx(ctx, t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("ConcurrentWaiters", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
var (
wg sync.WaitGroup
resps [2]workspacesdk.ProcessOutputResponse
codes [2]int
)
for i := range 2 {
wg.Add(1)
go func() {
defer wg.Done()
w := getOutputWithWait(t, handler, id)
codes[i] = w.Code
_ = json.NewDecoder(w.Body).Decode(&resps[i])
}()
}
// Signal the process to exit so both waiters unblock.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
wg.Wait()
for i := range 2 {
require.Equal(t, http.StatusOK, codes[i], "waiter %d", i)
require.False(t, resps[i].Running, "waiter %d", i)
}
})
}
func getOutputWithWait(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
return getOutputWithWaitCtx(ctx, t, handler, id)
}
func getOutputWithWaitCtx(ctx context.Context, t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
path := fmt.Sprintf("/%s/output?wait=true", id)
req := httptest.NewRequestWithContext(ctx, http.MethodGet, path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
return w
}
func TestSignalProcess(t *testing.T) {
@@ -1055,7 +583,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
return current, nil
}, pathStore, nil)
}, pathStore)
defer api.Close()
routes := api.Routes()
+2 -19
View File
@@ -39,13 +39,11 @@ const (
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
cond *sync.Cond
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
closed bool
totalBytes int
maxHead int
maxTail int
@@ -54,24 +52,20 @@ type HeadTailBuffer struct {
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
b := &HeadTailBuffer{
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
b.cond = sync.NewCond(&b.mu)
return b
}
// Write implements io.Writer. It is safe for concurrent use.
@@ -302,15 +296,6 @@ func truncateLines(s string) string {
return b.String()
}
// Close marks the buffer as closed and wakes any waiters.
// This is called when the process exits.
func (b *HeadTailBuffer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
b.cond.Broadcast()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
@@ -320,7 +305,5 @@ func (b *HeadTailBuffer) Reset() {
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.closed = false
b.totalBytes = 0
b.cond.Broadcast()
}
-26
View File
@@ -1,26 +0,0 @@
//go:build !windows
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Unix, Setpgid creates a new process group so
// that signals can be delivered to the entire group (the shell
// and all its children).
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// signalProcess sends a signal to the process group rooted at p.
// Using the negative PID sends the signal to every process in the
// group, ensuring child processes (e.g. from shell pipelines) are
// also signaled.
func signalProcess(p *os.Process, sig syscall.Signal) error {
return syscall.Kill(-p.Pid, sig)
}
-20
View File
@@ -1,20 +0,0 @@
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Windows, process groups are not supported in the
// same way as Unix, so this returns an empty struct.
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}
// signalProcess sends a signal directly to the process. Windows
// does not support process group signaling, so we fall back to
// sending the signal to the process itself.
func signalProcess(p *os.Process, _ syscall.Signal) error {
return p.Kill()
}
+26 -107
View File
@@ -21,10 +21,6 @@ import (
var (
errProcessNotFound = xerrors.New("process not found")
errProcessNotRunning = xerrors.New("process is not running")
// exitedProcessReapAge is how long an exited process is
// kept before being automatically removed from the map.
exitedProcessReapAge = 5 * time.Minute
)
// process represents a running or completed process.
@@ -34,7 +30,6 @@ type process struct {
command string
workDir string
background bool
chatID string
cmd *exec.Cmd
cancel context.CancelFunc
buf *HeadTailBuffer
@@ -70,25 +65,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
workingDir func() string
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
workingDir: workingDir,
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
@@ -96,7 +89,7 @@ func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(curr
// processes use a long-lived context so the process survives
// the HTTP request lifecycle. The background flag only affects
// client-side polling behavior.
func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*process, error) {
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
@@ -111,9 +104,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
cmd.Dir = m.resolveWorkDir(req.WorkDir)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
@@ -158,9 +152,8 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc := &process{
id: id,
command: req.Command,
workDir: cmd.Dir,
workDir: req.WorkDir,
background: req.Background,
chatID: chatID,
cmd: cmd,
cancel: cancel,
buf: buf,
@@ -208,9 +201,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
proc.exitCode = &code
proc.mu.Unlock()
// Wake any waiters blocked on new output or
// process exit before closing the done channel.
proc.buf.Close()
close(proc.done)
}()
@@ -225,32 +215,14 @@ func (m *manager) get(id string) (*process, bool) {
return proc, ok
}
// list returns info about all tracked processes. Exited
// processes older than exitedProcessReapAge are removed.
// If chatID is non-empty, only processes belonging to that
// chat are returned.
func (m *manager) list(chatID string) []workspacesdk.ProcessInfo {
// list returns info about all tracked processes.
func (m *manager) list() []workspacesdk.ProcessInfo {
m.mu.Lock()
defer m.mu.Unlock()
now := m.clock.Now()
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
for id, proc := range m.procs {
info := proc.info()
// Reap processes that exited more than 5 minutes ago
// to prevent unbounded map growth.
if !info.Running && info.ExitedAt != nil {
exitedAt := time.Unix(*info.ExitedAt, 0)
if now.Sub(exitedAt) > exitedProcessReapAge {
delete(m.procs, id)
continue
}
}
// Filter by chatID if provided.
if chatID != "" && proc.chatID != chatID {
continue
}
infos = append(infos, info)
for _, proc := range m.procs {
infos = append(infos, proc.info())
}
return infos
}
@@ -276,15 +248,13 @@ func (m *manager) signal(id string, sig string) error {
switch sig {
case "kill":
// Use process group kill to ensure child processes
// (e.g. from shell pipelines) are also killed.
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
// Use process group signal to ensure child processes
// are also terminated.
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
@@ -322,54 +292,3 @@ func (m *manager) Close() error {
return nil
}
// waitForOutput blocks until the buffer is closed (process
// exited) or the context is canceled. Returns nil when the
// buffer closed, ctx.Err() when the context expired.
func (p *process) waitForOutput(ctx context.Context) error {
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
nevermind := make(chan struct{})
defer close(nevermind)
go func() {
select {
case <-ctx.Done():
// Acquire the lock before broadcasting to
// guarantee the waiter has entered cond.Wait()
// (which atomically releases the lock).
// Without this, a Broadcast between the loop
// predicate check and cond.Wait() is lost.
p.buf.cond.L.Lock()
defer p.buf.cond.L.Unlock()
p.buf.cond.Broadcast()
case <-nevermind:
}
}()
for ctx.Err() == nil && !p.buf.closed {
p.buf.cond.Wait()
}
return ctx.Err()
}
// resolveWorkDir returns the directory a process should start in.
// Priority: explicit request dir > agent configured dir > $HOME.
// Falls through when a candidate is empty or does not exist on
// disk, matching the behavior of SSH sessions.
func (m *manager) resolveWorkDir(requested string) string {
if requested != "" {
return requested
}
if m.workingDir != nil {
if dir := m.workingDir(); dir != "" {
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return dir
}
}
}
if home, err := os.UserHomeDir(); err == nil {
return home
}
return ""
}
+2 -2
View File
@@ -398,11 +398,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
},
})
if err != nil {
logger.Warn(ctx, "reporting script completed", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
}
})
if err != nil {
logger.Warn(ctx, "reporting script completed: track command goroutine", slog.Error(err))
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
}
}()
-1
View File
@@ -30,7 +30,6 @@ func (a *agent) apiHandler() http.Handler {
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/git", a.gitAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
+27 -6
View File
@@ -6,6 +6,7 @@ import (
"context"
"net"
"path/filepath"
"sync"
"testing"
"github.com/google/uuid"
@@ -22,6 +23,26 @@ import (
"github.com/coder/coder/v2/testutil"
)
// logSink captures structured log entries for testing.
type logSink struct {
mu sync.Mutex
entries []slog.SinkEntry
}
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.mu.Lock()
defer s.mu.Unlock()
s.entries = append(s.entries, e)
}
func (*logSink) Sync() {}
func (s *logSink) getEntries() []slog.SinkEntry {
s.mu.Lock()
defer s.mu.Unlock()
return append([]slog.SinkEntry{}, s.entries...)
}
// getField returns the value of a field by name from a slog.Map.
func getField(fields slog.Map, name string) interface{} {
for _, f := range fields {
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
sink := testutil.NewFakeSink(t)
logger := sink.Logger(slog.LevelInfo)
sink := &logSink{}
logger := slog.Make(sink)
workspaceID := uuid.New()
templateID := uuid.New()
templateVersionID := uuid.New()
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 1
return len(sink.getEntries()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
entries := sink.Entries()
entries := sink.getEntries()
require.Len(t, entries, 1)
entry := entries[0]
require.Equal(t, slog.LevelInfo, entry.Level)
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
sendBoundaryLogsRequest(t, conn, req2)
require.Eventually(t, func() bool {
return len(sink.Entries()) >= 2
return len(sink.getEntries()) >= 2
}, testutil.WaitShort, testutil.IntervalFast)
entries = sink.Entries()
entries = sink.getEntries()
entry = entries[1]
require.Len(t, entries, 2)
require.Equal(t, slog.LevelInfo, entry.Level)
+8 -31
View File
@@ -2,7 +2,6 @@ package reaper
import (
"os"
"sync"
"github.com/hashicorp/go-reap"
@@ -43,42 +42,20 @@ func WithLogger(logger slog.Logger) Option {
}
}
// WithReaperStop sets a channel that, when closed, stops the reaper
// WithDone sets a channel that, when closed, stops the reaper
// goroutine. Callers that invoke ForkReap more than once in the
// same process (e.g. tests) should use this to prevent goroutine
// accumulation.
func WithReaperStop(ch chan struct{}) Option {
func WithDone(ch chan struct{}) Option {
return func(o *options) {
o.ReaperStop = ch
}
}
// WithReaperStopped sets a channel that is closed after the
// reaper goroutine has fully exited.
func WithReaperStopped(ch chan struct{}) Option {
return func(o *options) {
o.ReaperStopped = ch
}
}
// WithReapLock sets a mutex shared between the reaper and Wait4.
// The reaper holds the write lock while reaping, and ForkReap
// holds the read lock during Wait4, preventing the reaper from
// stealing the child's exit status. This is only needed for
// tests with instant-exit children where the race window is
// large.
func WithReapLock(mu *sync.RWMutex) Option {
return func(o *options) {
o.ReapLock = mu
o.Done = ch
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
ReaperStop chan struct{}
ReaperStopped chan struct{}
ReapLock *sync.RWMutex
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
Done chan struct{}
}
+34 -99
View File
@@ -7,7 +7,6 @@ import (
"os"
"os/exec"
"os/signal"
"sync"
"syscall"
"testing"
"time"
@@ -19,82 +18,35 @@ import (
"github.com/coder/coder/v2/testutil"
)
// subprocessEnvKey is set when a test re-execs itself as an
// isolated subprocess. Tests that call ForkReap or send signals
// to their own process check this to decide whether to run real
// test logic or launch the subprocess and wait for it.
const subprocessEnvKey = "CODER_REAPER_TEST_SUBPROCESS"
// withDone returns an option that stops the reaper goroutine when t
// completes, preventing goroutine accumulation across subtests.
func withDone(t *testing.T) reaper.Option {
t.Helper()
done := make(chan struct{})
t.Cleanup(func() { close(done) })
return reaper.WithDone(done)
}
// runSubprocess re-execs the current test binary in a new process
// running only the named test. This isolates ForkReap's
// syscall.ForkExec and any process-directed signals (e.g. SIGINT)
// from the parent test binary, making these tests safe to run in
// CI and alongside other tests.
// TestReap checks that's the reaper is successfully reaping
// exited processes and passing the PIDs through the shared
// channel.
//
// Returns true inside the subprocess (caller should proceed with
// the real test logic). Returns false in the parent after the
// subprocess exits successfully (caller should return).
func runSubprocess(t *testing.T) bool {
t.Helper()
if os.Getenv(subprocessEnvKey) == "1" {
return true
}
ctx := testutil.Context(t, testutil.WaitMedium)
//nolint:gosec // Test-controlled arguments.
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=^"+t.Name()+"$",
"-test.v",
)
cmd.Env = append(os.Environ(), subprocessEnvKey+"=1")
out, err := cmd.CombinedOutput()
t.Logf("Subprocess output:\n%s", out)
require.NoError(t, err, "subprocess failed")
return false
}
// withDone returns options that stop the reaper goroutine when t
// completes and wait for it to fully exit, preventing
// overlapping reapers across sequential subtests.
func withDone(t *testing.T) []reaper.Option {
t.Helper()
stop := make(chan struct{})
stopped := make(chan struct{})
t.Cleanup(func() {
close(stop)
<-stopped
})
return []reaper.Option{
reaper.WithReaperStop(stop),
reaper.WithReaperStopped(stopped),
}
}
// TestReap checks that the reaper successfully reaps exited
// processes and passes their PIDs through the shared channel.
//nolint:paralleltest
func TestReap(t *testing.T) {
t.Parallel()
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
pids := make(reap.PidCh, 1)
var reapLock sync.RWMutex
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
// Provide some argument that immediately exits.
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
reaper.WithReapLock(&reapLock),
}, withDone(t)...)
reapLock.RLock()
exitCode, err := reaper.ForkReap(opts...)
reapLock.RUnlock()
withDone(t),
)
require.NoError(t, err)
require.Equal(t, 0, exitCode)
@@ -114,7 +66,7 @@ func TestReap(t *testing.T) {
expectedPIDs := []int{cmd.Process.Pid, cmd2.Process.Pid}
for range len(expectedPIDs) {
for i := 0; i < len(expectedPIDs); i++ {
select {
case <-time.After(testutil.WaitShort):
t.Fatalf("Timed out waiting for process")
@@ -124,15 +76,11 @@ func TestReap(t *testing.T) {
}
}
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
//nolint:paralleltest
func TestForkReapExitCodes(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
tests := []struct {
name string
@@ -147,35 +95,26 @@ func TestForkReapExitCodes(t *testing.T) {
{"SIGTERM", "kill -15 $$", 128 + 15},
}
//nolint:paralleltest // Subtests must be sequential, each starts its own reaper.
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var reapLock sync.RWMutex
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
reaper.WithReapLock(&reapLock),
}, withDone(t)...)
reapLock.RLock()
exitCode, err := reaper.ForkReap(opts...)
reapLock.RUnlock()
withDone(t),
)
require.NoError(t, err)
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
})
}
}
// TestReapInterrupt verifies that ForkReap forwards caught signals
// to the child process. The test sends SIGINT to its own process
// and checks that the child receives it. Running in a subprocess
// ensures SIGINT cannot kill the parent test binary.
//nolint:paralleltest // Signal handling.
func TestReapInterrupt(t *testing.T) {
t.Parallel()
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
errC := make(chan error, 1)
pids := make(reap.PidCh, 1)
@@ -187,28 +126,24 @@ func TestReapInterrupt(t *testing.T) {
defer signal.Stop(usrSig)
go func() {
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
reaper.WithCatchSignals(os.Interrupt),
withDone(t),
// Signal propagation does not extend to children of children, so
// we create a little bash script to ensure sleep is interrupted.
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf(
"pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait",
os.Getpid(), os.Getpid(),
)),
}, withDone(t)...)
exitCode, err := reaper.ForkReap(opts...)
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
)
// The child exits with 128 + SIGTERM (15) = 143, but the trap catches
// SIGINT and sends SIGTERM to the sleep process, so exit code varies.
_ = exitCode
errC <- err
}()
require.Equal(t, syscall.SIGUSR1, <-usrSig)
require.Equal(t, <-usrSig, syscall.SIGUSR1)
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
require.NoError(t, err)
require.Equal(t, <-usrSig, syscall.SIGUSR2)
require.Equal(t, syscall.SIGUSR2, <-usrSig)
require.NoError(t, <-errC)
}
+14 -24
View File
@@ -19,36 +19,31 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
// startSignalForwarding registers signal handlers synchronously
// then forwards caught signals to the child in a background
// goroutine. Registering before the goroutine starts ensures no
// signal is lost between ForkExec and the handler being ready.
func startSignalForwarding(logger slog.Logger, pid int, sigs []os.Signal) {
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
sc := make(chan os.Signal, 1)
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
logger.Info(context.Background(), "reaper catching signals",
slog.F("signals", sigs),
slog.F("child_pid", pid),
)
go func() {
defer signal.Stop(sc)
for s := range sc {
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
for {
s := <-sc
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
}()
}
}
// ForkReap spawns a goroutine that reaps children. In order to avoid
@@ -69,12 +64,7 @@ func ForkReap(opt ...Option) (int, error) {
o(opts)
}
go func() {
reap.ReapChildren(opts.PIDs, nil, opts.ReaperStop, opts.ReapLock)
if opts.ReaperStopped != nil {
close(opts.ReaperStopped)
}
}()
go reap.ReapChildren(opts.PIDs, nil, opts.Done, nil)
pwd, err := os.Getwd()
if err != nil {
@@ -100,7 +90,7 @@ func ForkReap(opt ...Option) (int, error) {
return 1, xerrors.Errorf("fork exec: %w", err)
}
startSignalForwarding(opts.Logger, pid, opts.CatchSignals)
go catchSignals(opts.Logger, pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
-13
View File
@@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -41,18 +40,6 @@ func New(t testing.TB, args ...string) (*serpent.Invocation, config.Root) {
return NewWithCommand(t, cmd, args...)
}
// NewWithClock is like New, but injects the given clock for
// tests that are time-dependent.
func NewWithClock(t testing.TB, clk quartz.Clock, args ...string) (*serpent.Invocation, config.Root) {
var root cli.RootCmd
root.SetClock(clk)
cmd, err := root.Command(root.AGPL())
require.NoError(t, err)
return NewWithCommand(t, cmd, args...)
}
type logWriter struct {
prefix string
log slog.Logger
-15
View File
@@ -46,7 +46,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
autoUpdates string
copyParametersFrom string
useParameterDefaults bool
noWait bool
// Organization context is only required if more than 1 template
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
@@ -373,14 +372,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
if noWait {
_, _ = fmt.Fprintf(inv.Stdout,
"\nThe %s workspace has been created and is building in the background.\n",
cliui.Keyword(workspace.Name),
)
return nil
}
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil {
return xerrors.Errorf("watch build: %w", err)
@@ -454,12 +445,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Description: "Automatically accept parameter defaults when no value is provided.",
Value: serpent.BoolOf(&useParameterDefaults),
},
serpent.Option{
Flag: "no-wait",
Env: "CODER_CREATE_NO_WAIT",
Description: "Return immediately after creating the workspace. The build will run in the background.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
)
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
-75
View File
@@ -603,81 +603,6 @@ func TestCreate(t *testing.T) {
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
}
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was actually created.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
})
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
}))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--use-parameter-defaults",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was created and parameters were applied.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
})
}
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
-6
View File
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
Properties: sdkTool.Schema.Properties,
Required: sdkTool.Schema.Required,
},
Annotations: mcp.ToolAnnotation{
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
},
},
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var buf bytes.Buffer
+1 -16
View File
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
Annotations struct {
ReadOnlyHint *bool `json:"readOnlyHint"`
DestructiveHint *bool `json:"destructiveHint"`
IdempotentHint *bool `json:"idempotentHint"`
OpenWorldHint *bool `json:"openWorldHint"`
} `json:"annotations"`
Name string `json:"name"`
} `json:"tools"`
} `json:"result"`
}
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
annotations := toolsResponse.Result.Tools[0].Annotations
require.NotNil(t, annotations.ReadOnlyHint)
require.NotNil(t, annotations.DestructiveHint)
require.NotNil(t, annotations.IdempotentHint)
require.NotNil(t, annotations.OpenWorldHint)
assert.True(t, *annotations.ReadOnlyHint)
assert.False(t, *annotations.DestructiveHint)
assert.True(t, *annotations.IdempotentHint)
assert.False(t, *annotations.OpenWorldHint)
// Call the tool and ensure it works.
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
+45 -74
View File
@@ -1732,18 +1732,19 @@ const (
func (r *RootCmd) scaletestAutostart() *serpent.Command {
var (
workspaceCount int64
workspaceJobTimeout time.Duration
autostartBuildTimeout time.Duration
autostartDelay time.Duration
template string
noCleanup bool
workspaceCount int64
workspaceJobTimeout time.Duration
autostartDelay time.Duration
autostartTimeout time.Duration
template string
noCleanup bool
parameterFlags workspaceParameterFlags
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
@@ -1771,7 +1772,7 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("parse output flags: %w", err)
return xerrors.Errorf("could not parse --output flags")
}
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
@@ -1802,41 +1803,15 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := autostart.NewMetrics(reg)
setupBarrier := new(sync.WaitGroup)
setupBarrier.Add(int(workspaceCount))
// The workspace-build-updates experiment must be enabled to use
// the centralized pubsub channel for coordinating workspace builds.
experiments, err := client.Experiments(ctx)
if err != nil {
return xerrors.Errorf("get experiments: %w", err)
}
if !experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) {
return xerrors.New("the workspace-build-updates experiment must be enabled to run the autostart scaletest")
}
workspaceNames := make([]string, 0, workspaceCount)
resultSink := make(chan autostart.RunResult, workspaceCount)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i := range workspaceCount {
id := strconv.Itoa(int(i))
workspaceNames = append(workspaceNames, loadtestutil.GenerateDeterministicWorkspaceName(id))
}
dispatcher := autostart.NewWorkspaceDispatcher(workspaceNames)
decoder, err := client.WatchAllWorkspaceBuilds(ctx)
if err != nil {
return xerrors.Errorf("watch all workspace builds: %w", err)
}
defer decoder.Close()
// Start the dispatcher. It will run in a goroutine and automatically
// close all workspace channels when the build updates channel closes.
dispatcher.Start(ctx, decoder.Chan())
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for workspaceName, buildUpdatesChannel := range dispatcher.Channels {
id := strings.TrimPrefix(workspaceName, loadtestutil.ScaleTestPrefix+"-")
config := autostart.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
@@ -1846,16 +1821,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
// Use deterministic workspace name so we can pre-create the channel.
Name: workspaceName,
},
},
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartBuildTimeout: autostartBuildTimeout,
AutostartDelay: autostartDelay,
SetupBarrier: setupBarrier,
BuildUpdates: buildUpdatesChannel,
ResultSink: resultSink,
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartDelay: autostartDelay,
AutostartTimeout: autostartTimeout,
Metrics: metrics,
SetupBarrier: setupBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
@@ -1877,11 +1849,18 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
th.AddRun(autostartTestName, id, runner)
}
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Running autostart load test...")
@@ -1892,40 +1871,31 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// Collect all metrics from the channel.
close(resultSink)
var runResults []autostart.RunResult
for r := range resultSink {
runResults = append(runResults, r)
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
_, _ = fmt.Fprintf(inv.Stderr, "\nAll %d autostart builds completed successfully (elapsed: %s)\n", res.TotalRuns, time.Duration(res.Elapsed).Round(time.Millisecond))
if len(runResults) > 0 {
results := autostart.NewRunResults(runResults)
for _, out := range outputs {
if err := out.write(results.ToHarnessResults(), inv.Stdout); err != nil {
return xerrors.Errorf("write output: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(context.Background())
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
_, _ = fmt.Fprintln(inv.Stderr, "Cleanup complete")
} else {
_, _ = fmt.Fprintln(inv.Stderr, "\nSkipping cleanup (--no-cleanup specified). Resources left running.")
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
@@ -1948,13 +1918,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "Timeout for workspace jobs (e.g. build, start).",
Value: serpent.DurationOf(&workspaceJobTimeout),
},
{
Flag: "autostart-build-timeout",
Env: "CODER_SCALETEST_AUTOSTART_BUILD_TIMEOUT",
Default: "15m",
Description: "Timeout for the autostart build to complete. Must be longer than workspace-job-timeout to account for queueing time in high-load scenarios.",
Value: serpent.DurationOf(&autostartBuildTimeout),
},
{
Flag: "autostart-delay",
Env: "CODER_SCALETEST_AUTOSTART_DELAY",
@@ -1962,6 +1925,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "How long after all the workspaces have been stopped to schedule them to be started again.",
Value: serpent.DurationOf(&autostartDelay),
},
{
Flag: "autostart-timeout",
Env: "CODER_SCALETEST_AUTOSTART_TIMEOUT",
Default: "5m",
Description: "Timeout for the autostart build to be initiated after the scheduled start time.",
Value: serpent.DurationOf(&autostartTimeout),
},
{
Flag: "template",
FlagShorthand: "t",
@@ -1980,9 +1950,10 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
tracingFlags.attach(&cmd.Options)
output.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
+6 -12
View File
@@ -19,18 +19,12 @@ func OverrideVSCodeConfigs(fs afero.Fs) error {
return err
}
mutate := func(m map[string]interface{}) {
// These defaults prevent VS Code from overriding
// GIT_ASKPASS and using its own GitHub authentication,
// which would circumvent cloning with Coder-configured
// providers. We only set them if they are not already
// present so that template authors can override them
// via module settings (e.g. the vscode-web module).
if _, ok := m["git.useIntegratedAskPass"]; !ok {
m["git.useIntegratedAskPass"] = false
}
if _, ok := m["github.gitAuthentication"]; !ok {
m["github.gitAuthentication"] = false
}
// This prevents VS Code from overriding GIT_ASKPASS, which
// we use to automatically authenticate Git providers.
m["git.useIntegratedAskPass"] = false
// This prevents VS Code from using it's own GitHub authentication
// which would circumvent cloning with Coder-configured providers.
m["github.gitAuthentication"] = false
}
for _, configPath := range []string{
-27
View File
@@ -61,31 +61,4 @@ func TestOverrideVSCodeConfigs(t *testing.T) {
require.Equal(t, "something", mapping["hotdogs"])
}
})
t.Run("NoOverwrite", func(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
mapping := map[string]interface{}{
"git.useIntegratedAskPass": true,
"github.gitAuthentication": true,
"other.setting": "preserved",
}
data, err := json.Marshal(mapping)
require.NoError(t, err)
for _, configPath := range configPaths {
err = afero.WriteFile(fs, configPath, data, 0o600)
require.NoError(t, err)
}
err = gitauth.OverrideVSCodeConfigs(fs)
require.NoError(t, err)
for _, configPath := range configPaths {
data, err := afero.ReadFile(fs, configPath)
require.NoError(t, err)
mapping := map[string]interface{}{}
err = json.Unmarshal(data, &mapping)
require.NoError(t, err)
require.Equal(t, true, mapping["git.useIntegratedAskPass"])
require.Equal(t, true, mapping["github.gitAuthentication"])
require.Equal(t, "preserved", mapping["other.setting"])
}
})
}
+1 -1
View File
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
} else {
updated, err = client.CreateOrganizationRole(ctx, customRole)
if err != nil {
return xerrors.Errorf("create role: %w", err)
return xerrors.Errorf("patch role: %w", err)
}
}
-15
View File
@@ -39,7 +39,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/pretty"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -231,10 +230,6 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
}
func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) {
if r.clock == nil {
r.clock = quartz.NewReal()
}
fmtLong := `Coder %s A tool for provisioning self-hosted development environments with Terraform.
`
hiddenAgentAuth := &AgentAuth{}
@@ -553,16 +548,6 @@ type RootCmd struct {
useKeyring bool
keyringServiceName string
useKeyringWithGlobalConfig bool
// clock is used for time-dependent operations. Initialized to
// quartz.NewReal() in Command() if not set via SetClock.
clock quartz.Clock
}
// SetClock sets the clock used for time-dependent operations.
// Must be called before Command() to take effect.
func (r *RootCmd) SetClock(clk quartz.Clock) {
r.clock = clk
}
// ensureClientURL loads the client URL from the config file if it
-2
View File
@@ -2909,8 +2909,6 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
provider.MCPToolDenyRegex = v.Value
case "PKCE_METHODS":
provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ")
case "API_BASE_URL":
provider.APIBaseURL = v.Value
}
providers[providerNum] = provider
}
+10 -11
View File
@@ -188,17 +188,16 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "Creating user...")
newUser, err = tx.InsertUser(ctx, database.InsertUserParams{
ID: uuid.New(),
Email: newUserEmail,
Username: newUserUsername,
Name: "Admin User",
HashedPassword: []byte(hashedPassword),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
RBACRoles: []string{rbac.RoleOwner().String()},
LoginType: database.LoginTypePassword,
Status: "",
IsServiceAccount: false,
ID: uuid.New(),
Email: newUserEmail,
Username: newUserUsername,
Name: "Admin User",
HashedPassword: []byte(hashedPassword),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
RBACRoles: []string{rbac.RoleOwner().String()},
LoginType: database.LoginTypePassword,
Status: "",
})
if err != nil {
return xerrors.Errorf("insert user: %w", err)
-23
View File
@@ -108,29 +108,6 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
})
}
func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
"CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL)
}
func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "", providers[0].APIBaseURL)
}
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
// environment variables are still supported.
func TestReadGitAuthProvidersFromEnv(t *testing.T) {
+15 -1
View File
@@ -21,8 +21,9 @@ type storedCredentials map[string]struct {
APIToken string `json:"api_token"`
}
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
func TestKeyring(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" && runtime.GOOS != "darwin" {
t.Skip("linux is not supported yet")
}
@@ -36,6 +37,8 @@ func TestKeyring(t *testing.T) {
)
t.Run("ReadNonExistent", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -47,6 +50,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("DeleteNonExistent", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -58,6 +63,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("WriteAndRead", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -84,6 +91,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("WriteAndDelete", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -106,6 +115,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("OverwriteToken", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -135,6 +146,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("MultipleServers", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -186,6 +199,7 @@ func TestKeyring(t *testing.T) {
})
t.Run("StorageFormat", func(t *testing.T) {
t.Parallel()
// The storage format must remain consistent to ensure we don't break
// compatibility with other Coder related applications that may read
// or decode the same credential.
@@ -25,8 +25,9 @@ func readRawKeychainCredential(t *testing.T, serviceName string) []byte {
return winCred.CredentialBlob
}
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
func TestWindowsKeyring_WriteReadDelete(t *testing.T) {
t.Parallel()
const testURL = "http://127.0.0.1:1337"
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
+16 -8
View File
@@ -180,11 +180,15 @@ func TestSSH(t *testing.T) {
// Delay until workspace is starting, otherwise the agent may be
// booted due to outdated build.
require.Eventually(t, func() bool {
var err error
var err error
for {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, err)
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
break
}
time.Sleep(testutil.IntervalFast)
}
// When the agent connects, the workspace was started, and we should
// have access to the shell.
@@ -759,11 +763,15 @@ func TestSSH(t *testing.T) {
// Delay until workspace is starting, otherwise the agent may be
// booted due to outdated build.
require.Eventually(t, func() bool {
var err error
var err error
for {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, err)
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
break
}
time.Sleep(testutil.IntervalFast)
}
// When the agent connects, the workspace was started, and we should
// have access to the shell.
-23
View File
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
)
build = workspace.LatestBuild
default:
// If the last build was a failed start, run a stop
// first to clean up any partially-provisioned
// resources.
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStop,
})
if stopErr != nil {
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
}
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
if stopErr != nil {
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
}
// Re-fetch workspace after stop completes so
// startWorkspace sees the latest state.
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
}
}
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
// It's possible for a workspace build to fail due to the template requiring starting
// workspaces with the active version.
-52
View File
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
}
func TestStart_FailedStartCleansUp(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
store, ps := dbtestutil.NewDB(t)
client := coderdtest.New(t, &coderdtest.Options{
Database: store,
Pubsub: ps,
IncludeProvisionerDaemon: true,
})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Insert a failed start build directly into the database so that
// the workspace's latest build is a failed "start" transition.
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
ID: workspace.ID,
OwnerID: member.ID,
OrganizationID: owner.OrganizationID,
TemplateID: template.ID,
}).
Seed(database.WorkspaceBuild{
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStart,
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
}).
Failed().
Do()
inv, root := clitest.New(t, "start", workspace.Name)
clitest.SetupConfig(t, memberClient, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
// The CLI should detect the failed start and clean up first.
pty.ExpectMatch("Cleaning up before retrying")
pty.ExpectMatch("workspace has been started")
_ = testutil.TryReceive(ctx, t, doneChan)
}
+17 -26
View File
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
)
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
// Bypass rate limiting for support bundle collection since it makes many API calls.
// Note: this can only be done by the owner user.
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
cliLog.Debug(inv.Context(), "running as owner")
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
} else if !ok {
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
} else {
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
}
// Check if we're running inside a workspace
if val, found := os.LookupEnv("CODER"); found && val == "true" {
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
}
// Bypass rate limiting for support bundle collection since it makes many API calls.
client.HTTPClient.Transport = &codersdk.HeaderTransport{
Transport: client.HTTPClient.Transport,
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
}
deps := support.Deps{
Client: client,
// Support adds a sink so we don't need to supply one ourselves.
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
return
}
var docsURL string
if bun.Deployment.Config != nil {
docsURL = bun.Deployment.Config.Values.DocsURL.String()
} else {
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
if bun.Deployment.Config == nil {
cliui.Error(inv.Stdout, "No deployment configuration available!")
return
}
if bun.Deployment.HealthReport != nil {
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
} else {
cliui.Warn(inv.Stdout, "No deployment health report available.")
docsURL := bun.Deployment.Config.Values.DocsURL.String()
if bun.Deployment.HealthReport == nil {
cliui.Error(inv.Stdout, "No deployment health report available!")
return
}
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
if len(deployHealthSummary) > 0 {
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
}
if bun.Network.Netcheck == nil {
+22 -47
View File
@@ -28,9 +28,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
@@ -52,21 +50,9 @@ func TestSupportBundle(t *testing.T) {
dc.Values.Prometheus.Enable = true
secretValue := uuid.NewString()
seedSecretDeploymentOptions(t, &dc, secretValue)
// Use a mock healthcheck function to avoid flaky DERP health
// checks in CI. The DERP checker performs real network operations
// (portmapper gateway probing, STUN) that can hang for 60s+ on
// macOS CI runners. Since this test validates support bundle
// generation, not healthcheck correctness, a canned report is
// sufficient.
client, closer, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
DeploymentValues: dc.Values,
HealthcheckFunc: func(_ context.Context, _ string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
return &healthsdk.HealthcheckReport{
Time: time.Now(),
Healthy: true,
Severity: health.SeverityOK,
}
},
DeploymentValues: dc.Values,
HealthcheckTimeout: testutil.WaitSuperLong,
})
t.Cleanup(func() { closer.Close() })
@@ -74,7 +60,7 @@ func TestSupportBundle(t *testing.T) {
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Set up test fixtures
setupCtx := testutil.Context(t, testutil.WaitLong)
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
workspaceWithAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, func(agents []*proto.Agent) []*proto.Agent {
// This should not show up in the bundle output
agents[0].Env["SECRET_VALUE"] = secretValue
@@ -83,6 +69,22 @@ func TestSupportBundle(t *testing.T) {
workspaceWithoutAgent := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, owner.UserID, nil)
memberWorkspace := setupSupportBundleTestFixture(setupCtx, t, api.Database, owner.OrganizationID, member.ID, nil)
// Wait for healthcheck to complete successfully before continuing with sub-tests.
// The result is cached so subsequent requests will be fast.
healthcheckDone := make(chan *healthsdk.HealthcheckReport)
go func() {
defer close(healthcheckDone)
hc, err := healthsdk.New(client).DebugHealth(setupCtx)
if err != nil {
assert.NoError(t, err, "seed healthcheck cache")
return
}
healthcheckDone <- &hc
}()
if _, ok := testutil.AssertReceive(setupCtx, t, healthcheckDone); !ok {
t.Fatal("healthcheck did not complete in time -- this may be a transient issue")
}
t.Run("WorkspaceWithAgent", func(t *testing.T) {
t.Parallel()
@@ -130,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
assertBundleContents(t, path, true, false, []string{secretValue})
})
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
t.Run("NoPrivilege", func(t *testing.T) {
t.Parallel()
d := t.TempDir()
path := filepath.Join(d, "bundle.zip")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
clitest.SetupConfig(t, memberClient, root)
err := inv.Run()
require.NoError(t, err)
r, err := zip.OpenReader(path)
require.NoError(t, err, "open zip file")
defer r.Close()
fileNames := make(map[string]struct{}, len(r.File))
for _, f := range r.File {
fileNames[f.Name] = struct{}{}
}
// These should always be present in the zip structure, even if
// the content is null/empty for non-admin users.
for _, name := range []string{
"deployment/buildinfo.json",
"deployment/config.json",
"workspace/workspace.json",
"logs.txt",
"cli_logs.txt",
"network/netcheck.json",
"network/interfaces.json",
} {
require.Contains(t, fileNames, name)
}
require.ErrorContains(t, err, "failed authorization check")
})
// This ensures that the CLI does not panic when trying to generate a support bundle
@@ -180,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("received request: %s %s", r.Method, r.URL)
switch r.URL.Path {
case "/api/v2/users/me":
resp := codersdk.User{}
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
case "/api/v2/authcheck":
// Fake auth check
resp := codersdk.AuthorizationResponse{
+14 -12
View File
@@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -102,13 +103,13 @@ func TestSyncCommands_Golden(t *testing.T) {
require.NoError(t, err)
client.Close()
outBuf := testutil.NewWaitBuffer()
// Start a goroutine to complete the dependency after a short delay
// This simulates the dependency being satisfied while start is waiting
// The delay ensures the "Waiting..." message appears in the output
done := make(chan error, 1)
go func() {
if err := outBuf.WaitFor(ctx, "Waiting"); err != nil {
done <- err
return
}
// Wait a moment to let the start command begin waiting and print the message
time.Sleep(100 * time.Millisecond)
compCtx := context.Background()
compClient, err := agentsocket.NewClient(compCtx, agentsocket.WithPath(path))
@@ -118,7 +119,7 @@ func TestSyncCommands_Golden(t *testing.T) {
}
defer compClient.Close()
// Start and complete the dependency unit.
// Start and complete the dependency unit
err = compClient.SyncStart(compCtx, "dep-unit")
if err != nil {
done <- err
@@ -128,20 +129,21 @@ func TestSyncCommands_Golden(t *testing.T) {
done <- err
}()
var outBuf bytes.Buffer
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--socket-path", path)
inv.Stdout = outBuf
inv.Stderr = outBuf
inv.Stdout = &outBuf
inv.Stderr = &outBuf
// Run the start command - it should wait for the dependency.
// Run the start command - it should wait for the dependency
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
// Ensure the completion goroutine finished.
// Ensure the completion goroutine finished
select {
case err := <-done:
require.NoError(t, err, "complete dependency")
case <-ctx.Done():
t.Fatal("timed out waiting for dependency completion goroutine")
case <-time.After(time.Second):
// Goroutine should have finished by now
}
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/start_with_dependencies", outBuf.Bytes(), nil)
+5 -5
View File
@@ -90,7 +90,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
return err
}
tsr := toStatusRow(task, r.clock.Now())
tsr := toStatusRow(task)
out, err := formatter.Format(ctx, []taskStatusRow{tsr})
if err != nil {
return xerrors.Errorf("format task status: %w", err)
@@ -112,7 +112,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
}
// Only print if something changed
newStatusRow := toStatusRow(task, r.clock.Now())
newStatusRow := toStatusRow(task)
if !taskStatusRowEqual(lastStatusRow, newStatusRow) {
out, err := formatter.Format(ctx, []taskStatusRow{newStatusRow})
if err != nil {
@@ -166,10 +166,10 @@ func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
taskStateEqual(r1.CurrentState, r2.CurrentState)
}
func toStatusRow(task codersdk.Task, now time.Time) taskStatusRow {
func toStatusRow(task codersdk.Task) taskStatusRow {
tsr := taskStatusRow{
Task: task,
ChangedAgo: now.Sub(task.UpdatedAt).Truncate(time.Second).String() + " ago",
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
}
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
task.WorkspaceAgentHealth.Healthy &&
@@ -178,7 +178,7 @@ func toStatusRow(task codersdk.Task, now time.Time) taskStatusRow {
!task.WorkspaceAgentLifecycle.ShuttingDown()
if task.CurrentState != nil {
tsr.ChangedAgo = now.Sub(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
}
return tsr
}
+9 -12
View File
@@ -19,7 +19,6 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_TaskStatus(t *testing.T) {
@@ -29,12 +28,12 @@ func Test_TaskStatus(t *testing.T) {
args []string
expectOutput string
expectError string
hf func(context.Context, quartz.Clock) func(http.ResponseWriter, *http.Request)
hf func(context.Context, time.Time) func(http.ResponseWriter, *http.Request)
}{
{
args: []string{"doesnotexist"},
expectError: httpapi.ResourceNotFoundResponse.Message,
hf: func(ctx context.Context, _ quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/tasks/me/doesnotexist":
@@ -50,8 +49,7 @@ func Test_TaskStatus(t *testing.T) {
args: []string{"exists"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
0s ago active true working Thinking furiously...`,
hf: func(ctx context.Context, clk quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
now := clk.Now()
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/tasks/me/exists":
@@ -86,8 +84,7 @@ func Test_TaskStatus(t *testing.T) {
4s ago active true
3s ago active true working Reticulating splines...
2s ago active true complete Splines reticulated successfully!`,
hf: func(ctx context.Context, clk quartz.Clock) func(http.ResponseWriter, *http.Request) {
now := clk.Now()
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
var calls atomic.Int64
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -218,7 +215,7 @@ func Test_TaskStatus(t *testing.T) {
"created_at": "2025-08-26T12:34:56Z",
"updated_at": "2025-08-26T12:34:56Z"
}`,
hf: func(ctx context.Context, _ quartz.Clock) func(http.ResponseWriter, *http.Request) {
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -255,8 +252,8 @@ func Test_TaskStatus(t *testing.T) {
var (
ctx = testutil.Context(t, testutil.WaitShort)
mClock = quartz.NewMock(t)
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, mClock)))
now = time.Now().UTC() // TODO: replace with quartz
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
client = codersdk.New(testutil.MustURL(t, srv.URL))
sb = strings.Builder{}
args = []string{"task", "status", "--watch-interval", testutil.IntervalFast.String()}
@@ -264,10 +261,10 @@ func Test_TaskStatus(t *testing.T) {
t.Cleanup(srv.Close)
args = append(args, tc.args...)
inv, cfgDir := clitest.NewWithClock(t, mClock, args...)
inv, root := clitest.New(t, args...)
inv.Stdout = &sb
inv.Stderr = &sb
clitest.SetupConfig(t, client, cfgDir)
clitest.SetupConfig(t, client, root)
err := inv.WithContext(ctx).Run()
if tc.expectError == "" {
assert.NoError(t, err)
-4
View File
@@ -20,10 +20,6 @@ OPTIONS:
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
Specify the source workspace name to copy parameters from.
--no-wait bool, $CODER_CREATE_NO_WAIT
Return immediately after creating the workspace. The build will run in
the background.
--parameter string-array, $CODER_RICH_PARAMETER
Rich parameter value in the format "name=value".
+1 -1
View File
@@ -7,7 +7,7 @@
"last_seen_at": "====[timestamp]=====",
"name": "test-daemon",
"version": "v0.0.0-devel",
"api_version": "1.16",
"api_version": "1.15",
"provisioners": [
"echo"
],
+5 -6
View File
@@ -143,6 +143,11 @@ AI BRIDGE OPTIONS:
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
Whether to start an in-memory aibridged instance.
--aibridge-inject-coder-mcp-tools bool, $CODER_AIBRIDGE_INJECT_CODER_MCP_TOOLS (default: false)
Whether to inject Coder's MCP tools into intercepted AI Bridge
requests (requires the "oauth2" and "mcp-server-http" experiments to
be enabled).
--aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0)
Maximum number of concurrent AI Bridge requests per replica. Set to 0
to disable (unlimited).
@@ -170,12 +175,6 @@ AI BRIDGE OPTIONS:
exporting these records to external SIEM or observability systems.
AI BRIDGE PROXY OPTIONS:
--aibridge-proxy-allowed-private-cidrs string-array, $CODER_AIBRIDGE_PROXY_ALLOWED_PRIVATE_CIDRS
Comma-separated list of CIDR ranges that are permitted even though
they fall within blocked private/reserved IP ranges. By default all
private ranges are blocked to prevent SSRF attacks. Use this to allow
access to specific internal networks.
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
provider requests.
+10 -11
View File
@@ -8,17 +8,16 @@ USAGE:
Aliases: user
SUBCOMMANDS:
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
oidc-claims Display the OIDC claims for the authenticated user.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user
cannot log into the platform
activate Update a user's status to 'active'. Active users can fully
interact with the platform
create Create a new user.
delete Delete a user by username or user_id.
edit-roles Edit a user's roles by username or id
list Prints the list of users.
show Show a single user. Use 'me' to indicate the currently
authenticated user.
suspend Update a user's status to 'suspended'. A suspended user cannot
log into the platform
———
Run `coder --help` for a list of global options.
-4
View File
@@ -24,10 +24,6 @@ OPTIONS:
-p, --password string
Specifies a password for the new user.
--service-account bool
Create a user account intended to be used by a service or as an
intermediary rather than by a human.
-u, --username string
Specifies a username for the new user.
-24
View File
@@ -1,24 +0,0 @@
coder v0.0.0-devel
USAGE:
coder users oidc-claims [flags]
Display the OIDC claims for the authenticated user.
- Display your OIDC claims:
$ coder users oidc-claims
- Display your OIDC claims as JSON:
$ coder users oidc-claims -o json
OPTIONS:
-c, --column [key|value] (default: key,value)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
+2 -15
View File
@@ -752,11 +752,6 @@ workspace_prebuilds:
# limit; disabled when set to zero.
# (default: 3, type: int)
failure_hard_limit: 3
# Configure the background chat processing daemon.
chat:
# How many pending chats a worker should acquire per polling cycle.
# (default: 10, type: int)
acquireBatchSize: 10
aibridge:
# Whether to start an in-memory aibridged instance.
# (default: false, type: bool)
@@ -783,10 +778,8 @@ aibridge:
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
# Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a
# future release. Whether to inject Coder's MCP tools into intercepted AI Bridge
# requests (requires the "oauth2" and "mcp-server-http" experiments to be
# enabled).
# Whether to inject Coder's MCP tools into intercepted AI Bridge requests
# (requires the "oauth2" and "mcp-server-http" experiments to be enabled).
# (default: false, type: bool)
inject_coder_mcp_tools: false
# Length of time to retain data such as interceptions and all related records
@@ -873,12 +866,6 @@ aibridgeproxy:
# by the system. If not provided, the system certificate pool is used.
# (default: <unset>, type: string)
upstream_proxy_ca: ""
# Comma-separated list of CIDR ranges that are permitted even though they fall
# within blocked private/reserved IP ranges. By default all private ranges are
# blocked to prevent SSRF attacks. Use this to allow access to specific internal
# networks.
# (default: <unset>, type: string-array)
allowed_private_cidrs: []
# Configure data retention policies for various database tables. Retention
# policies automatically purge old data to reduce database size and improve
# performance. Setting a retention duration to 0 disables automatic purging for
+12 -37
View File
@@ -17,14 +17,13 @@ import (
func (r *RootCmd) userCreate() *serpent.Command {
var (
email string
username string
name string
password string
disableLogin bool
loginType string
serviceAccount bool
orgContext = NewOrganizationContext()
email string
username string
name string
password string
disableLogin bool
loginType string
orgContext = NewOrganizationContext()
)
cmd := &serpent.Command{
Use: "create",
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
if serviceAccount {
switch {
case loginType != "":
return xerrors.New("You cannot use --login-type with --service-account")
case password != "":
return xerrors.New("You cannot use --password with --service-account")
case email != "":
return xerrors.New("You cannot use --email with --service-account")
case disableLogin:
return xerrors.New("You cannot use --disable-login with --service-account")
}
}
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
client, err := r.InitClient(inv)
if err != nil {
return err
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
return err
}
}
if email == "" && !serviceAccount {
if email == "" {
email, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: "Email:",
Validate: func(s string) error {
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
}
}
userLoginType := codersdk.LoginTypePassword
if disableLogin || serviceAccount {
if disableLogin && loginType != "" {
return xerrors.New("You cannot specify both --disable-login and --login-type")
}
if disableLogin {
userLoginType = codersdk.LoginTypeNone
} else if loginType != "" {
userLoginType = codersdk.LoginType(loginType)
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
Password: password,
OrganizationIDs: []uuid.UUID{organization.ID},
UserLoginType: userLoginType,
ServiceAccount: serviceAccount,
})
if err != nil {
return err
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
case codersdk.LoginTypeOIDC:
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
}
if serviceAccount {
email = "n/a"
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
}
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
Share the instructions below to get them started.
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
)),
Value: serpent.StringOf(&loginType),
},
{
Flag: "service-account",
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
Value: serpent.BoolOf(&serviceAccount),
},
}
orgContext.AttachOptions(cmd)
-53
View File
@@ -8,7 +8,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
assert.Equal(t, args[5], created.Username)
assert.Empty(t, created.Name)
})
tests := []struct {
name string
args []string
err string
}{
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
},
{
name: "ServiceAccountLoginType",
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
err: "You cannot use --login-type with --service-account",
},
{
name: "ServiceAccountDisableLogin",
args: []string{"--service-account", "-u", "dean", "--disable-login"},
err: "You cannot use --disable-login with --service-account",
},
{
name: "ServiceAccountEmail",
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
err: "You cannot use --email with --service-account",
},
{
name: "ServiceAccountPassword",
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
err: "You cannot use --password with --service-account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
clitest.SetupConfig(t, client, root)
err := inv.Run()
if tt.err == "" {
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitShort)
created, err := client.User(ctx, "dean")
require.NoError(t, err)
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
} else {
require.Error(t, err)
require.ErrorContains(t, err, tt.err)
}
})
}
}
-79
View File
@@ -1,79 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) userOIDCClaims() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.ChangeFormatterData(
cliui.TableFormat([]claimRow{}, []string{"key", "value"}),
func(data any) (any, error) {
resp, ok := data.(codersdk.OIDCClaimsResponse)
if !ok {
return nil, xerrors.Errorf("expected type %T, got %T", resp, data)
}
rows := make([]claimRow, 0, len(resp.Claims))
for k, v := range resp.Claims {
rows = append(rows, claimRow{
Key: k,
Value: fmt.Sprintf("%v", v),
})
}
return rows, nil
},
),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "oidc-claims",
Short: "Display the OIDC claims for the authenticated user.",
Long: FormatExamples(
Example{
Description: "Display your OIDC claims",
Command: "coder users oidc-claims",
},
Example{
Description: "Display your OIDC claims as JSON",
Command: "coder users oidc-claims -o json",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
resp, err := client.UserOIDCClaims(inv.Context())
if err != nil {
return xerrors.Errorf("get oidc claims: %w", err)
}
out, err := formatter.Format(inv.Context(), resp)
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
type claimRow struct {
Key string `json:"-" table:"key,default_sort"`
Value string `json:"-" table:"value"`
}
-161
View File
@@ -1,161 +0,0 @@
package cli_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestUserOIDCClaims(t *testing.T) {
t.Parallel()
newOIDCTest := func(t *testing.T) (*oidctest.FakeIDP, *codersdk.Client) {
t.Helper()
fake := oidctest.NewFakeIDP(t,
oidctest.WithServing(),
)
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true
})
ownerClient := coderdtest.New(t, &coderdtest.Options{
OIDCConfig: cfg,
})
return fake, ownerClient
}
t.Run("OwnClaims", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "alice@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
"groups": []string{"admin", "eng"},
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims", "-o", "json")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
var resp codersdk.OIDCClaimsResponse
err = json.Unmarshal(buf.Bytes(), &resp)
require.NoError(t, err, "unmarshal JSON output")
require.NotEmpty(t, resp.Claims, "claims should not be empty")
assert.Equal(t, "alice@coder.com", resp.Claims["email"])
})
t.Run("Table", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
claims := jwt.MapClaims{
"email": "bob@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, userClient, root)
buf := bytes.NewBuffer(nil)
inv.Stdout = buf
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.NoError(t, err)
output := buf.String()
require.Contains(t, output, "email")
require.Contains(t, output, "bob@coder.com")
})
t.Run("NotOIDCUser", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, "users", "oidc-claims")
clitest.SetupConfig(t, client, root)
err := inv.WithContext(testutil.Context(t, testutil.WaitMedium)).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not an OIDC user")
})
// Verify that two different OIDC users each only see their own
// claims. The endpoint has no user parameter, so there is no way
// to request another user's claims by design.
t.Run("OnlyOwnClaims", func(t *testing.T) {
t.Parallel()
aliceFake, aliceOwnerClient := newOIDCTest(t)
aliceClaims := jwt.MapClaims{
"email": "alice-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
aliceClient, aliceLoginResp := aliceFake.Login(t, aliceOwnerClient, aliceClaims)
defer aliceLoginResp.Body.Close()
bobFake, bobOwnerClient := newOIDCTest(t)
bobClaims := jwt.MapClaims{
"email": "bob-isolation@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
bobClient, bobLoginResp := bobFake.Login(t, bobOwnerClient, bobClaims)
defer bobLoginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
// Alice sees her own claims.
aliceResp, err := aliceClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "alice-isolation@coder.com", aliceResp.Claims["email"])
// Bob sees his own claims.
bobResp, err := bobClient.UserOIDCClaims(ctx)
require.NoError(t, err)
assert.Equal(t, "bob-isolation@coder.com", bobResp.Claims["email"])
})
t.Run("ClaimsNeverNull", func(t *testing.T) {
t.Parallel()
fake, ownerClient := newOIDCTest(t)
// Use minimal claims — just enough for OIDC login.
claims := jwt.MapClaims{
"email": "minimal@coder.com",
"email_verified": true,
"sub": uuid.NewString(),
}
userClient, loginResp := fake.Login(t, ownerClient, claims)
defer loginResp.Body.Close()
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.UserOIDCClaims(ctx)
require.NoError(t, err)
require.NotNil(t, resp.Claims, "claims should never be nil, expected empty map")
})
}
-1
View File
@@ -19,7 +19,6 @@ func (r *RootCmd) users() *serpent.Command {
r.userSingle(),
r.userDelete(),
r.userEditRoles(),
r.userOIDCClaims(),
r.createUserStatusCommand(codersdk.UserStatusActive),
r.createUserStatusCommand(codersdk.UserStatusSuspended),
},
+4 -4
View File
@@ -116,10 +116,10 @@ func TestWorkspaceActivityBump(t *testing.T) {
// is required. The Activity Bump behavior is also coupled with
// Last Used, so it would be obvious to the user if we
// are falsely recognizing activity.
require.Never(t, func() bool {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && !workspace.LatestBuild.Deadline.Time.Equal(firstDeadline)
}, testutil.IntervalMedium, testutil.IntervalFast, "deadline should not change")
time.Sleep(testutil.IntervalMedium)
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err)
require.Equal(t, workspace.LatestBuild.Deadline.Time, firstDeadline)
return
}
+3 -6
View File
@@ -134,12 +134,9 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
case database.WorkspaceAgentLifecycleStateReady,
database.WorkspaceAgentLifecycleStateStartTimeout,
database.WorkspaceAgentLifecycleStateStartError:
// Only emit metrics for the parent agent, this metric is not intended to measure devcontainer durations.
if !workspaceAgent.ParentID.Valid {
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
return req.Lifecycle, nil
-58
View File
@@ -582,64 +582,6 @@ func TestUpdateLifecycle(t *testing.T) {
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
t.Run("SubAgentDoesNotEmitMetric", func(t *testing.T) {
t.Parallel()
parentID := uuid.New()
subAgent := database.WorkspaceAgent{
ID: uuid.New(),
ParentID: uuid.NullUUID{UUID: parentID, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
StartedAt: sql.NullTime{Valid: true, Time: someTime},
ReadyAt: sql.NullTime{Valid: false},
}
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: subAgent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: subAgent.StartedAt,
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
}).Return(nil)
// GetWorkspaceBuildMetricsByResourceID should NOT be called
// because sub-agents should be skipped before querying.
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return subAgent, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
// We don't expect the metric to be emitted for sub-agents, by default this will fail anyway but it doesn't hurt
// to document the test explicitly.
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), gomock.Any()).Times(0)
// If we were emitting the metric we would have failed by now since it would include a call to the database that we're not expecting.
pm, err := reg.Gather()
require.NoError(t, err)
for _, m := range pm {
if m.GetName() == fullMetricName {
t.Fatal("metric should not be emitted for sub-agent")
}
}
})
}
func TestUpdateStartup(t *testing.T) {
-38
View File
@@ -1,38 +0,0 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+1683 -2101
View File
File diff suppressed because it is too large Load Diff
+1683 -2074
View File
File diff suppressed because it is too large Load Diff
+1 -2
View File
@@ -32,8 +32,7 @@ type Auditable interface {
idpsync.OrganizationSyncSettings |
idpsync.GroupSyncSettings |
idpsync.RoleSyncSettings |
database.TaskTable |
database.AiSeatState
database.TaskTable
}
// Map is a map of changed fields in an audited resource. It maps field names to
-8
View File
@@ -132,8 +132,6 @@ func ResourceTarget[T Auditable](tgt T) string {
return "Organization Role Sync"
case database.TaskTable:
return typed.Name
case database.AiSeatState:
return "AI Seat"
default:
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
}
@@ -198,8 +196,6 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
return noID // Org field on audit log has org id
case database.TaskTable:
return typed.ID
case database.AiSeatState:
return typed.UserID
default:
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
}
@@ -255,8 +251,6 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
return database.ResourceTypeIdpSyncSettingsGroup
case database.TaskTable:
return database.ResourceTypeTask
case database.AiSeatState:
return database.ResourceTypeAiSeat
default:
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
}
@@ -315,8 +309,6 @@ func ResourceRequiresOrgID[T Auditable]() bool {
return true
case database.TaskTable:
return true
case database.AiSeatState:
return false
default:
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
}
+3 -1
View File
@@ -240,7 +240,9 @@ func (c *Compressor) serveRef(w http.ResponseWriter, r *http.Request, headers ht
}
for key, values := range headers {
w.Header()[key] = values
for _, value := range values {
w.Header().Add(key, value)
}
}
w.Header().Set("Content-Encoding", cref.key.encoding)
w.Header().Add("Vary", "Accept-Encoding")
@@ -155,41 +155,6 @@ type nopEncoder struct {
func (nopEncoder) Close() error { return nil }
func TestCompressorPresetHeaders(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
tempDir := t.TempDir()
cacheDir := filepath.Join(tempDir, "cache")
err := os.MkdirAll(cacheDir, 0o700)
require.NoError(t, err)
srcDir := filepath.Join(tempDir, "src")
err = os.MkdirAll(srcDir, 0o700)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcDir, "file.html"), []byte("textstring"), 0o600)
require.NoError(t, err)
compressor := NewCompressor(logger, 5, cacheDir, http.FS(os.DirFS(srcDir)))
for range 2 {
ctx := testutil.Context(t, testutil.WaitShort)
req := httptest.NewRequestWithContext(ctx, "GET", "/file.html", nil)
req.Header.Set("Accept-Encoding", "gzip")
respRec := httptest.NewRecorder()
respRec.Header().Set("X-Original-Content-Length", "10")
respRec.Header().Set("ETag", `"abc123"`)
compressor.ServeHTTP(respRec, req)
resp := respRec.Result()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, []string{"10"}, resp.Header.Values("X-Original-Content-Length"))
require.Equal(t, []string{`"abc123"`}, resp.Header.Values("ETag"))
require.NoError(t, resp.Body.Close())
}
}
// nolint: tparallel // we want to assert the state of the cache, so run synchronously
func TestCompressorHeadings(t *testing.T) {
t.Parallel()
-71
View File
@@ -1,71 +0,0 @@
package chatcost
import (
"github.com/shopspring/decimal"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
// Returns cost in micros -- millionths of a dollar, rounded up to the next
// whole microdollar.
// Returns nil when pricing is not configured or when all priced usage fields
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
func CalculateTotalCostMicros(
usage codersdk.ChatMessageUsage,
cost *codersdk.ModelCostConfig,
) *int64 {
if cost == nil {
return nil
}
// A cost config with no prices set means pricing is effectively
// unconfigured — return nil (unpriced) rather than zero.
if cost.InputPricePerMillionTokens == nil &&
cost.OutputPricePerMillionTokens == nil &&
cost.CacheReadPricePerMillionTokens == nil &&
cost.CacheWritePricePerMillionTokens == nil {
return nil
}
if usage.InputTokens == nil &&
usage.OutputTokens == nil &&
usage.ReasoningTokens == nil &&
usage.CacheCreationTokens == nil &&
usage.CacheReadTokens == nil {
return nil
}
// OutputTokens already includes reasoning tokens per provider
// semantics (e.g. OpenAI's completion_tokens encompasses
// reasoning_tokens). Adding ReasoningTokens here would
// double-count.
// Preserve nil when usage exists only in categories without configured
// pricing, so callers can distinguish "unpriced" from "priced at zero".
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
if !hasMatchingPrice {
return nil
}
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
total := inputMicros.
Add(outputMicros).
Add(cacheReadMicros).
Add(cacheWriteMicros)
rounded := total.Ceil().IntPart()
return &rounded
}
// calcCost returns the cost in fractional microdollars (millionths of a USD)
// for the given token count at the specified per-million-token price.
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
}
-163
View File
@@ -1,163 +0,0 @@
package chatcost_test
import (
"testing"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
func TestCalculateTotalCostMicros(t *testing.T) {
t.Parallel()
tests := []struct {
name string
usage codersdk.ChatMessageUsage
cost *codersdk.ModelCostConfig
want *int64
}{
{
name: "nil cost returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: nil,
want: nil,
},
{
name: "all priced usage fields nil returns nil",
usage: codersdk.ChatMessageUsage{
TotalTokens: ptr.Ref[int64](1234),
ContextLimit: ptr.Ref[int64](8192),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: nil,
},
{
name: "sub-micro total rounds up to 1",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
},
want: ptr.Ref[int64](1),
},
{
name: "simple input only",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "simple output only",
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "reasoning tokens included in output total",
usage: codersdk.ChatMessageUsage{
OutputTokens: ptr.Ref[int64](500),
ReasoningTokens: ptr.Ref[int64](200),
},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "cache read tokens",
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
cost: &codersdk.ModelCostConfig{
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "cache creation tokens",
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
cost: &codersdk.ModelCostConfig{
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
},
want: ptr.Ref[int64](18750),
},
{
name: "full mixed usage totals all components exactly",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](101),
OutputTokens: ptr.Ref[int64](201),
ReasoningTokens: ptr.Ref[int64](52),
CacheReadTokens: ptr.Ref[int64](1005),
CacheCreationTokens: ptr.Ref[int64](33),
TotalTokens: ptr.Ref[int64](1391),
ContextLimit: ptr.Ref[int64](4096),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
},
want: ptr.Ref[int64](2005),
},
{
name: "partial pricing only input contributes",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](1234),
OutputTokens: ptr.Ref[int64](999),
ReasoningTokens: ptr.Ref[int64](111),
CacheReadTokens: ptr.Ref[int64](500),
CacheCreationTokens: ptr.Ref[int64](250),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
},
want: ptr.Ref[int64](3085),
},
{
name: "zero tokens with pricing returns zero pointer",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](0),
},
{
name: "usage only in unpriced categories returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: nil,
},
{
name: "non nil usage with empty cost config returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
cost: &codersdk.ModelCostConfig{},
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
if tt.want == nil {
require.Nil(t, got)
} else {
require.NotNil(t, got)
require.Equal(t, *tt.want, *got)
}
})
}
}
+668 -1507
View File
File diff suppressed because it is too large Load Diff
-534
View File
@@ -2,26 +2,13 @@ package chatd
import (
"context"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
@@ -97,524 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
require.ErrorContains(t, err, loadErr.Error())
require.Equal(t, chat, refreshed)
}
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
workspaceAgent := database.WorkspaceAgent{
ID: uuid.New(),
OperatingSystem: "linux",
Directory: "/home/coder/project",
ExpandedDirectory: "/home/coder/project",
}
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
workspacesdk.LSResponse{},
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
).Times(1)
conn.EXPECT().ReadFile(
gomock.Any(),
"/home/coder/project/AGENTS.md",
int64(0),
int64(maxInstructionFileBytes+1),
).Return(
nil,
"",
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
).Times(1)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := &Server{
db: db,
logger: logger,
instructionCache: make(map[uuid.UUID]cachedInstruction),
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
return conn, func() {}, nil
},
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
instruction := server.resolveInstructions(
ctx,
chat,
workspaceCtx.getWorkspaceAgent,
workspaceCtx.getWorkspaceConn,
)
require.Contains(t, instruction, "Operating System: linux")
require.Contains(t, instruction, "Working Directory: /home/coder/project")
}
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
t.Parallel()
ctx := context.Background()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
workspaceID := uuid.New()
chat := database.Chat{
ID: uuid.New(),
WorkspaceID: uuid.NullUUID{
UUID: workspaceID,
Valid: true,
},
}
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
gomock.InOrder(
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{initialAgent}, nil),
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
gomock.Any(),
workspaceID,
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
)
conn := agentconnmock.NewMockAgentConn(ctrl)
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
var dialed []uuid.UUID
server := &Server{db: db}
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
dialed = append(dialed, agentID)
if agentID == initialAgent.ID {
return nil, nil, xerrors.New("dial failed")
}
return conn, func() {}, nil
}
chatStateMu := &sync.Mutex{}
currentChat := chat
workspaceCtx := turnWorkspaceContext{
server: server,
chatStateMu: chatStateMu,
currentChat: &currentChat,
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
}
t.Cleanup(workspaceCtx.close)
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
require.NoError(t, err)
require.Same(t, conn, gotConn)
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
}
func TestSubscribeSkipsDatabaseCatchupForLocallyDeliveredMessage(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
localMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishMessage(chatID, localMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
cachedMessage := codersdk.ChatMessage{
ID: 2,
ChatID: chatID,
Role: codersdk.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &cachedMessage,
})
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
catchupMessage := database.ChatMessage{
ID: 2,
ChatID: chatID,
Role: database.ChatMessageRoleAssistant,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 1,
}).Return([]database.ChatMessage{catchupMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
AfterMessageID: 1,
})
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(2), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
initialMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
editedMessage := database.ChatMessage{
ID: 1,
ChatID: chatID,
Role: database.ChatMessageRoleUser,
}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{initialMessage}, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return([]database.ChatMessage{editedMessage}, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
server.publishEditedMessage(chatID, editedMessage)
event := requireStreamMessageEvent(t, events)
require.Equal(t, int64(1), event.Message.ID)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) {
t.Parallel()
ctx, cancelCtx := context.WithCancel(context.Background())
defer cancelCtx()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
chat := database.Chat{ID: chatID, Status: database.ChatStatusPending}
gomock.InOrder(
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: 0,
}).Return(nil, nil),
db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil),
db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil),
)
server := newSubscribeTestServer(t, db)
_, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0)
require.True(t, ok)
defer cancel()
retryingAt := time.Unix(1_700_000_000, 0).UTC()
expected := &codersdk.ChatStreamRetry{
Attempt: 1,
DelayMs: (1500 * time.Millisecond).Milliseconds(),
Error: "rate limit exceeded",
RetryingAt: retryingAt,
}
server.publishRetry(chatID, expected)
event := requireStreamRetryEvent(t, events)
require.Equal(t, expected, event.Retry)
requireNoStreamEvent(t, events, 200*time.Millisecond)
}
func newSubscribeTestServer(t *testing.T, db database.Store) *Server {
t.Helper()
return &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
pubsub: dbpubsub.NewInMemory(),
}
}
func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type)
require.NotNil(t, event.Message)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream message event")
return codersdk.ChatStreamEvent{}
}
}
func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent {
t.Helper()
select {
case event, ok := <-events:
require.True(t, ok, "chat stream closed before delivering an event")
require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type)
require.NotNil(t, event.Retry)
return event
case <-time.After(time.Second):
t.Fatal("timed out waiting for chat stream retry event")
return codersdk.ChatStreamEvent{}
}
}
func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) {
t.Helper()
select {
case event, ok := <-events:
if !ok {
t.Fatal("chat stream closed unexpectedly")
}
t.Fatalf("unexpected chat stream event: %+v", event)
case <-time.After(wait):
}
}
// TestPublishToStream_DropWarnRateLimiting walks through a
// realistic lifecycle: buffer fills up, subscriber channel fills
// up, counters get reset between steps. It verifies that WARN
// logs are rate-limited to at most once per streamDropWarnInterval
// and that counter resets re-enable an immediate WARN.
func TestPublishToStream_DropWarnRateLimiting(t *testing.T) {
t.Parallel()
sink := testutil.NewFakeSink(t)
mClock := quartz.NewMock(t)
server := &Server{
logger: sink.Logger(),
clock: mClock,
}
chatID := uuid.New()
subCh := make(chan codersdk.ChatStreamEvent, 1)
subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop
// Set up state that mirrors a running chat: buffer at capacity,
// buffering enabled, one saturated subscriber.
state := &chatStreamState{
buffering: true,
buffer: make([]codersdk.ChatStreamEvent, maxStreamBufferSize),
subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{
uuid.New(): subCh,
},
}
server.chatStreams.Store(chatID, state)
bufferMsg := "chat stream buffer full, dropping oldest event"
subMsg := "dropping chat stream event"
filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool {
return func(e slog.SinkEntry) bool {
return e.Level == level && e.Message == msg
}
}
// --- Phase 1: buffer-full rate limiting ---
// message_part events hit both the buffer-full and subscriber-full
// paths. The first publish triggers a WARN for each; the rest
// within the window are DEBUG.
partEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessagePart,
MessagePart: &codersdk.ChatStreamMessagePart{},
}
for i := 0; i < 50; i++ {
server.publishToStream(chatID, partEvent)
}
require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1))
// Subscriber also saw 50 drops (one per publish).
require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1)
require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg)))
requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1))
// --- Phase 2: clock advance triggers second WARN with count ---
mClock.Advance(streamDropWarnInterval + time.Second)
server.publishToStream(chatID, partEvent)
bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 2)
requireFieldValue(t, bufWarn[1], "dropped_count", int64(50))
subWarn := sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 2)
requireFieldValue(t, subWarn[1], "dropped_count", int64(50))
// --- Phase 3: counter reset (simulates step persist) ---
state.mu.Lock()
state.buffer = make([]codersdk.ChatStreamEvent, maxStreamBufferSize)
state.resetDropCounters()
state.mu.Unlock()
// The very next drop should WARN immediately — the reset zeroed
// lastWarnAt so the interval check passes.
server.publishToStream(chatID, partEvent)
bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg))
require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset")
requireFieldValue(t, bufWarn[2], "dropped_count", int64(1))
subWarn = sink.Entries(filter(slog.LevelWarn, subMsg))
require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset")
requireFieldValue(t, subWarn[2], "dropped_count", int64(1))
}
// requireFieldValue asserts that a SinkEntry contains a field with
// the given name and value.
func requireFieldValue(t *testing.T, entry slog.SinkEntry, name string, expected interface{}) {
t.Helper()
for _, f := range entry.Fields {
if f.Name == name {
require.Equal(t, expected, f.Value, "field %q value mismatch", name)
return
}
}
t.Fatalf("field %q not found in log entry", name)
}
+96 -1953
View File
File diff suppressed because it is too large Load Diff
+91 -196
View File
@@ -8,7 +8,6 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"charm.land/fantasy"
@@ -42,11 +41,6 @@ type PersistedStep struct {
Content []fantasy.Content
Usage fantasy.Usage
ContextLimit sql.NullInt64
// Runtime is the wall-clock duration of this step,
// covering LLM streaming, tool execution, and retries.
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
}
// RunOptions configures a single streaming chat loop run.
@@ -68,16 +62,9 @@ type RunOptions struct {
// of the provider, which lives in chatd, not chatloop.
ProviderOptions fantasy.ProviderOptions
// ProviderTools are provider-native tools (like web search
// and computer use) whose definitions are passed directly
// to the provider API. When a ProviderTool has a non-nil
// Runner, tool calls are executed locally; otherwise the
// provider handles execution (e.g. web search).
ProviderTools []ProviderTool
PersistStep func(context.Context, PersistedStep) error
PublishMessagePart func(
role codersdk.ChatMessageRole,
role fantasy.MessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
@@ -94,16 +81,6 @@ type RunOptions struct {
OnInterruptedPersistError func(error)
}
// ProviderTool pairs a provider-native tool definition with an
// optional local executor. When Runner is nil the tool is fully
// provider-executed (e.g. web search). When Runner is non-nil
// the definition is sent to the API but execution is handled
// locally (e.g. computer use).
type ProviderTool struct {
Definition fantasy.Tool
Runner fantasy.AgentTool
}
// stepResult holds the accumulated output of a single streaming
// step. Since we own the stream consumer, all content is tracked
// directly here — no shadow draft state needed.
@@ -127,7 +104,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
switch c.GetType() {
case fantasy.ContentTypeText:
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
if !ok || strings.TrimSpace(text.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.TextPart{
@@ -136,7 +113,7 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
})
case fantasy.ContentTypeReasoning:
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
if !ok || strings.TrimSpace(reasoning.Text) == "" {
if !ok {
continue
}
assistantParts = append(assistantParts, fantasy.ReasoningPart{
@@ -174,23 +151,11 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
if !ok {
continue
}
part := fantasy.ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderExecuted: result.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
}
// Provider-executed tool results (e.g. web_search)
// must stay in the assistant message so the result
// block appears inline after the corresponding
// server_tool_use block. This matches the persistence
// layer in chatd.go which keeps them in
// assistantBlocks.
if result.ProviderExecuted {
assistantParts = append(assistantParts, part)
} else {
toolParts = append(toolParts, part)
}
toolParts = append(toolParts, fantasy.ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
})
default:
continue
}
@@ -232,14 +197,14 @@ func Run(ctx context.Context, opts RunOptions) error {
opts.MaxSteps = 1
}
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
if opts.PublishMessagePart == nil {
return
}
opts.PublishMessagePart(role, part)
}
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools, opts.ProviderTools)
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools)
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
messages := opts.Messages
@@ -265,7 +230,6 @@ func Run(ctx context.Context, opts RunOptions) error {
for step := 0; totalSteps < opts.MaxSteps; step++ {
totalSteps++
stepStart := time.Now()
// Copy messages so that provider-specific caching
// mutations don't leak back to the caller's slice.
// copy copies Message structs by value, so field
@@ -332,29 +296,17 @@ func Run(ctx context.Context, opts RunOptions) error {
return ctx.Err()
}
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
publishMessagePart(
codersdk.ChatMessageRoleTool,
fantasy.MessageRoleTool,
chatprompt.PartFromContent(tr),
)
})
for _, tr := range toolResults {
result.content = append(result.content, tr)
}
// Check for interruption after tool execution.
// Tools that were canceled mid-flight produce error
// results via ctx cancellation. Persist the full
// step (assistant blocks + tool results) through
// the interrupt-safe path so nothing is lost.
if ctx.Err() != nil {
if errors.Is(context.Cause(ctx), ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return ctx.Err()
}
}
// Extract context limit from provider metadata.
contextLimit := extractContextLimit(result.providerMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
@@ -363,22 +315,16 @@ func Run(ctx context.Context, opts RunOptions) error {
Valid: true,
}
}
// Persist the step. If persistence fails because
// the chat was interrupted between the previous
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
// Persist the step — errors propagate directly.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return xerrors.Errorf("persist step: %w", err)
}
lastUsage = result.usage
lastProviderMetadata = result.providerMetadata
@@ -473,7 +419,7 @@ func Run(ctx context.Context, opts RunOptions) error {
func processStepStream(
ctx context.Context,
stream fantasy.StreamResponse,
publishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart),
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
) (stepResult, error) {
var result stepResult
@@ -482,6 +428,27 @@ func processStepStream(
activeReasoningContent := make(map[string]reasoningState)
// Track tool names by ID for input delta publishing.
toolNames := make(map[string]string)
// Track reasoning text/titles for title extraction.
reasoningTitles := make(map[string]string)
reasoningText := make(map[string]string)
setReasoningTitleFromText := func(id string, text string) {
if id == "" || strings.TrimSpace(text) == "" {
return
}
if reasoningTitles[id] != "" {
return
}
reasoningText[id] += text
if !strings.ContainsAny(reasoningText[id], "\r\n") {
return
}
title := chatprompt.ReasoningTitleFromFirstLine(reasoningText[id])
if title == "" {
return
}
reasoningTitles[id] = title
}
for part := range stream {
switch part.Type {
@@ -492,7 +459,10 @@ func processStepStream(
if _, exists := activeTextContent[part.ID]; exists {
activeTextContent[part.ID] += part.Delta
}
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(part.Delta))
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: part.Delta,
})
case fantasy.StreamPartTypeTextEnd:
if text, exists := activeTextContent[part.ID]; exists {
@@ -515,7 +485,13 @@ func processStepStream(
active.options = part.ProviderMetadata
activeReasoningContent[part.ID] = active
}
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageReasoning(part.Delta))
setReasoningTitleFromText(part.ID, part.Delta)
title := reasoningTitles[part.ID]
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: part.Delta,
Title: title,
})
case fantasy.StreamPartTypeReasoningEnd:
if active, exists := activeReasoningContent[part.ID]; exists {
@@ -528,6 +504,21 @@ func processStepStream(
}
result.content = append(result.content, content)
delete(activeReasoningContent, part.ID)
// Derive reasoning title at end of reasoning
// block if we haven't yet.
if reasoningTitles[part.ID] == "" {
reasoningTitles[part.ID] = chatprompt.ReasoningTitleFromFirstLine(
reasoningText[part.ID],
)
}
title := reasoningTitles[part.ID]
if title != "" {
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Title: title,
})
}
}
case fantasy.StreamPartTypeToolInputStart:
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
@@ -541,19 +532,17 @@ func processStepStream(
}
case fantasy.StreamPartTypeToolInputDelta:
var providerExecuted bool
if toolCall, exists := activeToolCalls[part.ID]; exists {
toolCall.Input += part.Delta
providerExecuted = toolCall.ProviderExecuted
}
toolName := toolNames[part.ID]
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: part.ID,
ToolName: toolName,
ArgsDelta: part.Delta,
ProviderExecuted: providerExecuted,
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: part.ID,
ToolName: toolName,
ArgsDelta: part.Delta,
})
case fantasy.StreamPartTypeToolInputEnd:
// No callback needed; the full tool call arrives in
// StreamPartTypeToolCall.
@@ -575,7 +564,7 @@ func processStepStream(
delete(activeToolCalls, part.ID)
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(tc),
)
@@ -589,40 +578,20 @@ func processStepStream(
}
result.content = append(result.content, sourceContent)
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(sourceContent),
)
case fantasy.StreamPartTypeToolResult:
// Provider-executed tool results (e.g. web search)
// are emitted by the provider and added directly
// to the step content for multi-turn round-tripping.
// This mirrors fantasy's agent.go accumulation logic.
if part.ProviderExecuted {
tr := fantasy.ToolResultContent{
ToolCallID: part.ID,
ToolName: part.ToolCallName,
ProviderExecuted: part.ProviderExecuted,
ProviderMetadata: part.ProviderMetadata,
}
result.content = append(result.content, tr)
publishMessagePart(
codersdk.ChatMessageRoleTool,
chatprompt.PartFromContent(tr),
)
}
case fantasy.StreamPartTypeFinish:
result.usage = part.Usage
result.finishReason = part.FinishReason
result.providerMetadata = part.ProviderMetadata
case fantasy.StreamPartTypeError:
// Detect interruption: the stream may surface the
// cancel as context.Canceled or propagate the
// ErrInterrupted cause directly, depending on
// the provider implementation.
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
// Detect interruption: context canceled with
// ErrInterrupted as the cause.
if errors.Is(part.Error, context.Canceled) &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
// Flush in-progress content so that
// persistInterruptedStep has access to partial
// text, reasoning, and tool calls that were
@@ -640,43 +609,17 @@ func processStepStream(
}
}
// The stream iterator may stop yielding parts without
// producing a StreamPartTypeError when the context is
// canceled (e.g. some providers close the response body
// silently). Detect this case and flush partial content
// so that persistInterruptedStep can save it.
if ctx.Err() != nil &&
errors.Is(context.Cause(ctx), ErrInterrupted) {
flushActiveState(
&result,
activeTextContent,
activeReasoningContent,
activeToolCalls,
toolNames,
)
return result, ErrInterrupted
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
hasLocalToolCalls = true
break
}
}
result.shouldContinue = hasLocalToolCalls &&
result.shouldContinue = len(result.toolCalls) > 0 &&
result.finishReason == fantasy.FinishReasonToolCalls
return result, nil
}
// executeTools runs all tool calls concurrently after the stream
// completes. Results are published via onResult in the original
// tool-call order after all tools finish, preserving deterministic
// event ordering for SSE subscribers.
// executeTools runs each tool call sequentially after the stream
// completes. Results are published via onResult as each tool
// finishes.
func executeTools(
ctx context.Context,
allTools []fantasy.AgentTool,
providerTools []ProviderTool,
toolCalls []fantasy.ToolCallContent,
onResult func(fantasy.ToolResultContent),
) []fantasy.ToolResultContent {
@@ -684,58 +627,16 @@ func executeTools(
return nil
}
// Filter out provider-executed tool calls. These were
// handled server-side by the LLM provider (e.g., web
// search) and their results are already in the stream
// content.
localToolCalls := make([]fantasy.ToolCallContent, 0, len(toolCalls))
for _, tc := range toolCalls {
if !tc.ProviderExecuted {
localToolCalls = append(localToolCalls, tc)
}
}
if len(localToolCalls) == 0 {
return nil
}
toolMap := make(map[string]fantasy.AgentTool, len(allTools))
for _, t := range allTools {
toolMap[t.Info().Name] = t
}
// Include runners from provider tools so locally-executed
// provider tools (e.g. computer use) can be dispatched.
for _, pt := range providerTools {
if pt.Runner != nil {
toolMap[pt.Runner.Info().Name] = pt.Runner
}
}
results := make([]fantasy.ToolResultContent, len(localToolCalls))
var wg sync.WaitGroup
wg.Add(len(localToolCalls))
for i, tc := range localToolCalls {
go func(i int, tc fantasy.ToolCallContent) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
results[i] = fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.Errorf("tool panicked: %v", r),
},
}
}
}()
results[i] = executeSingleTool(ctx, toolMap, tc)
}(i, tc)
}
wg.Wait()
// Publish results in the original tool-call order so SSE
// subscribers see a deterministic event sequence.
if onResult != nil {
for _, tr := range results {
results := make([]fantasy.ToolResultContent, 0, len(toolCalls))
for _, tc := range toolCalls {
tr := executeSingleTool(ctx, toolMap, tc)
results = append(results, tr)
if onResult != nil {
onResult(tr)
}
}
@@ -885,9 +786,8 @@ func persistInterruptedStep(
continue
}
content = append(content, fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
ProviderExecuted: tc.ProviderExecuted,
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New(interruptedToolResultErrorMessage),
},
@@ -907,17 +807,15 @@ func persistInterruptedStep(
// buildToolDefinitions converts AgentTool definitions into the
// fantasy.Tool slice expected by fantasy.Call. When activeTools
// is non-empty, only function tools whose name appears in the
// list are included. Provider tool definitions are always
// appended unconditionally.
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []ProviderTool) []fantasy.Tool {
prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools))
// is non-empty, only tools whose name appears in the list are
// included. This mirrors fantasy's agent.prepareTools filtering.
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fantasy.Tool {
prepared := make([]fantasy.Tool, 0, len(tools))
for _, tool := range tools {
info := tool.Info()
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
continue
}
inputSchema := map[string]any{
"type": "object",
"properties": info.Parameters,
@@ -931,9 +829,6 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi
ProviderOptions: tool.ProviderOptions(),
})
}
for _, pt := range providerTools {
prepared = append(prepared, pt.Definition)
}
return prepared
}

Some files were not shown because too many files have changed in this diff Show More