Compare commits
231 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 09d799606b | |||
| 6e5ac02c96 | |||
| f714f589c5 | |||
| 72689c2552 | |||
| 85509733f3 | |||
| eacabd8390 | |||
| 84527390c6 | |||
| 67f5494665 | |||
| 9d33c340ec | |||
| 3bd840fe27 | |||
| 03d0fc4f4c | |||
| efe114119f | |||
| c3b6284955 | |||
| 1152b61ebb | |||
| 5745ff7912 | |||
| 4a79af1a0d | |||
| bdbcd3428b | |||
| 870583224d | |||
| df2360f56a | |||
| cc6716c730 | |||
| 836a2112b6 | |||
| 690e3a87d8 | |||
| 0e7e0a959e | |||
| ff156772f2 | |||
| a5400b2208 | |||
| 4e2640e506 | |||
| 6104a000d1 | |||
| 8714aa4637 | |||
| 7777072d7a | |||
| f6f33fa480 | |||
| 84dc1a3482 | |||
| 0e1846fe2a | |||
| 322a94b23b | |||
| e9025f91e8 | |||
| 4b8c079eef | |||
| 42c12176a0 | |||
| 072e9a212f | |||
| d21a9373b6 | |||
| 2488cf0d41 | |||
| 3407fa80a4 | |||
| 1ac5418fc4 | |||
| b1e80e6f3a | |||
| fc9e04da67 | |||
| 57af7abf1f | |||
| a6697b1b29 | |||
| c8079a5b8c | |||
| 5cb820387c | |||
| 2bb483b425 | |||
| 3aada03f52 | |||
| c3923f2ccd | |||
| 2b70122e4a | |||
| fd6346265c | |||
| 53bfbf7c03 | |||
| c7abfc6ff8 | |||
| 660a3dad21 | |||
| e7e2de99ba | |||
| 5130404f2a | |||
| fba00a6b3a | |||
| 3325b86903 | |||
| 53304df70d | |||
| d495a4eddb | |||
| a342fc43c3 | |||
| 45c32d62c5 | |||
| 58f295059c | |||
| 4d7eb2ae4b | |||
| 57dc23f603 | |||
| fc607cd400 | |||
| 51198744ff | |||
| 1f37df4db3 | |||
| e5c19d0af4 | |||
| e96cd5cbb2 | |||
| 77d53d2955 | |||
| d39f69f4c2 | |||
| c33dc3e459 | |||
| 7a83d825cf | |||
| a46336c3ec | |||
| 40114b8eea | |||
| 2f2ba0ef7e | |||
| 9d2643d3aa | |||
| ac791e5bd3 | |||
| 7b846fb548 | |||
| 196c6702fd | |||
| bb59477648 | |||
| c7c789f9e4 | |||
| 71b132b9e7 | |||
| c72d3e4919 | |||
| f766ad064d | |||
| 0a026fde39 | |||
| 2d7dd73106 | |||
| c24b240934 | |||
| f2eb6d5af0 | |||
| e7f8dfbe15 | |||
| bfc58c8238 | |||
| bc27274aba | |||
| cbe46c816e | |||
| 53e52aef78 | |||
| c2534c19f6 | |||
| da71a09ab6 | |||
| 33136dfe39 | |||
| 22a87f6cf6 | |||
| b44a421412 | |||
| 4c63ed7602 | |||
| 983f362dff | |||
| 8b72feeae4 | |||
| b74d60e88c | |||
| d3986b53b9 | |||
| 8cc6473736 | |||
| 30a63009aa | |||
| f22450f29b | |||
| 01f25dd9ae | |||
| b6d1a11c58 | |||
| 6489d6f714 | |||
| 12bdbc693f | |||
| f5e5bd2d64 | |||
| fee5cc5e5b | |||
| 72fb0cd554 | |||
| ba764a24ea | |||
| 8c70170ee7 | |||
| e18ce505ec | |||
| beed379b1d | |||
| 2948400aef | |||
| f35b99a4fa | |||
| b898e45ec4 | |||
| d61772dc52 | |||
| c933ddcffd | |||
| a21f00d250 | |||
| 3167908358 | |||
| 45f62d1487 | |||
| b850d40db8 | |||
| 73bf8478d8 | |||
| 41c505f03b | |||
| abdfadf8cb | |||
| d936a99e6b | |||
| 14341edfc2 | |||
| e7ea649dc2 | |||
| 56960585af | |||
| f07e266904 | |||
| 9bc884d597 | |||
| f46692531f | |||
| 6e9e39a4e0 | |||
| 1a2eea5e76 | |||
| 9e7125f852 | |||
| e6983648aa | |||
| 47846c0ee4 | |||
| ff715c9f4c | |||
| f4ab854b06 | |||
| c6b68b2991 | |||
| 5dfd563e4b | |||
| 4957888270 | |||
| 26adc26a26 | |||
| b33b8e476b | |||
| 95bd099c77 | |||
| 3f939375fa | |||
| a072d542a5 | |||
| a96ec4c397 | |||
| 2eb3ab4cf5 | |||
| 51a627c107 | |||
| 49006685b0 | |||
| 715486465b | |||
| e205a3493d | |||
| 6b14a3eb7f | |||
| 0fea47d97c | |||
| 02b1951aac | |||
| dd34e3d3c2 | |||
| a48e4a43e2 | |||
| 5b7ba93cb2 | |||
| aba3832b15 | |||
| ca873060c6 | |||
| 896c43d5b7 | |||
| 2ad0e74e67 | |||
| 6509fb2574 | |||
| 667d501282 | |||
| 69a4a8825d | |||
| 4b3ed61210 | |||
| c01430e53b | |||
| 7d6fde35bd | |||
| 77ca772552 | |||
| 703629f5e9 | |||
| 4cf8d4414e | |||
| 3608064600 | |||
| 4e50ca6b6e | |||
| 4c83a7021f | |||
| b9c729457b | |||
| 9bd712013f | |||
| 8c52e150f6 | |||
| f404463317 | |||
| 338d30e4c4 | |||
| 4afdfc50a5 | |||
| b199ef1b69 | |||
| eecb7d0b66 | |||
| 2cd871e88f | |||
| b9b3c67c73 | |||
| 09aa7b1887 | |||
| 5712faaa2c | |||
| a104d608a3 | |||
| 30a736c49e | |||
| 537260aa22 | |||
| ec48636ba8 | |||
| 752e6ecc16 | |||
| d06bf5c75f | |||
| 6665944740 | |||
| c0ef3540a5 | |||
| eb1d194447 | |||
| 2618952598 | |||
| 24c7a09321 | |||
| 13e3df67d6 | |||
| f9891416c0 | |||
| c805c8c02c | |||
| 4e781c9323 | |||
| ba05188934 | |||
| 71ac4847cf | |||
| ffb47cea19 | |||
| 957fb556da | |||
| ecf3dccbbc | |||
| d91d9712f7 | |||
| 48ab492f49 | |||
| 81468323e0 | |||
| 6c44de951d | |||
| d034903736 | |||
| fd60fa7eb6 | |||
| 0b1e4880bd | |||
| 9f6f4ba74d | |||
| 56bdea73b8 | |||
| 719c24829a | |||
| f91475cd51 | |||
| 25dac6e5f7 | |||
| 51f298f2de | |||
| 5dd570f099 | |||
| dba688662c | |||
| 0ec27e3d48 | |||
| 8d3d537ca6 |
@@ -189,8 +189,8 @@ func (q *sqlQuerier) UpdateUser(ctx context.Context, arg UpdateUserParams) (User
|
||||
### Common Debug Commands
|
||||
|
||||
```bash
|
||||
# Check database connection
|
||||
make test-postgres
|
||||
# Run tests (starts Postgres automatically if needed)
|
||||
make test
|
||||
|
||||
# Run specific database tests
|
||||
go test ./coderd/database/... -run TestSpecificFunction
|
||||
|
||||
@@ -67,7 +67,6 @@ coderd/
|
||||
| `make test` | Run all Go tests |
|
||||
| `make test RUN=TestFunctionName` | Run specific test |
|
||||
| `go test -v ./path/to/package -run TestFunctionName` | Run test with verbose output |
|
||||
| `make test-postgres` | Run tests with Postgres database |
|
||||
| `make test-race` | Run tests with Go race detector |
|
||||
| `make test-e2e` | Run end-to-end tests |
|
||||
|
||||
|
||||
@@ -109,7 +109,6 @@
|
||||
|
||||
- Run full test suite: `make test`
|
||||
- Run specific test: `make test RUN=TestFunctionName`
|
||||
- Run with Postgres: `make test-postgres`
|
||||
- Run with race detector: `make test-race`
|
||||
- Run end-to-end tests: `make test-e2e`
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
name: "🐞 Bug"
|
||||
description: "File a bug report."
|
||||
title: "bug: "
|
||||
labels: ["needs-triage"]
|
||||
type: "Bug"
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
||||
@@ -64,17 +64,14 @@ runs:
|
||||
TEST_PACKAGES: ${{ inputs.test-packages }}
|
||||
RACE_DETECTION: ${{ inputs.race-detection }}
|
||||
TS_DEBUG_DISCO: "true"
|
||||
TS_DEBUG_DERP: "true"
|
||||
LC_CTYPE: "en_US.UTF-8"
|
||||
LC_ALL: "en_US.UTF-8"
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ${RACE_DETECTION} == true ]]; then
|
||||
gotestsum --junitfile="gotests.xml" --packages="${TEST_PACKAGES}" -- \
|
||||
-tags=testsmallbatch \
|
||||
-race \
|
||||
-parallel "${TEST_NUM_PARALLEL_TESTS}" \
|
||||
-p "${TEST_NUM_PARALLEL_PACKAGES}"
|
||||
make test-race
|
||||
else
|
||||
make test
|
||||
fi
|
||||
|
||||
@@ -366,9 +366,9 @@ jobs:
|
||||
needs: changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -475,11 +475,6 @@ jobs:
|
||||
mkdir -p /tmp/tmpfs
|
||||
sudo mount_tmpfs -o noowners -s 8g /tmp/tmpfs
|
||||
|
||||
# Install google-chrome for scaletests.
|
||||
# As another concern, should we really have this kind of external dependency
|
||||
# requirement on standard CI?
|
||||
brew install google-chrome
|
||||
|
||||
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
|
||||
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
|
||||
|
||||
@@ -574,9 +569,9 @@ jobs:
|
||||
- changes
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
@@ -986,6 +981,9 @@ jobs:
|
||||
run: |
|
||||
make build/coder_docs_"$(./scripts/version.sh)".tgz
|
||||
|
||||
- name: Check for unstaged files
|
||||
run: ./scripts/check_unstaged.sh
|
||||
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
|
||||
@@ -19,6 +19,9 @@ on:
|
||||
default: ""
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
classify-severity:
|
||||
name: AI Severity Classification
|
||||
@@ -32,7 +35,6 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
- name: Determine Issue Context
|
||||
|
||||
@@ -31,6 +31,9 @@ on:
|
||||
default: ""
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
code-review:
|
||||
name: AI Code Review
|
||||
@@ -51,7 +54,6 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
- name: Check if secrets are available
|
||||
|
||||
@@ -34,6 +34,9 @@ on:
|
||||
default: ""
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
doc-check:
|
||||
name: Analyze PR for Documentation Updates Needed
|
||||
@@ -56,7 +59,6 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
- name: Check if secrets are available
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
name: Linear Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
# This event reads the workflow from the default branch (main), not the
|
||||
# release branch. No cherry-pick needed.
|
||||
# https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#release
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
name: Sync issues to Linear release
|
||||
if: github.event_name == 'push'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Sync issues
|
||||
id: sync
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: sync
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.sync.outputs.release-url
|
||||
run: echo "Synced to $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.sync.outputs.release-url }}
|
||||
|
||||
complete:
|
||||
name: Complete Linear release
|
||||
if: github.event_name == 'release'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Complete release
|
||||
id: complete
|
||||
uses: linear/linear-release-action@f64cdc603e6eb7a7ef934bc5492ae929f88c8d1a # v0
|
||||
with:
|
||||
access_key: ${{ secrets.LINEAR_ACCESS_KEY }}
|
||||
command: complete
|
||||
version: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Print release URL
|
||||
if: steps.complete.outputs.release-url
|
||||
run: echo "Completed $RELEASE_URL"
|
||||
env:
|
||||
RELEASE_URL: ${{ steps.complete.outputs.release-url }}
|
||||
@@ -16,9 +16,9 @@ jobs:
|
||||
# when changing runner sizes
|
||||
runs-on: ${{ matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'depot-macos-latest' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'depot-windows-2022-16' || matrix.os }}
|
||||
# This timeout must be greater than the timeout set by `go test` in
|
||||
# `make test-postgres` to ensure we receive a trace of running
|
||||
# goroutines. Setting this to the timeout +5m should work quite well
|
||||
# even if some of the preceding steps are slow.
|
||||
# `make test` to ensure we receive a trace of running goroutines.
|
||||
# Setting this to the timeout +5m should work quite well even if
|
||||
# some of the preceding steps are slow.
|
||||
timeout-minutes: 25
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -26,6 +26,9 @@ on:
|
||||
default: "traiage"
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
traiage:
|
||||
name: Triage GitHub Issue with Claude Code
|
||||
@@ -38,7 +41,6 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
# This is only required for testing locally using nektos/act, so leaving commented out.
|
||||
|
||||
@@ -30,6 +30,22 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Rewrite same-repo links for PR branch
|
||||
if: github.event_name == 'pull_request'
|
||||
env:
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
run: |
|
||||
# Rewrite same-repo blob/tree main links to the PR head SHA
|
||||
# so that files or directories introduced in the PR are
|
||||
# reachable during link checking.
|
||||
{
|
||||
echo 'replacementPatterns:'
|
||||
echo " - pattern: \"https://github.com/coder/coder/blob/main/\""
|
||||
echo " replacement: \"https://github.com/coder/coder/blob/${HEAD_SHA}/\""
|
||||
echo " - pattern: \"https://github.com/coder/coder/tree/main/\""
|
||||
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
|
||||
} >> .github/.linkspector.yml
|
||||
|
||||
- name: Check Markdown links
|
||||
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
|
||||
id: markdown-link-check
|
||||
|
||||
@@ -38,6 +38,7 @@ site/.swc
|
||||
|
||||
# Make target for updating generated/golden files (any dir).
|
||||
.gen
|
||||
/_gen/
|
||||
.gen-golden
|
||||
|
||||
# Build
|
||||
|
||||
@@ -37,19 +37,20 @@ Only pause to ask for confirmation when:
|
||||
|
||||
## Essential Commands
|
||||
|
||||
| Task | Command | Notes |
|
||||
|-------------------|--------------------------|----------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Postgres** | `make test-postgres` | Run tests with Postgres database |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
| Task | Command | Notes |
|
||||
|-----------------|--------------------------|-------------------------------------|
|
||||
| **Development** | `./scripts/develop.sh` | ⚠️ Don't use manual build |
|
||||
| **Build** | `make build` | Fat binaries (includes server) |
|
||||
| **Build Slim** | `make build-slim` | Slim binaries |
|
||||
| **Test** | `make test` | Full test suite |
|
||||
| **Test Single** | `make test RUN=TestName` | Faster than full suite |
|
||||
| **Test Race** | `make test-race` | Run tests with Go race detector |
|
||||
| **Lint** | `make lint` | Always run after changes |
|
||||
| **Generate** | `make gen` | After database changes |
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
|
||||
| **Pre-push** | `make pre-push` | Heavier CI checks (allowlisted) |
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
@@ -99,10 +100,66 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
|
||||
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
- Add swagger annotations when introducing new HTTP endpoints. Do this in
|
||||
the same change as the handler so the docs do not get missed before
|
||||
release.
|
||||
- For user-scoped or resource-scoped routes, prefer path parameters over
|
||||
query parameters when that matches existing route patterns.
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
- Use `ByX` when `X` is the lookup or filter column.
|
||||
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
|
||||
dimension.
|
||||
- Avoid `ByX` names for grouped queries.
|
||||
|
||||
### Database-to-SDK Conversions
|
||||
|
||||
- Extract explicit db-to-SDK conversion helpers instead of inlining large
|
||||
conversion blocks inside handlers.
|
||||
- Keep nullable-field handling, type coercion, and response shaping in the
|
||||
converter so handlers stay focused on request flow and authorization.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
|
||||
### Git Hooks (MANDATORY - DO NOT SKIP)
|
||||
|
||||
**You MUST install and use the git hooks. NEVER bypass them with
|
||||
`--no-verify`. Skipping hooks wastes CI cycles and is unacceptable.**
|
||||
|
||||
The first run will be slow as caches warm up. Consecutive runs are
|
||||
**significantly faster** (often 10x) thanks to Go build cache,
|
||||
generated file timestamps, and warm node_modules. This is NOT a
|
||||
reason to skip them. Wait for hooks to complete before proceeding,
|
||||
no matter how long they take.
|
||||
|
||||
```sh
|
||||
git config core.hooksPath scripts/githooks
|
||||
```
|
||||
|
||||
Two hooks run automatically:
|
||||
|
||||
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
|
||||
Fast checks that catch most CI failures. Allow at least 5 minutes.
|
||||
- **pre-push**: `make pre-push` (heavier checks including tests).
|
||||
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
|
||||
who opt in. Allow at least 15 minutes.
|
||||
|
||||
`git commit` and `git push` will appear to hang while hooks run.
|
||||
This is normal. Do not interrupt, retry, or reduce the timeout.
|
||||
|
||||
NEVER run `git config core.hooksPath` to change or disable hooks.
|
||||
|
||||
If a hook fails, fix the issue and retry. Do not work around the
|
||||
failure by skipping the hook.
|
||||
|
||||
### Git Workflow
|
||||
|
||||
When working on existing PRs, check out the branch first:
|
||||
@@ -152,6 +209,21 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
|
||||
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
|
||||
- Commit format: `type(scope): message`
|
||||
|
||||
### Frontend Patterns
|
||||
|
||||
- Prefer existing shared UI components and utilities over custom
|
||||
implementations. Reuse common primitives such as loading, table, and error
|
||||
handling components when they fit the use case.
|
||||
- Use Storybook stories for all component and page testing, including
|
||||
visual presentation, user interactions, keyboard navigation, focus
|
||||
management, and accessibility behavior. Do not create standalone
|
||||
vitest/RTL test files for components or pages. Stories double as living
|
||||
documentation, visual regression coverage, and interaction test suites
|
||||
via `play` functions. Reserve plain vitest files for pure logic only:
|
||||
utility functions, data transformations, hooks tested via
|
||||
`renderHook()` that do not require DOM assertions, and query/cache
|
||||
operations with no rendered output.
|
||||
|
||||
### Writing Comments
|
||||
|
||||
Code comments should be clear, well-formatted, and add meaningful context.
|
||||
|
||||
@@ -19,6 +19,17 @@ SHELL := bash
|
||||
.SHELLFLAGS := -ceu
|
||||
.ONESHELL:
|
||||
|
||||
# When MAKE_TIMED=1, replace SHELL with a wrapper that prints
|
||||
# elapsed wall-clock time for each recipe. pre-commit and pre-push
|
||||
# set this on their sub-makes so every parallel job reports its
|
||||
# duration. Ad-hoc usage: make MAKE_TIMED=1 test
|
||||
ifdef MAKE_TIMED
|
||||
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
|
||||
.SHELLFLAGS = $@ -ceu
|
||||
export MAKE_TIMED
|
||||
export MAKE_LOGDIR
|
||||
endif
|
||||
|
||||
# This doesn't work on directories.
|
||||
# See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets
|
||||
.DELETE_ON_ERROR:
|
||||
@@ -33,6 +44,25 @@ SHELL := bash
|
||||
coderd/database/unique_constraint.go \
|
||||
coderd/database/dbmetrics/querymetrics.go \
|
||||
coderd/database/dbauthz/dbauthz.go \
|
||||
coderd/database/dbmock/dbmock.go \
|
||||
coderd/database/pubsub/psmock/psmock.go \
|
||||
agent/agentcontainers/acmock/acmock.go \
|
||||
coderd/httpmw/loggermw/loggermock/loggermock.go \
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
|
||||
tailnet/tailnettest/coordinatormock.go \
|
||||
tailnet/tailnettest/coordinateemock.go \
|
||||
tailnet/tailnettest/workspaceupdatesprovidermock.go \
|
||||
tailnet/tailnettest/subscriptionmock.go \
|
||||
enterprise/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/aibridged/aibridgedmock/poolmock.go \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
agent/boundarylogproxy/codec/boundary.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
site/src/api/typesGenerated.ts \
|
||||
site/e2e/provisionerGenerated.ts \
|
||||
site/src/api/chatModelOptionsGenerated.json \
|
||||
@@ -50,6 +80,23 @@ SHELL := bash
|
||||
codersdk/rbacresources_gen.go \
|
||||
codersdk/apikey_scopes_gen.go
|
||||
|
||||
# atomic_write runs a command, captures stdout into a temp file, and
|
||||
# atomically replaces $@. An optional second argument is a formatting
|
||||
# command that receives the temp file path as its argument.
|
||||
# Usage: $(call atomic_write,GENERATE_CMD[,FORMAT_CMD])
|
||||
define atomic_write
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
$(1) > "$$tmpfile" && \
|
||||
$(if $(2),$(2) "$$tmpfile" &&) \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
_gen:
|
||||
mkdir -p _gen
|
||||
|
||||
# Don't print the commands in the file unless you specify VERBOSE. This is
|
||||
# essentially the same as putting "@" at the start of each line.
|
||||
ifndef VERBOSE
|
||||
@@ -67,11 +114,19 @@ VERSION := $(shell ./scripts/version.sh)
|
||||
POSTGRES_VERSION ?= 17
|
||||
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
|
||||
|
||||
# Use the highest ZSTD compression level in CI.
|
||||
ifdef CI
|
||||
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
|
||||
# nproc/4 (min 2) since test, lint, and build targets have internal
|
||||
# parallelism. Override: make pre-push PARALLEL_JOBS=8
|
||||
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
|
||||
|
||||
# Use the highest ZSTD compression level in release builds to
|
||||
# minimize artifact size. For non-release CI builds (e.g. main
|
||||
# branch preview), use multithreaded level 6 which is ~99% faster
|
||||
# at the cost of ~30% larger archives.
|
||||
ifeq ($(CODER_RELEASE),true)
|
||||
ZSTDFLAGS := -22 --ultra
|
||||
else
|
||||
ZSTDFLAGS := -6
|
||||
ZSTDFLAGS := -6 -T0
|
||||
endif
|
||||
|
||||
# Common paths to exclude from find commands, this rule is written so
|
||||
@@ -80,7 +135,7 @@ endif
|
||||
# Note, all find statements should be written with `.` or `./path` as
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' \) -prune \)
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
@@ -461,6 +516,9 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
|
||||
BOLD := $(shell tput bold 2>/dev/null)
|
||||
GREEN := $(shell tput setaf 2 2>/dev/null)
|
||||
RED := $(shell tput setaf 1 2>/dev/null)
|
||||
YELLOW := $(shell tput setaf 3 2>/dev/null)
|
||||
DIM := $(shell tput dim 2>/dev/null || tput setaf 8 2>/dev/null)
|
||||
RESET := $(shell tput sgr0 2>/dev/null)
|
||||
|
||||
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
|
||||
@@ -570,7 +628,7 @@ endif
|
||||
# GitHub Actions linters are run in a separate CI job (lint-actions) that only
|
||||
# triggers when workflow files change, so we skip them here when CI=true.
|
||||
LINT_ACTIONS_TARGETS := $(if $(CI),,lint/actions/actionlint)
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations $(LINT_ACTIONS_TARGETS)
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/check-scopes lint/migrations lint/bootstrap $(LINT_ACTIONS_TARGETS)
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -585,7 +643,7 @@ lint/ts: site/node_modules/.installed
|
||||
lint/go:
|
||||
./scripts/check_enterprise_imports.sh
|
||||
./scripts/check_codersdk_imports.sh
|
||||
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
|
||||
linter_ver=$$(grep -oE 'GOLANGCI_LINT_VERSION=\S+' dogfood/coder/Dockerfile | cut -d '=' -f 2)
|
||||
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
@@ -600,6 +658,11 @@ lint/shellcheck: $(SHELL_SRC_FILES)
|
||||
shellcheck --external-sources $(SHELL_SRC_FILES)
|
||||
.PHONY: lint/shellcheck
|
||||
|
||||
lint/bootstrap:
|
||||
bash scripts/check_bootstrap_quotes.sh
|
||||
.PHONY: lint/bootstrap
|
||||
|
||||
|
||||
lint/helm:
|
||||
cd helm/
|
||||
make lint
|
||||
@@ -634,6 +697,102 @@ lint/migrations:
|
||||
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
|
||||
.PHONY: lint/migrations
|
||||
|
||||
TYPOS_VERSION := $(shell grep -oP 'crate-ci/typos@\S+\s+\#\s+v\K[0-9.]+' .github/workflows/ci.yaml)
|
||||
|
||||
# Map uname values to typos release asset names.
|
||||
TYPOS_ARCH := $(shell uname -m)
|
||||
ifeq ($(shell uname -s),Darwin)
|
||||
TYPOS_OS := apple-darwin
|
||||
else
|
||||
TYPOS_OS := unknown-linux-musl
|
||||
endif
|
||||
|
||||
build/typos-$(TYPOS_VERSION):
|
||||
mkdir -p build/
|
||||
curl -sSfL "https://github.com/crate-ci/typos/releases/download/v$(TYPOS_VERSION)/typos-v$(TYPOS_VERSION)-$(TYPOS_ARCH)-$(TYPOS_OS).tar.gz" \
|
||||
| tar -xzf - -C build/ ./typos
|
||||
mv build/typos "$@"
|
||||
|
||||
lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
build/typos-$(TYPOS_VERSION) --config .github/workflows/typos.toml
|
||||
.PHONY: lint/typos
|
||||
|
||||
# pre-commit and pre-push mirror CI checks locally.
|
||||
#
|
||||
# pre-commit runs checks that don't need external services (Docker,
|
||||
# Playwright). This is the git pre-commit hook default since Docker
|
||||
# and browser issues in the local environment would otherwise block
|
||||
# all commits.
|
||||
#
|
||||
# pre-push adds heavier checks: Go tests, JS tests, and site build.
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
define check-unstaged
|
||||
unstaged="$$(git diff --name-only)"
|
||||
if [[ -n $$unstaged ]]; then
|
||||
echo "$(RED)✗ check unstaged changes$(RESET)"
|
||||
echo "$$unstaged" | sed 's/^/ - /'
|
||||
echo ""
|
||||
echo "$(DIM) Verify generated changes are correct before staging:$(RESET)"
|
||||
echo "$(DIM) git diff$(RESET)"
|
||||
echo "$(DIM) git add -u && git commit$(RESET)"
|
||||
exit 1
|
||||
fi
|
||||
endef
|
||||
define check-untracked
|
||||
untracked=$$(git ls-files --other --exclude-standard)
|
||||
if [[ -n $$untracked ]]; then
|
||||
echo "$(YELLOW)? check untracked files$(RESET)"
|
||||
echo "$$untracked" | sed 's/^/ - /'
|
||||
echo ""
|
||||
echo "$(DIM) Review if these should be committed or added to .gitignore.$(RESET)"
|
||||
fi
|
||||
endef
|
||||
|
||||
pre-commit:
|
||||
start=$$(date +%s)
|
||||
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit.XXXXXX")
|
||||
echo "$(BOLD)pre-commit$(RESET) ($$logdir)"
|
||||
echo "gen + fmt:"
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir gen fmt
|
||||
$(check-unstaged)
|
||||
echo "lint + build:"
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
|
||||
lint \
|
||||
lint/typos \
|
||||
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
$(check-unstaged)
|
||||
$(check-untracked)
|
||||
rm -rf $$logdir
|
||||
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
|
||||
.PHONY: pre-commit
|
||||
|
||||
pre-push:
|
||||
start=$$(date +%s)
|
||||
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
|
||||
echo "$(BOLD)pre-push$(RESET) ($$logdir)"
|
||||
echo "test + build site:"
|
||||
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
|
||||
test \
|
||||
test-js \
|
||||
site/out/index.html
|
||||
rm -rf $$logdir
|
||||
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
|
||||
.PHONY: pre-push
|
||||
|
||||
offlinedocs/check: offlinedocs/node_modules/.installed
|
||||
cd offlinedocs/
|
||||
pnpm format:check
|
||||
pnpm lint
|
||||
pnpm export
|
||||
.PHONY: offlinedocs/check
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
@@ -822,7 +981,7 @@ $(TAILNETTEST_MOCKS): tailnet/coordinator.go tailnet/service.go
|
||||
touch "$@"
|
||||
|
||||
tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -830,7 +989,7 @@ tailnet/proto/tailnet.pb.go: tailnet/proto/tailnet.proto
|
||||
./tailnet/proto/tailnet.proto
|
||||
|
||||
agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -838,7 +997,7 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
./agent/proto/agent.proto
|
||||
|
||||
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto agent/proto/agent.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -846,7 +1005,7 @@ agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.p
|
||||
./agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -854,7 +1013,7 @@ provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
./provisionersdk/proto/provisioner.proto
|
||||
|
||||
provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
@@ -862,132 +1021,110 @@ provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
|
||||
./provisionerd/proto/provisionerd.proto
|
||||
|
||||
vpn/vpn.pb.go: vpn/vpn.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
./agent/boundarylogproxy/codec/boundary.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
protoc \
|
||||
./scripts/atomic_protoc.sh \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# Generate to a temp file, format it, then atomically move to
|
||||
# the target so that an interrupt never leaves a partial or
|
||||
# unformatted file in the working tree.
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run -C ./scripts/apitypings main.go > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
go run ./scripts/examplegen/main.go > "$@.tmp" && mv "$@.tmp" "$@"
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
tempdir=$(shell mktemp -d /tmp/typegen_rbac_object.XXXXXX)
|
||||
go run ./scripts/typegen/main.go rbac object > "$$tempdir/object_gen.go"
|
||||
mv -v "$$tempdir/object_gen.go" coderd/rbac/object_gen.go
|
||||
rmdir -v "$$tempdir"
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go
|
||||
# Generate typed low-level ScopeName constants from RBACPermissions
|
||||
# Write to a temp file first to avoid truncating the package during build
|
||||
# since the generator imports the rbac package.
|
||||
tempfile=$(shell mktemp /tmp/scopes_constants_gen.XXXXXX)
|
||||
go run ./scripts/typegen/main.go rbac scopenames > "$$tempfile"
|
||||
mv -v "$$tempfile" coderd/rbac/scopes_constants_gen.go
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
# Do no overwrite codersdk/rbacresources_gen.go directly, as it would make the file empty, breaking
|
||||
# the `codersdk` package and any parallel build targets.
|
||||
go run scripts/typegen/main.go rbac codersdk > /tmp/rbacresources_gen.go
|
||||
mv /tmp/rbacresources_gen.go codersdk/rbacresources_gen.go
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
go run ./scripts/apikeyscopesgen > /tmp/apikey_scopes_gen.go
|
||||
mv /tmp/apikey_scopes_gen.go codersdk/apikey_scopes_gen.go
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run scripts/typegen/main.go rbac typescript > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run scripts/typegen/main.go countries > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && \
|
||||
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
|
||||
go run ./scripts/metricsdocgen/scanner > $@.tmp && mv $@.tmp $@
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES)
|
||||
tmpdir=$$(mktemp -d) && \
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
cp "$$tmpdir/docs/reference/cli/"*.md docs/reference/cli/ && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
coderd/apidoc/.gen: \
|
||||
node_modules/.installed \
|
||||
@@ -1002,25 +1139,27 @@ coderd/apidoc/.gen: \
|
||||
scripts/apidocgen/generate.sh \
|
||||
scripts/apidocgen/swaginit/main.go \
|
||||
$(wildcard scripts/apidocgen/postprocess/*) \
|
||||
$(wildcard scripts/apidocgen/markdown-template/*)
|
||||
tmpdir=$$(mktemp -d) && swagtmp=$$(mktemp -d) && \
|
||||
$(wildcard scripts/apidocgen/markdown-template/*) | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && swagtmp=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && swagtmp=$$(realpath "$$swagtmp") && \
|
||||
mkdir -p "$$tmpdir/reference/api" && \
|
||||
cp docs/manifest.json "$$tmpdir/manifest.json" && \
|
||||
SWAG_OUTPUT_DIR="$$swagtmp" APIDOCGEN_DOCS_DIR="$$tmpdir" ./scripts/apidocgen/generate.sh && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/reference/api/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/reference/api/*.md" && \
|
||||
./scripts/biome_format.sh "$$swagtmp/swagger.json" && \
|
||||
cp "$$tmpdir/reference/api/"*.md docs/reference/api/ && \
|
||||
cp "$$tmpdir/manifest.json" docs/manifest.json && \
|
||||
cp "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
|
||||
cp "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
|
||||
for f in "$$tmpdir/reference/api/"*.md; do mv "$$f" "docs/reference/api/$$(basename "$$f")"; done && \
|
||||
mv "$$tmpdir/manifest.json" _gen/manifest-staging.json && \
|
||||
mv "$$swagtmp/docs.go" coderd/apidoc/docs.go && \
|
||||
mv "$$swagtmp/swagger.json" coderd/apidoc/swagger.json && \
|
||||
rm -rf "$$tmpdir" "$$swagtmp"
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
tmpfile=$$(mktemp -d)/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md | _gen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
cp _gen/manifest-staging.json "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@"
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
touch "$@"
|
||||
@@ -1107,10 +1246,22 @@ else
|
||||
GOTESTSUM_RETRY_FLAGS :=
|
||||
endif
|
||||
|
||||
# default to 8x8 parallelism to avoid overwhelming our workspaces. Hopefully we can remove these defaults
|
||||
# when we get our test suite's resource utilization under control.
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests (from ~18GB to negligible).
|
||||
GOTEST_FLAGS := -tags=testsmallbatch -v -p $(or $(TEST_NUM_PARALLEL_PACKAGES),"8") -parallel=$(or $(TEST_NUM_PARALLEL_TESTS),"8")
|
||||
# Default to 8x8 parallelism to avoid overwhelming our workspaces.
|
||||
# Race detection defaults to 4x4 because the detector adds significant
|
||||
# CPU overhead. Override via TEST_NUM_PARALLEL_PACKAGES /
|
||||
# TEST_NUM_PARALLEL_TESTS.
|
||||
TEST_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),8)
|
||||
TEST_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),8)
|
||||
RACE_PARALLEL_PACKAGES := $(or $(TEST_NUM_PARALLEL_PACKAGES),4)
|
||||
RACE_PARALLEL_TESTS := $(or $(TEST_NUM_PARALLEL_TESTS),4)
|
||||
|
||||
# Use testsmallbatch tag to reduce wireguard memory allocation in tests
|
||||
# (from ~18GB to negligible). Recursively expanded so target-specific
|
||||
# overrides of TEST_PARALLEL_* take effect (e.g. test-race lowers
|
||||
# parallelism). CI job timeout is 25m (see test-go-pg in ci.yaml),
|
||||
# keep the Go timeout 5m shorter so tests produce goroutine dumps
|
||||
# instead of the CI runner killing the process with no output.
|
||||
GOTEST_FLAGS = -tags=testsmallbatch -v -timeout 20m -p $(TEST_PARALLEL_PACKAGES) -parallel=$(TEST_PARALLEL_TESTS)
|
||||
|
||||
# The most common use is to set TEST_COUNT=1 to avoid Go's test cache.
|
||||
ifdef TEST_COUNT
|
||||
@@ -1136,13 +1287,34 @@ endif
|
||||
TEST_PACKAGES ?= ./...
|
||||
|
||||
test:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet $(GOTESTSUM_RETRY_FLAGS) --packages="$(TEST_PACKAGES)" -- $(GOTEST_FLAGS)
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
$(GOTEST_FLAGS)
|
||||
.PHONY: test
|
||||
|
||||
test-race: TEST_PARALLEL_PACKAGES := $(RACE_PARALLEL_PACKAGES)
|
||||
test-race: TEST_PARALLEL_TESTS := $(RACE_PARALLEL_TESTS)
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --format standard-quiet \
|
||||
--junitfile="gotests.xml" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="$(TEST_PACKAGES)" \
|
||||
-- \
|
||||
-race \
|
||||
$(GOTEST_FLAGS)
|
||||
.PHONY: test-race
|
||||
|
||||
test-cli:
|
||||
$(MAKE) test TEST_PACKAGES="./cli..."
|
||||
.PHONY: test-cli
|
||||
|
||||
test-js: site/node_modules/.installed
|
||||
cd site/
|
||||
pnpm test:ci
|
||||
.PHONY: test-js
|
||||
|
||||
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
|
||||
# dependency for any sqlc-cloud related targets.
|
||||
sqlc-cloud-is-setup:
|
||||
@@ -1154,37 +1326,22 @@ sqlc-cloud-is-setup:
|
||||
|
||||
sqlc-push: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc push"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc push -f coderd/database/sqlc.yaml && echo "Passed sqlc push"
|
||||
.PHONY: sqlc-push
|
||||
|
||||
sqlc-verify: sqlc-cloud-is-setup test-postgres-docker
|
||||
echo "--- sqlc verify"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc verify -f coderd/database/sqlc.yaml && echo "Passed sqlc verify"
|
||||
.PHONY: sqlc-verify
|
||||
|
||||
sqlc-vet: test-postgres-docker
|
||||
echo "--- sqlc vet"
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$(shell go run scripts/migrate-ci/main.go)" \
|
||||
SQLC_DATABASE_URL="postgresql://postgres:postgres@localhost:5432/$$(go run scripts/migrate-ci/main.go)" \
|
||||
sqlc vet -f coderd/database/sqlc.yaml && echo "Passed sqlc vet"
|
||||
.PHONY: sqlc-vet
|
||||
|
||||
# When updating -timeout for this test, keep in sync with
|
||||
# test-go-postgres (.github/workflows/coder.yaml).
|
||||
# Do add coverage flags so that test caching works.
|
||||
test-postgres: test-postgres-docker
|
||||
# The postgres test is prone to failure, so we limit parallelism for
|
||||
# more consistent execution.
|
||||
$(GIT_FLAGS) gotestsum \
|
||||
--junitfile="gotests.xml" \
|
||||
--jsonfile="gotests.json" \
|
||||
$(GOTESTSUM_RETRY_FLAGS) \
|
||||
--packages="./..." -- \
|
||||
-tags=testsmallbatch \
|
||||
-timeout=20m \
|
||||
-count=1
|
||||
.PHONY: test-postgres
|
||||
|
||||
test-migrations: test-postgres-docker
|
||||
echo "--- test migrations"
|
||||
@@ -1200,13 +1357,24 @@ test-migrations: test-postgres-docker
|
||||
|
||||
# NOTE: we set --memory to the same size as a GitHub runner.
|
||||
test-postgres-docker:
|
||||
# If our container is already running, nothing to do.
|
||||
if docker ps --filter "name=test-postgres-docker-${POSTGRES_VERSION}" --format '{{.Names}}' | grep -q .; then \
|
||||
echo "test-postgres-docker-${POSTGRES_VERSION} is already running."; \
|
||||
exit 0; \
|
||||
fi
|
||||
# If something else is on 5432, warn but don't fail.
|
||||
if pg_isready -h 127.0.0.1 -q 2>/dev/null; then \
|
||||
echo "WARNING: PostgreSQL is already running on 127.0.0.1:5432 (not our container)."; \
|
||||
echo "Tests will use this instance. To use the Makefile's container, stop it first."; \
|
||||
exit 0; \
|
||||
fi
|
||||
docker rm -f test-postgres-docker-${POSTGRES_VERSION} || true
|
||||
|
||||
# Try pulling up to three times to avoid CI flakes.
|
||||
docker pull ${POSTGRES_IMAGE} || {
|
||||
retries=2
|
||||
for try in $(seq 1 ${retries}); do
|
||||
echo "Failed to pull image, retrying (${try}/${retries})..."
|
||||
for try in $$(seq 1 $${retries}); do
|
||||
echo "Failed to pull image, retrying ($${try}/$${retries})..."
|
||||
sleep 1
|
||||
if docker pull ${POSTGRES_IMAGE}; then
|
||||
break
|
||||
@@ -1247,16 +1415,11 @@ test-postgres-docker:
|
||||
-c log_statement=all
|
||||
while ! pg_isready -h 127.0.0.1
|
||||
do
|
||||
echo "$(date) - waiting for database to start"
|
||||
echo "$$(date) - waiting for database to start"
|
||||
sleep 0.5
|
||||
done
|
||||
.PHONY: test-postgres-docker
|
||||
|
||||
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
|
||||
test-race:
|
||||
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -tags=testsmallbatch -race -count=1 -parallel 4 -p 4 ./...
|
||||
.PHONY: test-race
|
||||
|
||||
test-tailnet-integration:
|
||||
env \
|
||||
CODER_TAILNET_TESTS=true \
|
||||
@@ -1285,6 +1448,7 @@ site/e2e/bin/coder: go.mod go.sum $(GO_SRC_FILES)
|
||||
|
||||
test-e2e: site/e2e/bin/coder site/node_modules/.installed site/out/index.html
|
||||
cd site/
|
||||
pnpm playwright:install
|
||||
ifdef CI
|
||||
DEBUG=pw:api pnpm playwright:test --forbid-only --workers 1
|
||||
else
|
||||
@@ -1299,3 +1463,5 @@ dogfood/coder/nix.hash: flake.nix flake.lock
|
||||
count-test-databases:
|
||||
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
|
||||
.PHONY: count-test-databases
|
||||
|
||||
.PHONY: count-test-databases
|
||||
|
||||
+20
-3
@@ -39,8 +39,10 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
@@ -102,6 +104,7 @@ type Options struct {
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
GitAPIOptions []agentgit.Option
|
||||
Clock quartz.Clock
|
||||
SocketServerEnabled bool
|
||||
SocketPath string // Path for the agent socket server socket
|
||||
@@ -217,6 +220,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
gitAPIOptions: options.GitAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
socketServerEnabled: options.SocketServerEnabled,
|
||||
boundaryLogProxySocketPath: options.BoundaryLogProxySocketPath,
|
||||
@@ -302,9 +306,12 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
gitAPIOptions []agentgit.Option
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -376,9 +383,15 @@ func (a *agent) init() {
|
||||
|
||||
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
|
||||
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -2049,6 +2062,10 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.desktopAPI.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -3040,6 +3040,62 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
func TestAgent_ReconnectNoLifecycleReemit(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
fCoordinator := tailnettest.NewFakeCoordinator()
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *proto.Stats, 50)
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
client := agenttest.NewClient(t,
|
||||
logger,
|
||||
agentID,
|
||||
agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
Scripts: []codersdk.WorkspaceAgentScript{{
|
||||
Script: "echo hello",
|
||||
Timeout: 30 * time.Second,
|
||||
RunOnStart: true,
|
||||
}},
|
||||
},
|
||||
statsCh,
|
||||
fCoordinator,
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
closer := agent.New(agent.Options{
|
||||
Client: client,
|
||||
Logger: logger.Named("agent"),
|
||||
})
|
||||
defer closer.Close()
|
||||
|
||||
// Wait for the agent to reach Ready state.
|
||||
require.Eventually(t, func() bool {
|
||||
return slices.Contains(client.GetLifecycleStates(), codersdk.WorkspaceAgentLifecycleReady)
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
statesBefore := slices.Clone(client.GetLifecycleStates())
|
||||
|
||||
// Disconnect by closing the coordinator response channel.
|
||||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
close(call1.Resps)
|
||||
|
||||
// Wait for reconnect.
|
||||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
|
||||
// Wait for a stats report as a deterministic steady-state proof.
|
||||
testutil.RequireReceive(ctx, t, statsCh)
|
||||
|
||||
statesAfter := client.GetLifecycleStates()
|
||||
require.Equal(t, statesBefore, statesAfter,
|
||||
"lifecycle states should not be re-reported after reconnect")
|
||||
|
||||
closer.Close()
|
||||
}
|
||||
|
||||
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// DesktopAction is the request body for the desktop action endpoint.
|
||||
type DesktopAction struct {
|
||||
Action string `json:"action"`
|
||||
Coordinate *[2]int `json:"coordinate,omitempty"`
|
||||
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
// ScaledWidth and ScaledHeight are the coordinate space the
|
||||
// model is using. When provided, coordinates are linearly
|
||||
// mapped from scaled → native before dispatching.
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
|
||||
// DesktopActionResponse is the response from the desktop action
|
||||
// endpoint.
|
||||
type DesktopActionResponse struct {
|
||||
Output string `json:"output,omitempty"`
|
||||
ScreenshotData string `json:"screenshot_data,omitempty"`
|
||||
ScreenshotWidth int `json:"screenshot_width,omitempty"`
|
||||
ScreenshotHeight int `json:"screenshot_height,omitempty"`
|
||||
}
|
||||
|
||||
// API exposes the desktop streaming HTTP routes for the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
desktop Desktop
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewAPI creates a new desktop streaming API.
|
||||
func NewAPI(logger slog.Logger, desktop Desktop, clock quartz.Clock) *API {
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
return &API{
|
||||
logger: logger,
|
||||
desktop: desktop,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the chi router for mounting at /api/v0/desktop.
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/vnc", a.handleDesktopVNC)
|
||||
r.Post("/action", a.handleAction)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Start the desktop session (idempotent).
|
||||
_, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get a VNC connection.
|
||||
vncConn, err := a.desktop.VNCConn(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to connect to VNC server.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Accept WebSocket from coderd.
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// No read limit — RFB framebuffer updates can be large.
|
||||
conn.SetReadLimit(-1)
|
||||
|
||||
wsCtx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
// Bicopy raw bytes between WebSocket and VNC TCP.
|
||||
agentssh.Bicopy(wsCtx, wsNetConn, vncConn)
|
||||
}
|
||||
|
||||
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
|
||||
slog.Error(err),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var action DesktopAction
|
||||
if err := json.NewDecoder(r.Body).Decode(&action); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to decode request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Info(ctx, "handleAction: started",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
|
||||
// Helper to scale a coordinate pair from the model's space to
|
||||
// native display pixels.
|
||||
scaleXY := func(x, y int) (int, int) {
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
|
||||
}
|
||||
return x, y
|
||||
}
|
||||
|
||||
var resp DesktopActionResponse
|
||||
|
||||
switch action.Action {
|
||||
case "key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.KeyPress(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key press failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "key action performed"
|
||||
|
||||
case "type":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for type action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.Type(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Type action failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "type action performed"
|
||||
|
||||
case "cursor_position":
|
||||
x, y, err := a.desktop.CursorPosition(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Cursor position failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
|
||||
|
||||
case "mouse_move":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Move(ctx, x, y); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Mouse move failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "mouse_move action performed"
|
||||
|
||||
case "left_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
stepStart := a.clock.Now()
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: Click failed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step", "click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "handleAction: Click completed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
resp.Output = "left_click action performed"
|
||||
|
||||
case "left_click_drag":
|
||||
if action.Coordinate == nil || action.StartCoordinate == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"coordinate\" or \"start_coordinate\" for left_click_drag.",
|
||||
})
|
||||
return
|
||||
}
|
||||
sx, sy := scaleXY(action.StartCoordinate[0], action.StartCoordinate[1])
|
||||
ex, ey := scaleXY(action.Coordinate[0], action.Coordinate[1])
|
||||
if err := a.desktop.Drag(ctx, sx, sy, ex, ey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click drag failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_click_drag action performed"
|
||||
|
||||
case "left_mouse_down":
|
||||
if err := a.desktop.ButtonDown(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_down action performed"
|
||||
|
||||
case "left_mouse_up":
|
||||
if err := a.desktop.ButtonUp(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_up action performed"
|
||||
|
||||
case "right_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonRight); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Right click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "right_click action performed"
|
||||
|
||||
case "middle_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonMiddle); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Middle click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "middle_click action performed"
|
||||
|
||||
case "double_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.DoubleClick(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Double click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "double_click action performed"
|
||||
|
||||
case "triple_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
for range 3 {
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Triple click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.Output = "triple_click action performed"
|
||||
|
||||
case "scroll":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
|
||||
amount := 3
|
||||
if action.ScrollAmount != nil {
|
||||
amount = *action.ScrollAmount
|
||||
}
|
||||
direction := "down"
|
||||
if action.ScrollDirection != nil {
|
||||
direction = *action.ScrollDirection
|
||||
}
|
||||
|
||||
var dx, dy int
|
||||
switch direction {
|
||||
case "up":
|
||||
dy = -amount
|
||||
case "down":
|
||||
dy = amount
|
||||
case "left":
|
||||
dx = -amount
|
||||
case "right":
|
||||
dx = amount
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid scroll direction: " + direction,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.desktop.Scroll(ctx, x, y, dx, dy); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Scroll failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "scroll action performed"
|
||||
|
||||
case "hold_key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for hold_key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
dur := 1000
|
||||
if action.Duration != nil {
|
||||
dur = *action.Duration
|
||||
}
|
||||
if err := a.desktop.KeyDown(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
timer := a.clock.NewTimer(time.Duration(dur)*time.Millisecond, "agentdesktop", "hold_key")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context canceled; release the key immediately.
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "hold_key action performed"
|
||||
|
||||
case "screenshot":
|
||||
var opts ScreenshotOptions
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
opts.TargetWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
opts.TargetHeight = *action.ScaledHeight
|
||||
}
|
||||
result, err := a.desktop.Screenshot(ctx, opts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Screenshot failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "screenshot"
|
||||
resp.ScreenshotData = result.Data
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
|
||||
resp.ScreenshotWidth = *action.ScaledWidth
|
||||
} else {
|
||||
resp.ScreenshotWidth = cfg.Width
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
|
||||
resp.ScreenshotHeight = *action.ScaledHeight
|
||||
} else {
|
||||
resp.ScreenshotHeight = cfg.Height
|
||||
}
|
||||
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown action: " + action.Action,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
|
||||
if ctx.Err() != nil {
|
||||
a.logger.Error(ctx, "handleAction: context canceled before writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return
|
||||
}
|
||||
a.logger.Info(ctx, "handleAction: writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session if one is running.
|
||||
func (a *API) Close() error {
|
||||
return a.desktop.Close()
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
// returning an error if the coordinate field is missing.
|
||||
func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
if action.Coordinate == nil {
|
||||
return 0, 0, &missingFieldError{field: "coordinate", action: action.Action}
|
||||
}
|
||||
return action.Coordinate[0], action.Coordinate[1], nil
|
||||
}
|
||||
|
||||
// missingFieldError is returned when a required field is absent from
|
||||
// a DesktopAction.
|
||||
type missingFieldError struct {
|
||||
field string
|
||||
action string
|
||||
}
|
||||
|
||||
func (e *missingFieldError) Error() string {
|
||||
return "Missing \"" + e.field + "\" for " + e.action + " action."
|
||||
}
|
||||
|
||||
// scaleCoordinate maps a coordinate from scaled → native space.
|
||||
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
|
||||
if scaledDim == 0 || scaledDim == nativeDim {
|
||||
return scaled
|
||||
}
|
||||
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
|
||||
// Clamp to valid range.
|
||||
native = math.Max(native, 0)
|
||||
native = math.Min(native, float64(nativeDim-1))
|
||||
return int(native)
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package agentdesktop_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
|
||||
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
|
||||
// fakeDesktop is a minimal Desktop implementation for unit tests.
|
||||
type fakeDesktop struct {
|
||||
startErr error
|
||||
startCfg agentdesktop.DisplayConfig
|
||||
vncConnErr error
|
||||
screenshotErr error
|
||||
screenshotRes agentdesktop.ScreenshotResult
|
||||
closed bool
|
||||
|
||||
// Track calls for assertions.
|
||||
lastMove [2]int
|
||||
lastClick [3]int // x, y, button
|
||||
lastScroll [4]int // x, y, dx, dy
|
||||
lastKey string
|
||||
lastTyped string
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
|
||||
return f.startCfg, f.startErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
|
||||
return nil, f.vncConnErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
return f.screenshotRes, f.screenshotErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
|
||||
f.lastMove = [2]int{x, y}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 1}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 2}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
|
||||
f.lastScroll = [4]int{x, y, dx, dy}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
|
||||
f.lastKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
|
||||
f.lastKeyDown = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
|
||||
f.lastKeyUp = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Type(_ context.Context, text string) error {
|
||||
f.lastTyped = text
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return 10, 20, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Failed to start desktop session.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "screenshot"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
// Dimensions come from DisplayConfig, not the screenshot CLI.
|
||||
assert.Equal(t, "screenshot", result.Output)
|
||||
assert.Equal(t, "base64data", result.ScreenshotData)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
|
||||
}
|
||||
|
||||
func TestHandleAction_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "left_click",
|
||||
Coordinate: &[2]int{100, 200},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "left_click action performed", resp.Output)
|
||||
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
|
||||
}
|
||||
|
||||
func TestHandleAction_UnknownAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "explode"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestHandleAction_KeyAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "Return"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "key",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "Return", fake.lastKey)
|
||||
}
|
||||
|
||||
func TestHandleAction_TypeAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "hello world"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "type",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "hello world", fake.lastTyped)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
mClk := quartz.NewMock(t)
|
||||
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
|
||||
defer trap.Close()
|
||||
api := agentdesktop.NewAPI(logger, fake, mClk)
|
||||
defer api.Close()
|
||||
|
||||
text := "Shift_L"
|
||||
dur := 100
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "hold_key",
|
||||
Text: &text,
|
||||
Duration: &dur,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
handler.ServeHTTP(rr, req)
|
||||
}()
|
||||
|
||||
// Wait for the timer to be created, then advance past it.
|
||||
trap.MustWait(req.Context()).MustRelease(req.Context())
|
||||
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
|
||||
|
||||
<-done
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hold_key action performed", resp.Output)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyDown)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyUp)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "hold_key"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_ScrollDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
dir := "down"
|
||||
amount := 5
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "scroll",
|
||||
Coordinate: &[2]int{500, 400},
|
||||
ScrollDirection: &dir,
|
||||
ScrollAmount: &amount,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// dy should be positive 5 for "down".
|
||||
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
|
||||
}
|
||||
|
||||
func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
// Native display is 1920x1080.
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
// Model is working in a 1280x720 coordinate space.
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "mouse_move",
|
||||
Coordinate: &[2]int{640, 360},
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
|
||||
// midpoint).
|
||||
assert.Equal(t, 960, fake.lastMove[0])
|
||||
assert.Equal(t, 540, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestClose_DelegatesToDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, fake.closed)
|
||||
}
|
||||
|
||||
func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// After Close(), Start() will return an error because the
|
||||
// underlying Desktop is closed.
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the closed desktop returning an error on Start().
|
||||
fake.startErr = xerrors.New("desktop is closed")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Desktop abstracts a virtual desktop session running inside a workspace.
|
||||
type Desktop interface {
|
||||
// Start launches the desktop session. It is idempotent — calling
|
||||
// Start on an already-running session returns the existing
|
||||
// config. The returned DisplayConfig describes the running
|
||||
// session.
|
||||
Start(ctx context.Context) (DisplayConfig, error)
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames. Each call returns a new
|
||||
// connection; multiple clients can connect simultaneously.
|
||||
// Start must be called before VNCConn.
|
||||
VNCConn(ctx context.Context) (net.Conn, error)
|
||||
|
||||
// Screenshot captures the current framebuffer as a PNG and
|
||||
// returns it base64-encoded. TargetWidth/TargetHeight in opts
|
||||
// are the desired output dimensions (the implementation
|
||||
// rescales); pass 0 to use native resolution.
|
||||
Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error)
|
||||
|
||||
// Mouse operations.
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
Move(ctx context.Context, x, y int) error
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
Click(ctx context.Context, x, y int, button MouseButton) error
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
DoubleClick(ctx context.Context, x, y int, button MouseButton) error
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
ButtonDown(ctx context.Context, button MouseButton) error
|
||||
// ButtonUp releases a mouse button.
|
||||
ButtonUp(ctx context.Context, button MouseButton) error
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
Scroll(ctx context.Context, x, y, dx, dy int) error
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding
|
||||
// the left mouse button.
|
||||
Drag(ctx context.Context, startX, startY, endX, endY int) error
|
||||
|
||||
// Keyboard operations.
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string
|
||||
// (e.g. "Return", "ctrl+c").
|
||||
KeyPress(ctx context.Context, keys string) error
|
||||
// KeyDown presses and holds a key.
|
||||
KeyDown(ctx context.Context, key string) error
|
||||
// KeyUp releases a key.
|
||||
KeyUp(ctx context.Context, key string) error
|
||||
// Type types a string of text character-by-character.
|
||||
Type(ctx context.Context, text string) error
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
CursorPosition(ctx context.Context) (x, y int, err error)
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
type DisplayConfig struct {
|
||||
Width int // native width in pixels
|
||||
Height int // native height in pixels
|
||||
VNCPort int // local TCP port for the VNC server
|
||||
Display int // X11 display number (e.g. 1 for :1), -1 if N/A
|
||||
}
|
||||
|
||||
// MouseButton identifies a mouse button.
|
||||
type MouseButton string
|
||||
|
||||
const (
|
||||
MouseButtonLeft MouseButton = "left"
|
||||
MouseButtonRight MouseButton = "right"
|
||||
MouseButtonMiddle MouseButton = "middle"
|
||||
)
|
||||
|
||||
// ScreenshotOptions configures a screenshot capture.
|
||||
type ScreenshotOptions struct {
|
||||
TargetWidth int // 0 = native
|
||||
TargetHeight int // 0 = native
|
||||
}
|
||||
|
||||
// ScreenshotResult is a captured screenshot.
|
||||
type ScreenshotResult struct {
|
||||
Data string // base64-encoded PNG
|
||||
}
|
||||
@@ -0,0 +1,544 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
portableDesktopVersion = "v0.0.4"
|
||||
downloadRetries = 3
|
||||
downloadRetryDelay = time.Second
|
||||
)
|
||||
|
||||
// platformBinaries maps GOARCH to download URL and expected SHA-256
|
||||
// digest for each supported platform.
|
||||
var platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
"amd64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
|
||||
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
|
||||
},
|
||||
"arm64": {
|
||||
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
|
||||
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
|
||||
},
|
||||
}
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
VNCPort int `json:"vncPort"`
|
||||
Geometry string `json:"geometry"` // e.g. "1920x1080"
|
||||
}
|
||||
|
||||
// desktopSession tracks a running portabledesktop process.
|
||||
type desktopSession struct {
|
||||
cmd *exec.Cmd
|
||||
vncPort int
|
||||
width int // native width, parsed from geometry
|
||||
height int // native height, parsed from geometry
|
||||
display int // X11 display number, -1 if not available
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
|
||||
type cursorOutput struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
}
|
||||
|
||||
// screenshotOutput is the JSON output from
|
||||
// `portabledesktop screenshot --json`.
|
||||
type screenshotOutput struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
dataDir string // agent's ScriptDataDir, used for binary caching
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
|
||||
// httpClient is used for downloading the binary. If nil,
|
||||
// http.DefaultClient is used.
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. dataDir is used
|
||||
// to cache the downloaded binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
dataDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
}
|
||||
|
||||
// httpDo returns the HTTP client to use for downloads.
|
||||
func (p *portableDesktop) httpDo() *http.Client {
|
||||
if p.httpClient != nil {
|
||||
return p.httpClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return DisplayConfig{}, xerrors.New("desktop is closed")
|
||||
}
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
|
||||
}
|
||||
|
||||
// If we have an existing session, check if it's still alive.
|
||||
if p.session != nil {
|
||||
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
|
||||
return DisplayConfig{
|
||||
Width: p.session.width,
|
||||
Height: p.session.height,
|
||||
VNCPort: p.session.vncPort,
|
||||
Display: p.session.display,
|
||||
}, nil
|
||||
}
|
||||
// Process died — clean up and recreate.
|
||||
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
|
||||
p.session.cancel()
|
||||
p.session = nil
|
||||
}
|
||||
|
||||
// Spawn portabledesktop up --json.
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON output to get VNC port and geometry.
|
||||
var output portableDesktopOutput
|
||||
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
|
||||
}
|
||||
|
||||
if output.VNCPort == 0 {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
|
||||
}
|
||||
|
||||
var w, h int
|
||||
if output.Geometry != "" {
|
||||
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
|
||||
slog.F("geometry", output.Geometry),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info(ctx, "started portabledesktop session",
|
||||
slog.F("vnc_port", output.VNCPort),
|
||||
slog.F("width", w),
|
||||
slog.F("height", h),
|
||||
slog.F("pid", cmd.Process.Pid),
|
||||
)
|
||||
|
||||
p.session = &desktopSession{
|
||||
cmd: cmd,
|
||||
vncPort: output.VNCPort,
|
||||
width: w,
|
||||
height: h,
|
||||
display: -1,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
return DisplayConfig{
|
||||
Width: w,
|
||||
Height: h,
|
||||
VNCPort: output.VNCPort,
|
||||
Display: -1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames.
|
||||
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
|
||||
p.mu.Lock()
|
||||
session := p.session
|
||||
p.mu.Unlock()
|
||||
|
||||
if session == nil {
|
||||
return nil, xerrors.New("desktop session not started")
|
||||
}
|
||||
|
||||
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
|
||||
}
|
||||
|
||||
// Screenshot captures the current framebuffer as a base64-encoded PNG.
|
||||
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
|
||||
args := []string{"screenshot", "--json"}
|
||||
if opts.TargetWidth > 0 {
|
||||
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
|
||||
}
|
||||
if opts.TargetHeight > 0 {
|
||||
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
|
||||
}
|
||||
|
||||
out, err := p.runCmd(ctx, args...)
|
||||
if err != nil {
|
||||
return ScreenshotResult{}, err
|
||||
}
|
||||
|
||||
var result screenshotOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
|
||||
}
|
||||
|
||||
return ScreenshotResult(result), nil
|
||||
}
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
|
||||
return err
|
||||
}
|
||||
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "down", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonUp releases a mouse button.
|
||||
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
|
||||
return err
|
||||
}
|
||||
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding the
|
||||
// left mouse button.
|
||||
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string.
|
||||
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "key", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyDown presses and holds a key.
|
||||
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "down", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyUp releases a key.
|
||||
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "up", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Type types a string of text character-by-character.
|
||||
func (p *portableDesktop) Type(ctx context.Context, text string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "type", text)
|
||||
return err
|
||||
}
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
|
||||
out, err := p.runCmd(ctx, "cursor", "--json")
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var result cursorOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
|
||||
}
|
||||
|
||||
return result.X, result.Y, nil
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.closed = true
|
||||
if p.session != nil {
|
||||
p.session.cancel()
|
||||
// Xvnc is a child process — killing it cleans up the X
|
||||
// session.
|
||||
_ = p.session.cmd.Process.Kill()
|
||||
_ = p.session.cmd.Wait()
|
||||
p.session = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runCmd executes a portabledesktop subcommand and returns combined
|
||||
// output. The caller must have previously called ensureBinary.
|
||||
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
|
||||
start := time.Now()
|
||||
//nolint:gosec // args are constructed by the caller, not user input.
|
||||
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "portabledesktop command failed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
slog.Error(err),
|
||||
slog.F("output", string(out)),
|
||||
)
|
||||
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
p.logger.Warn(ctx, "portabledesktop command slow",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
} else {
|
||||
p.logger.Debug(ctx, "portabledesktop command completed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves or downloads the portabledesktop binary. It
|
||||
// must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. Check PATH.
|
||||
if path, err := exec.LookPath("portabledesktop"); err == nil {
|
||||
p.logger.Info(ctx, "found portabledesktop in PATH",
|
||||
slog.F("path", path),
|
||||
)
|
||||
p.binPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Platform checks.
|
||||
if runtime.GOOS != "linux" {
|
||||
return xerrors.New("portabledesktop is only supported on Linux")
|
||||
}
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
// 3. Check cache.
|
||||
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
|
||||
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
|
||||
// Verify it is executable.
|
||||
if info.Mode()&0o100 != 0 {
|
||||
p.logger.Info(ctx, "using cached portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
p.binPath = cachedPath
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Download with retry.
|
||||
p.logger.Info(ctx, "downloading portabledesktop binary",
|
||||
slog.F("url", bin.URL),
|
||||
slog.F("version", portableDesktopVersion),
|
||||
slog.F("arch", runtime.GOARCH),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := range downloadRetries {
|
||||
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
|
||||
lastErr = err
|
||||
p.logger.Warn(ctx, "download attempt failed",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", downloadRetries),
|
||||
slog.Error(err),
|
||||
)
|
||||
if attempt < downloadRetries-1 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(downloadRetryDelay):
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
p.binPath = cachedPath
|
||||
p.logger.Info(ctx, "downloaded portabledesktop binary",
|
||||
slog.F("path", cachedPath),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
|
||||
}
|
||||
|
||||
// downloadBinary fetches a binary from url, verifies its SHA-256
|
||||
// digest matches expectedSHA256, and atomically writes it to destPath.
|
||||
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
|
||||
return xerrors.Errorf("create cache directory: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create HTTP request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("HTTP GET %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Write to a temp file in the same directory so the final rename
|
||||
// is atomic on the same filesystem.
|
||||
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Clean up the temp file on any error path.
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Stream the response body while computing SHA-256.
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return xerrors.Errorf("download body: %w", err)
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return xerrors.Errorf("close temp file: %w", err)
|
||||
}
|
||||
|
||||
// Verify digest.
|
||||
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualSHA256 != expectedSHA256 {
|
||||
return xerrors.Errorf(
|
||||
"SHA-256 mismatch: expected %s, got %s",
|
||||
expectedSHA256, actualSHA256,
|
||||
)
|
||||
}
|
||||
|
||||
if err := os.Chmod(tmpPath, 0o700); err != nil {
|
||||
return xerrors.Errorf("chmod: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
return xerrors.Errorf("rename to final path: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,713 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
|
||||
// recordedExecer implements agentexec.Execer by recording every
|
||||
// invocation and delegating to a real shell command built from a
|
||||
// caller-supplied mapping of subcommand → shell script body.
|
||||
type recordedExecer struct {
|
||||
mu sync.Mutex
|
||||
commands [][]string
|
||||
// scripts maps a subcommand keyword (e.g. "up", "screenshot")
|
||||
// to a shell snippet whose stdout will be the command output.
|
||||
scripts map[string]string
|
||||
}
|
||||
|
||||
func (r *recordedExecer) record(cmd string, args ...string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.commands = append(r.commands, append([]string{cmd}, args...))
|
||||
}
|
||||
|
||||
func (r *recordedExecer) allCommands() [][]string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([][]string, len(r.commands))
|
||||
copy(out, r.commands)
|
||||
return out
|
||||
}
|
||||
|
||||
// scriptFor finds the first matching script key present in args.
|
||||
func (r *recordedExecer) scriptFor(args []string) string {
|
||||
for _, a := range args {
|
||||
if s, ok := r.scripts[a]; ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: succeed silently.
|
||||
return "true"
|
||||
}
|
||||
|
||||
func (r *recordedExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||
r.record(cmd, args...)
|
||||
script := r.scriptFor(args)
|
||||
//nolint:gosec // Test helper — script content is controlled by the test.
|
||||
return exec.CommandContext(ctx, "sh", "-c", script)
|
||||
}
|
||||
|
||||
func (r *recordedExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||
r.record(cmd, args...)
|
||||
return pty.CommandContext(ctx, "sh", "-c", r.scriptFor(args))
|
||||
}
|
||||
|
||||
// --- portableDesktop tests ---
|
||||
|
||||
func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1920, cfg.Width)
|
||||
assert.Equal(t, 1080, cfg.Height)
|
||||
assert.Equal(t, 5901, cfg.VNCPort)
|
||||
assert.Equal(t, -1, cfg.Display)
|
||||
|
||||
// Clean up the long-running process.
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg1, cfg2, "second Start should return the same config")
|
||||
|
||||
// The execer should have been called exactly once for "up".
|
||||
cmds := rec.allCommands()
|
||||
upCalls := 0
|
||||
for _, c := range cmds {
|
||||
for _, a := range c {
|
||||
if a == "up" {
|
||||
upCalls++
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, upCalls, "expected exactly one 'up' invocation")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"abc123"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc123", result.Data)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
dataDir := t.TempDir()
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"x"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: dataDir,
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
// The last command should contain the target dimension flags.
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
assert.Contains(t, joined, "--target-width 800")
|
||||
assert.Contains(t, joined, "--target-height 600")
|
||||
}
|
||||
|
||||
func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Each sub-test verifies a single mouse method dispatches the
|
||||
// correct CLI arguments.
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string // substrings expected in a recorded command
|
||||
}{
|
||||
{
|
||||
name: "Move",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Move(ctx, 42, 99)
|
||||
},
|
||||
wantArgs: []string{"mouse", "move", "42", "99"},
|
||||
},
|
||||
{
|
||||
name: "Click",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Click(ctx, 10, 20, MouseButtonLeft)
|
||||
},
|
||||
// Click does move then click.
|
||||
wantArgs: []string{"mouse", "click", "left"},
|
||||
},
|
||||
{
|
||||
name: "DoubleClick",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.DoubleClick(ctx, 5, 6, MouseButtonRight)
|
||||
},
|
||||
wantArgs: []string{"mouse", "click", "right"},
|
||||
},
|
||||
{
|
||||
name: "ButtonDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonDown(ctx, MouseButtonMiddle)
|
||||
},
|
||||
wantArgs: []string{"mouse", "down", "middle"},
|
||||
},
|
||||
{
|
||||
name: "ButtonUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonUp(ctx, MouseButtonLeft)
|
||||
},
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
{
|
||||
name: "Scroll",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Scroll(ctx, 50, 60, 3, 4)
|
||||
},
|
||||
wantArgs: []string{"mouse", "scroll", "3", "4"},
|
||||
},
|
||||
{
|
||||
name: "Drag",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Drag(ctx, 10, 20, 30, 40)
|
||||
},
|
||||
// Drag ends with mouse up left.
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"mouse": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds, "expected at least one command")
|
||||
|
||||
// Find at least one recorded command that contains
|
||||
// all expected argument substrings.
|
||||
found := false
|
||||
for _, cmd := range cmds {
|
||||
joined := strings.Join(cmd, " ")
|
||||
match := true
|
||||
for _, want := range tt.wantArgs {
|
||||
if !strings.Contains(joined, want) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found,
|
||||
"no recorded command matched %v; got %v", tt.wantArgs, cmds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string
|
||||
}{
|
||||
{
|
||||
name: "KeyPress",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyPress(ctx, "Return")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "key", "Return"},
|
||||
},
|
||||
{
|
||||
name: "KeyDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyDown(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "down", "shift"},
|
||||
},
|
||||
{
|
||||
name: "KeyUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyUp(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "up", "shift"},
|
||||
},
|
||||
{
|
||||
name: "Type",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Type(ctx, "hello world")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "type", "hello world"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"keyboard": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(context.Background(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
for _, want := range tt.wantArgs {
|
||||
assert.Contains(t, joined, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"cursor": `echo '{"x":100,"y":200}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1024x768"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should exist.
|
||||
pd.mu.Lock()
|
||||
require.NotNil(t, pd.session)
|
||||
pd.mu.Unlock()
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
// Session should be cleaned up.
|
||||
pd.mu.Lock()
|
||||
assert.Nil(t, pd.session)
|
||||
assert.True(t, pd.closed)
|
||||
pd.mu.Unlock()
|
||||
|
||||
// Subsequent Start must fail.
|
||||
_, err = pd.Start(ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- downloadBinary tests ---
|
||||
|
||||
func TestDownloadBinary_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho portable\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the file exists and has correct content.
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
// Verify executable permissions.
|
||||
info, err := os.Stat(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("real binary content"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SHA-256 mismatch")
|
||||
|
||||
// The destination file should not exist (temp file cleaned up).
|
||||
_, statErr := os.Stat(destPath)
|
||||
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
|
||||
|
||||
// No leftover temp files in the directory.
|
||||
entries, err := os.ReadDir(destDir)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, entries, "no leftover temp files should remain")
|
||||
}
|
||||
|
||||
func TestDownloadBinary_HTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 404")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When binPath is already set, ensureBinary should return
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
bin, ok := platformBinaries[runtime.GOARCH]
|
||||
if !ok {
|
||||
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
|
||||
|
||||
cachedPath := filepath.Join(cacheDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, cachedPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_Downloads(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment and we override the package-level platformBinaries.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Save and restore platformBinaries for this test.
|
||||
origBinaries := platformBinaries
|
||||
platformBinaries = map[string]struct {
|
||||
URL string
|
||||
SHA256 string
|
||||
}{
|
||||
runtime.GOARCH: {
|
||||
URL: srv.URL + "/portabledesktop",
|
||||
SHA256: expectedSHA,
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() { platformBinaries = origBinaries })
|
||||
|
||||
dataDir := t.TempDir()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
dataDir: dataDir,
|
||||
httpClient: srv.Client(),
|
||||
}
|
||||
|
||||
// Ensure PATH doesn't contain a real portabledesktop binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
|
||||
assert.Equal(t, expectedPath, pd.binPath)
|
||||
|
||||
// Verify the downloaded file has correct content.
|
||||
got, err := os.ReadFile(expectedPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("portabledesktop is only supported on Linux")
|
||||
}
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho retried\n")
|
||||
hash := sha256.Sum256(binaryContent)
|
||||
expectedSHA := hex.EncodeToString(hash[:])
|
||||
|
||||
var mu sync.Mutex
|
||||
attempt := 0
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
mu.Lock()
|
||||
current := attempt
|
||||
attempt++
|
||||
mu.Unlock()
|
||||
|
||||
// Fail the first 2 attempts, succeed on the third.
|
||||
if current < 2 {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(binaryContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Test downloadBinary directly to avoid time.Sleep in
|
||||
// ensureBinary's retry loop. We call it 3 times to simulate
|
||||
// what ensureBinary would do.
|
||||
destDir := t.TempDir()
|
||||
destPath := filepath.Join(destDir, "portabledesktop")
|
||||
|
||||
var lastErr error
|
||||
for i := range 3 {
|
||||
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
if i < 2 {
|
||||
// In the real code, ensureBinary sleeps here.
|
||||
// We skip the sleep in tests.
|
||||
continue
|
||||
}
|
||||
}
|
||||
require.NoError(t, lastErr, "download should succeed on the third attempt")
|
||||
|
||||
got, err := os.ReadFile(destPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryContent, got)
|
||||
|
||||
mu.Lock()
|
||||
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
|
||||
// Silence the linter about unused imports — agentexec.DefaultExecer
|
||||
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
|
||||
// fmt.Sscanf is used indirectly via the implementation.
|
||||
var (
|
||||
_ = agentexec.DefaultExecer
|
||||
_ = fmt.Sprintf
|
||||
)
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
"github.com/spf13/afero"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
)
|
||||
|
||||
// API exposes file-related operations performed through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
filesystem afero.Fs
|
||||
pathStore *agentgit.PathStore
|
||||
}
|
||||
|
||||
func NewAPI(logger slog.Logger, filesystem afero.Fs) *API {
|
||||
func NewAPI(logger slog.Logger, filesystem afero.Fs, pathStore *agentgit.PathStore) *API {
|
||||
api := &API{
|
||||
logger: logger,
|
||||
filesystem: filesystem,
|
||||
pathStore: pathStore,
|
||||
}
|
||||
return api
|
||||
}
|
||||
|
||||
@@ -13,10 +13,12 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -301,6 +303,13 @@ func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track edited path for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), []string{path})
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: fmt.Sprintf("Successfully wrote to %q", path),
|
||||
})
|
||||
@@ -380,6 +389,17 @@ func (api *API) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track edited paths for git watch.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
filePaths := make([]string, 0, len(req.Files))
|
||||
for _, f := range req.Files {
|
||||
filePaths = append(filePaths, f.Path)
|
||||
}
|
||||
api.pathStore.AddPaths(append([]uuid.UUID{chatID}, ancestorIDs...), filePaths)
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Successfully edited file(s)",
|
||||
})
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -21,6 +24,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -116,7 +120,7 @@ func TestReadFile(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -296,7 +300,7 @@ func TestWriteFile(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -414,7 +418,7 @@ func TestEditFiles(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
@@ -838,6 +842,169 @@ func TestEditFiles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
chatID := uuid.New()
|
||||
ancestorID := uuid.New()
|
||||
ancestorJSON, _ := json.Marshal([]string{ancestorID.String()})
|
||||
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
req.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, string(ancestorJSON))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify PathStore was updated for both chat and ancestor.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Equal(t, []string{testPath}, paths)
|
||||
|
||||
ancestorPaths := pathStore.GetPaths(ancestorID)
|
||||
require.Equal(t, []string{testPath}, ancestorPaths)
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_NoChatHeaders_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path="+testPath, body)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// PathStore should be globally empty since no chat headers were set.
|
||||
require.Equal(t, 0, pathStore.Len())
|
||||
}
|
||||
|
||||
func TestHandleWriteFile_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
chatID := uuid.New()
|
||||
|
||||
// Write to a relative path (should fail with 400).
|
||||
body := strings.NewReader("hello world")
|
||||
req := httptest.NewRequest(http.MethodPost, "/write-file?path=relative/path.txt", body)
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
// PathStore should NOT be updated on failure.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Empty(t, paths)
|
||||
}
|
||||
|
||||
func TestHandleEditFiles_ChatHeaders_UpdatesPathStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
testPath := filepath.Join(os.TempDir(), "test.txt")
|
||||
|
||||
// Create the file first.
|
||||
require.NoError(t, afero.WriteFile(fs, testPath, []byte("hello"), 0o644))
|
||||
|
||||
chatID := uuid.New()
|
||||
editReq := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: testPath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{Search: "hello", Replace: "world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(editReq)
|
||||
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Equal(t, []string{testPath}, paths)
|
||||
}
|
||||
|
||||
func TestHandleEditFiles_Failure_NoPathStoreUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
logger := slogtest.Make(t, nil)
|
||||
fs := afero.NewMemMapFs()
|
||||
api := agentfiles.NewAPI(logger, fs, pathStore)
|
||||
|
||||
chatID := uuid.New()
|
||||
|
||||
// Edit a non-existent file (should fail with 404).
|
||||
editReq := workspacesdk.FileEditRequest{
|
||||
Files: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: "/nonexistent/file.txt",
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{Search: "hello", Replace: "world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(editReq)
|
||||
req := httptest.NewRequest(http.MethodPost, "/edit-files", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r := chi.NewRouter()
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
r.ServeHTTP(rr, req)
|
||||
|
||||
require.NotEqual(t, http.StatusOK, rr.Code)
|
||||
|
||||
// PathStore should NOT be updated on failure.
|
||||
paths := pathStore.GetPaths(chatID)
|
||||
require.Empty(t, paths)
|
||||
}
|
||||
|
||||
func TestReadFileLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -851,7 +1018,7 @@ func TestReadFileLines(t *testing.T) {
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
api := agentfiles.NewAPI(logger, fs, nil)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory-lines")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
|
||||
@@ -0,0 +1,441 @@
|
||||
// Package agentgit provides a WebSocket-based service for watching git
|
||||
// repository changes on the agent. It is mounted at /api/v0/git/watch
|
||||
// and allows clients to subscribe to file paths, triggering scans of
|
||||
// the corresponding git repositories.
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Option configures the git watch service.
|
||||
type Option func(*Handler)
|
||||
|
||||
// WithClock sets a controllable clock for testing. Defaults to
|
||||
// quartz.NewReal().
|
||||
func WithClock(c quartz.Clock) Option {
|
||||
return func(h *Handler) {
|
||||
h.clock = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithGitBinary overrides the git binary path (for testing).
|
||||
func WithGitBinary(path string) Option {
|
||||
return func(h *Handler) {
|
||||
h.gitBin = path
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// scanCooldown is the minimum interval between successive scans.
|
||||
scanCooldown = 1 * time.Second
|
||||
// fallbackPollInterval is the safety-net poll period used when no
|
||||
// filesystem events arrive.
|
||||
fallbackPollInterval = 30 * time.Second
|
||||
// maxTotalDiffSize is the maximum size of the combined
|
||||
// unified diff for an entire repository sent over the wire.
|
||||
// This must stay under the WebSocket message size limit.
|
||||
maxTotalDiffSize = 3 * 1024 * 1024 // 3 MiB
|
||||
)
|
||||
|
||||
// Handler manages per-connection git watch state.
|
||||
type Handler struct {
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
gitBin string // path to git binary; empty means "git" (from PATH)
|
||||
|
||||
mu sync.Mutex
|
||||
repoRoots map[string]struct{} // watched repo roots
|
||||
lastSnapshots map[string]repoSnapshot // last emitted snapshot per repo
|
||||
lastScanAt time.Time // when the last scan completed
|
||||
scanTrigger chan struct{} // buffered(1), poked by triggers
|
||||
}
|
||||
|
||||
// repoSnapshot captures the last emitted state for delta comparison.
|
||||
type repoSnapshot struct {
|
||||
branch string
|
||||
remoteOrigin string
|
||||
unifiedDiff string
|
||||
}
|
||||
|
||||
// NewHandler creates a new git watch handler.
|
||||
func NewHandler(logger slog.Logger, opts ...Option) *Handler {
|
||||
h := &Handler{
|
||||
logger: logger,
|
||||
clock: quartz.NewReal(),
|
||||
gitBin: "git",
|
||||
repoRoots: make(map[string]struct{}),
|
||||
lastSnapshots: make(map[string]repoSnapshot),
|
||||
scanTrigger: make(chan struct{}, 1),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(h)
|
||||
}
|
||||
|
||||
// Check if git is available.
|
||||
if _, err := exec.LookPath(h.gitBin); err != nil {
|
||||
h.logger.Warn(context.Background(), "git binary not found, git scanning disabled")
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// gitAvailable returns true if the configured git binary can be found
|
||||
// in PATH.
|
||||
func (h *Handler) gitAvailable() bool {
|
||||
_, err := exec.LookPath(h.gitBin)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Subscribe processes a subscribe message, resolving paths to git repo
|
||||
// roots and adding new repos to the watch set. Returns true if any new
|
||||
// repo roots were added.
|
||||
func (h *Handler) Subscribe(paths []string) bool {
|
||||
if !h.gitAvailable() {
|
||||
return false
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
added := false
|
||||
for _, p := range paths {
|
||||
if !filepath.IsAbs(p) {
|
||||
continue
|
||||
}
|
||||
p = filepath.Clean(p)
|
||||
|
||||
root, err := findRepoRoot(h.gitBin, p)
|
||||
if err != nil {
|
||||
// Not a git path — silently ignore.
|
||||
continue
|
||||
}
|
||||
if _, ok := h.repoRoots[root]; ok {
|
||||
continue
|
||||
}
|
||||
h.repoRoots[root] = struct{}{}
|
||||
added = true
|
||||
}
|
||||
return added
|
||||
}
|
||||
|
||||
// RequestScan pokes the scan trigger so the run loop performs a scan.
|
||||
func (h *Handler) RequestScan() {
|
||||
select {
|
||||
case h.scanTrigger <- struct{}{}:
|
||||
default:
|
||||
// Already pending.
|
||||
}
|
||||
}
|
||||
|
||||
// Scan performs a scan of all subscribed repos and computes deltas
|
||||
// against the previously emitted snapshots.
|
||||
func (h *Handler) Scan(ctx context.Context) *codersdk.WorkspaceAgentGitServerMessage {
|
||||
if !h.gitAvailable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
roots := make([]string, 0, len(h.repoRoots))
|
||||
for r := range h.repoRoots {
|
||||
roots = append(roots, r)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if len(roots) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := h.clock.Now().UTC()
|
||||
var repos []codersdk.WorkspaceAgentRepoChanges
|
||||
|
||||
// Perform all I/O outside the lock to avoid blocking
|
||||
// AddPaths/GetPaths/Subscribe callers during disk-heavy scans.
|
||||
type scanResult struct {
|
||||
root string
|
||||
changes codersdk.WorkspaceAgentRepoChanges
|
||||
err error
|
||||
}
|
||||
results := make([]scanResult, 0, len(roots))
|
||||
for _, root := range roots {
|
||||
changes, err := getRepoChanges(ctx, h.logger, h.gitBin, root)
|
||||
results = append(results, scanResult{root: root, changes: changes, err: err})
|
||||
}
|
||||
|
||||
// Re-acquire the lock only to commit snapshot updates.
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
for _, res := range results {
|
||||
if res.err != nil {
|
||||
if isRepoDeleted(h.gitBin, res.root) {
|
||||
// Repo root or .git directory was removed.
|
||||
// Emit a removal entry, then evict from watch set.
|
||||
removal := codersdk.WorkspaceAgentRepoChanges{
|
||||
RepoRoot: res.root,
|
||||
Removed: true,
|
||||
}
|
||||
delete(h.repoRoots, res.root)
|
||||
delete(h.lastSnapshots, res.root)
|
||||
repos = append(repos, removal)
|
||||
} else {
|
||||
// Transient error — log and skip without
|
||||
// removing the repo from the watch set.
|
||||
h.logger.Warn(ctx, "scan repo failed",
|
||||
slog.F("root", res.root),
|
||||
slog.Error(res.err),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
prev, hasPrev := h.lastSnapshots[res.root]
|
||||
if hasPrev &&
|
||||
prev.branch == res.changes.Branch &&
|
||||
prev.remoteOrigin == res.changes.RemoteOrigin &&
|
||||
prev.unifiedDiff == res.changes.UnifiedDiff {
|
||||
// No change in this repo since last emit.
|
||||
continue
|
||||
}
|
||||
|
||||
// Update snapshot.
|
||||
h.lastSnapshots[res.root] = repoSnapshot{
|
||||
branch: res.changes.Branch,
|
||||
remoteOrigin: res.changes.RemoteOrigin,
|
||||
unifiedDiff: res.changes.UnifiedDiff,
|
||||
}
|
||||
|
||||
repos = append(repos, res.changes)
|
||||
}
|
||||
|
||||
h.lastScanAt = now
|
||||
|
||||
if len(repos) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeChanges,
|
||||
ScannedAt: &now,
|
||||
Repositories: repos,
|
||||
}
|
||||
}
|
||||
|
||||
// RunLoop runs the main event loop that listens for refresh requests
|
||||
// and fallback poll ticks. It calls scanFn whenever a scan should
|
||||
// happen (rate-limited to scanCooldown). It blocks until ctx is
|
||||
// canceled.
|
||||
func (h *Handler) RunLoop(ctx context.Context, scanFn func()) {
|
||||
fallbackTicker := h.clock.NewTicker(fallbackPollInterval)
|
||||
defer fallbackTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-h.scanTrigger:
|
||||
h.rateLimitedScan(ctx, scanFn)
|
||||
|
||||
case <-fallbackTicker.C:
|
||||
h.rateLimitedScan(ctx, scanFn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) rateLimitedScan(ctx context.Context, scanFn func()) {
|
||||
h.mu.Lock()
|
||||
elapsed := h.clock.Since(h.lastScanAt)
|
||||
if elapsed < scanCooldown {
|
||||
h.mu.Unlock()
|
||||
|
||||
// Wait for cooldown then scan.
|
||||
remaining := scanCooldown - elapsed
|
||||
timer := h.clock.NewTimer(remaining)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
scanFn()
|
||||
return
|
||||
}
|
||||
h.mu.Unlock()
|
||||
scanFn()
|
||||
}
|
||||
|
||||
// isRepoDeleted returns true when the repo root directory or its .git
|
||||
// entry no longer represents a valid git repository. This
|
||||
// distinguishes a genuine repo deletion from a transient scan error
|
||||
// (e.g. lock contention).
|
||||
//
|
||||
// It handles three deletion cases:
|
||||
// 1. The repo root directory itself was removed.
|
||||
// 2. The .git entry (directory or file) was removed.
|
||||
// 3. The .git entry is a file (worktree/submodule) whose target
|
||||
// gitdir was removed. In this case .git exists on disk but
|
||||
// `git rev-parse --git-dir` fails because the referenced
|
||||
// directory is gone.
|
||||
func isRepoDeleted(gitBin string, repoRoot string) bool {
|
||||
if _, err := os.Stat(repoRoot); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
gitPath := filepath.Join(repoRoot, ".git")
|
||||
fi, err := os.Stat(gitPath)
|
||||
if os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
// If .git is a regular file (worktree or submodule), the actual
|
||||
// git object store lives elsewhere. Validate that the target is
|
||||
// still reachable by running git rev-parse.
|
||||
if err == nil && !fi.IsDir() {
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := cmd.Run(); err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findRepoRoot uses `git rev-parse --show-toplevel` to find the
|
||||
// repository root for the given path.
|
||||
func findRepoRoot(gitBin string, p string) (string, error) {
|
||||
// If p is a file, start from its parent directory.
|
||||
dir := p
|
||||
if info, err := os.Stat(dir); err != nil || !info.IsDir() {
|
||||
dir = filepath.Dir(dir)
|
||||
}
|
||||
cmd := exec.CommandContext(context.Background(), gitBin, "rev-parse", "--show-toplevel")
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("no git repo found for %s", p)
|
||||
}
|
||||
root := filepath.FromSlash(strings.TrimSpace(string(out)))
|
||||
// Resolve symlinks and short (8.3) names on Windows so the
|
||||
// returned root matches paths produced by Go's filepath APIs.
|
||||
if resolved, evalErr := filepath.EvalSymlinks(root); evalErr == nil {
|
||||
root = resolved
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// getRepoChanges reads the current state of a git repository using
|
||||
// the git CLI. It returns branch, remote origin, and a unified diff.
|
||||
func getRepoChanges(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (codersdk.WorkspaceAgentRepoChanges, error) {
|
||||
result := codersdk.WorkspaceAgentRepoChanges{
|
||||
RepoRoot: repoRoot,
|
||||
}
|
||||
|
||||
// Verify this is still a valid git repository before doing
|
||||
// anything else. This catches deleted repos early.
|
||||
verifyCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "--git-dir")
|
||||
if err := verifyCmd.Run(); err != nil {
|
||||
return result, xerrors.Errorf("not a git repository: %w", err)
|
||||
}
|
||||
|
||||
// Read branch name.
|
||||
branchCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "symbolic-ref", "--short", "HEAD")
|
||||
if out, err := branchCmd.Output(); err == nil {
|
||||
result.Branch = strings.TrimSpace(string(out))
|
||||
} else {
|
||||
logger.Debug(ctx, "failed to read HEAD", slog.F("root", repoRoot), slog.Error(err))
|
||||
}
|
||||
|
||||
// Read remote origin URL.
|
||||
remoteCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "config", "--get", "remote.origin.url")
|
||||
if out, err := remoteCmd.Output(); err == nil {
|
||||
result.RemoteOrigin = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// Compute unified diff.
|
||||
// `git diff HEAD` shows both staged and unstaged changes vs HEAD.
|
||||
// For repos with no commits yet, fall back to showing untracked
|
||||
// files only.
|
||||
diff, err := computeGitDiff(ctx, logger, gitBin, repoRoot)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("compute diff: %w", err)
|
||||
}
|
||||
|
||||
result.UnifiedDiff = diff
|
||||
if len(result.UnifiedDiff) > maxTotalDiffSize {
|
||||
result.UnifiedDiff = "Total diff too large to show. Size: " + humanize.IBytes(uint64(len(result.UnifiedDiff))) + ". Showing branch and remote only."
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// computeGitDiff produces a unified diff string for the repository by
|
||||
// combining `git diff HEAD` (staged + unstaged changes) with diffs
|
||||
// for untracked files.
|
||||
func computeGitDiff(ctx context.Context, logger slog.Logger, gitBin string, repoRoot string) (string, error) {
|
||||
var diffParts []string
|
||||
|
||||
// Check if the repo has any commits.
|
||||
hasCommits := true
|
||||
checkCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "rev-parse", "HEAD")
|
||||
if err := checkCmd.Run(); err != nil {
|
||||
hasCommits = false
|
||||
}
|
||||
|
||||
if hasCommits {
|
||||
// `git diff HEAD` captures both staged and unstaged changes
|
||||
// relative to HEAD in a single unified diff.
|
||||
cmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "HEAD")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("git diff HEAD: %w", err)
|
||||
}
|
||||
if len(out) > 0 {
|
||||
diffParts = append(diffParts, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
// Show untracked files as diffs too.
|
||||
// `git ls-files --others --exclude-standard` lists untracked,
|
||||
// non-ignored files.
|
||||
lsCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "ls-files", "--others", "--exclude-standard")
|
||||
lsOut, err := lsCmd.Output()
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "failed to list untracked files", slog.F("root", repoRoot), slog.Error(err))
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
|
||||
untrackedFiles := strings.Split(strings.TrimSpace(string(lsOut)), "\n")
|
||||
for _, f := range untrackedFiles {
|
||||
f = strings.TrimSpace(f)
|
||||
if f == "" {
|
||||
continue
|
||||
}
|
||||
// Use `git diff --no-index /dev/null <file>` to generate
|
||||
// a unified diff for untracked files.
|
||||
var stdout bytes.Buffer
|
||||
untrackedCmd := exec.CommandContext(ctx, gitBin, "-C", repoRoot, "diff", "--no-index", "--", "/dev/null", f)
|
||||
untrackedCmd.Stdout = &stdout
|
||||
// git diff --no-index exits with 1 when files differ,
|
||||
// which is expected. We ignore the error and check for
|
||||
// output instead.
|
||||
_ = untrackedCmd.Run()
|
||||
if stdout.Len() > 0 {
|
||||
diffParts = append(diffParts, stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(diffParts, ""), nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,147 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// API exposes the git watch HTTP routes for the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
opts []Option
|
||||
pathStore *PathStore
|
||||
}
|
||||
|
||||
// NewAPI creates a new git watch API.
|
||||
func NewAPI(logger slog.Logger, pathStore *PathStore, opts ...Option) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
pathStore: pathStore,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the chi router for mounting at /api/v0/git.
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/watch", a.handleWatch)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleWatch(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionNoContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to accept WebSocket.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 4 MiB read limit — subscribe messages with many paths can exceed the
|
||||
// default 32 KB limit. Matches the SDK/proxy side.
|
||||
conn.SetReadLimit(1 << 22)
|
||||
|
||||
stream := wsjson.NewStream[
|
||||
codersdk.WorkspaceAgentGitClientMessage,
|
||||
codersdk.WorkspaceAgentGitServerMessage,
|
||||
](conn, websocket.MessageText, websocket.MessageText, a.logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, a.logger, cancel, conn)
|
||||
|
||||
handler := NewHandler(a.logger, a.opts...)
|
||||
|
||||
// scanAndSend performs a scan and sends results if there are
|
||||
// changes.
|
||||
scanAndSend := func() {
|
||||
msg := handler.Scan(ctx)
|
||||
if msg != nil {
|
||||
if err := stream.Send(*msg); err != nil {
|
||||
a.logger.Debug(ctx, "failed to send changes", slog.Error(err))
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a chat_id query parameter is provided and the PathStore is
|
||||
// available, subscribe to path updates for this chat.
|
||||
chatIDStr := r.URL.Query().Get("chat_id")
|
||||
if chatIDStr != "" && a.pathStore != nil {
|
||||
chatID, parseErr := uuid.Parse(chatIDStr)
|
||||
if parseErr == nil {
|
||||
// Subscribe to future path updates BEFORE reading
|
||||
// existing paths. This ordering guarantees no
|
||||
// notification from AddPaths is lost: any call that
|
||||
// lands before Subscribe is picked up by GetPaths
|
||||
// below, and any call after Subscribe delivers a
|
||||
// notification on the channel.
|
||||
notifyCh, unsubscribe := a.pathStore.Subscribe(chatID)
|
||||
defer unsubscribe()
|
||||
|
||||
// Load any paths that are already tracked for this chat.
|
||||
existingPaths := a.pathStore.GetPaths(chatID)
|
||||
if len(existingPaths) > 0 {
|
||||
handler.Subscribe(existingPaths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-notifyCh:
|
||||
paths := a.pathStore.GetPaths(chatID)
|
||||
handler.Subscribe(paths)
|
||||
handler.RequestScan()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Start the main run loop in a goroutine.
|
||||
go handler.RunLoop(ctx, scanAndSend)
|
||||
|
||||
// Read client messages.
|
||||
updates := stream.Chan()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = stream.Close(websocket.StatusGoingAway)
|
||||
return
|
||||
case msg, ok := <-updates:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case codersdk.WorkspaceAgentGitClientMessageTypeRefresh:
|
||||
handler.RequestScan()
|
||||
default:
|
||||
if err := stream.Send(codersdk.WorkspaceAgentGitServerMessage{
|
||||
Type: codersdk.WorkspaceAgentGitServerMessageTypeError,
|
||||
Message: "unknown message type",
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// ExtractChatContext reads chat identity headers from the request.
|
||||
// Returns zero values if headers are absent (non-chat request).
|
||||
func ExtractChatContext(r *http.Request) (chatID uuid.UUID, ancestorIDs []uuid.UUID, ok bool) {
|
||||
raw := r.Header.Get(workspacesdk.CoderChatIDHeader)
|
||||
if raw == "" {
|
||||
return uuid.Nil, nil, false
|
||||
}
|
||||
chatID, err := uuid.Parse(raw)
|
||||
if err != nil {
|
||||
return uuid.Nil, nil, false
|
||||
}
|
||||
rawAncestors := r.Header.Get(workspacesdk.CoderAncestorChatIDsHeader)
|
||||
if rawAncestors != "" {
|
||||
var ids []string
|
||||
if err := json.Unmarshal([]byte(rawAncestors), &ids); err == nil {
|
||||
for _, s := range ids {
|
||||
if id, err := uuid.Parse(s); err == nil {
|
||||
ancestorIDs = append(ancestorIDs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return chatID, ancestorIDs, true
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package agentgit_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
func TestExtractChatContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
|
||||
ancestor1 := uuid.MustParse("11111111-2222-3333-4444-555555555555")
|
||||
ancestor2 := uuid.MustParse("66666666-7777-8888-9999-aaaaaaaaaaaa")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string // empty means header not set
|
||||
setChatID bool // whether to set the chat ID header at all
|
||||
ancestors string // empty means header not set
|
||||
setAncestors bool // whether to set the ancestor header at all
|
||||
wantChatID uuid.UUID
|
||||
wantAncestorIDs []uuid.UUID
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "NoHeadersPresent",
|
||||
setChatID: false,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_NoAncestors",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_ValidAncestors",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: mustMarshalJSON(t, []string{
|
||||
ancestor1.String(),
|
||||
ancestor2.String(),
|
||||
}),
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "MalformedChatID",
|
||||
chatID: "not-a-uuid",
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_MalformedAncestorJSON",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: `{this is not json}`,
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
// Only valid UUIDs in the array are returned; invalid
|
||||
// entries are silently skipped.
|
||||
name: "ValidChatID_PartialValidAncestorUUIDs",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: mustMarshalJSON(t, []string{
|
||||
ancestor1.String(),
|
||||
"bad-uuid",
|
||||
ancestor2.String(),
|
||||
}),
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: []uuid.UUID{ancestor1, ancestor2},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
// Header is explicitly set to an empty string, which
|
||||
// Header.Get returns as "".
|
||||
name: "EmptyChatIDHeader",
|
||||
chatID: "",
|
||||
setChatID: true,
|
||||
setAncestors: false,
|
||||
wantChatID: uuid.Nil,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "ValidChatID_EmptyAncestorHeader",
|
||||
chatID: validID.String(),
|
||||
setChatID: true,
|
||||
ancestors: "",
|
||||
setAncestors: true,
|
||||
wantChatID: validID,
|
||||
wantAncestorIDs: nil,
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
if tt.setChatID {
|
||||
r.Header.Set(workspacesdk.CoderChatIDHeader, tt.chatID)
|
||||
}
|
||||
if tt.setAncestors {
|
||||
r.Header.Set(workspacesdk.CoderAncestorChatIDsHeader, tt.ancestors)
|
||||
}
|
||||
|
||||
chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r)
|
||||
|
||||
require.Equal(t, tt.wantOK, ok, "ok mismatch")
|
||||
require.Equal(t, tt.wantChatID, chatID, "chatID mismatch")
|
||||
require.Equal(t, tt.wantAncestorIDs, ancestorIDs, "ancestorIDs mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mustMarshalJSON marshals v to a JSON string, failing the test on error.
|
||||
func mustMarshalJSON(t *testing.T, v any) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
return string(b)
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package agentgit
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// PathStore tracks which file paths each chat has touched.
|
||||
// It is safe for concurrent use.
|
||||
type PathStore struct {
|
||||
mu sync.RWMutex
|
||||
chatPaths map[uuid.UUID]map[string]struct{}
|
||||
subscribers map[uuid.UUID][]chan<- struct{}
|
||||
}
|
||||
|
||||
// NewPathStore creates a new PathStore.
|
||||
func NewPathStore() *PathStore {
|
||||
return &PathStore{
|
||||
chatPaths: make(map[uuid.UUID]map[string]struct{}),
|
||||
subscribers: make(map[uuid.UUID][]chan<- struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPaths adds paths to every chat in chatIDs and notifies
|
||||
// their subscribers. Zero-value UUIDs are silently skipped.
|
||||
func (ps *PathStore) AddPaths(chatIDs []uuid.UUID, paths []string) {
|
||||
affected := make([]uuid.UUID, 0, len(chatIDs))
|
||||
for _, id := range chatIDs {
|
||||
if id != uuid.Nil {
|
||||
affected = append(affected, id)
|
||||
}
|
||||
}
|
||||
if len(affected) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ps.mu.Lock()
|
||||
for _, id := range affected {
|
||||
m, ok := ps.chatPaths[id]
|
||||
if !ok {
|
||||
m = make(map[string]struct{})
|
||||
ps.chatPaths[id] = m
|
||||
}
|
||||
for _, p := range paths {
|
||||
m[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
|
||||
ps.notifySubscribers(affected)
|
||||
}
|
||||
|
||||
// Notify sends a signal to all subscribers of the given chat IDs
|
||||
// without adding any paths. Zero-value UUIDs are silently skipped.
|
||||
func (ps *PathStore) Notify(chatIDs []uuid.UUID) {
|
||||
affected := make([]uuid.UUID, 0, len(chatIDs))
|
||||
for _, id := range chatIDs {
|
||||
if id != uuid.Nil {
|
||||
affected = append(affected, id)
|
||||
}
|
||||
}
|
||||
if len(affected) == 0 {
|
||||
return
|
||||
}
|
||||
ps.notifySubscribers(affected)
|
||||
}
|
||||
|
||||
// notifySubscribers sends a non-blocking signal to all subscriber
|
||||
// channels for the given chat IDs.
|
||||
func (ps *PathStore) notifySubscribers(chatIDs []uuid.UUID) {
|
||||
ps.mu.RLock()
|
||||
toNotify := make([]chan<- struct{}, 0)
|
||||
for _, id := range chatIDs {
|
||||
toNotify = append(toNotify, ps.subscribers[id]...)
|
||||
}
|
||||
ps.mu.RUnlock()
|
||||
|
||||
for _, ch := range toNotify {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPaths returns all paths tracked for a chat, deduplicated
|
||||
// and sorted lexicographically.
|
||||
func (ps *PathStore) GetPaths(chatID uuid.UUID) []string {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
|
||||
m := ps.chatPaths[chatID]
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(m))
|
||||
for p := range m {
|
||||
out = append(out, p)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of chat IDs that have tracked paths.
|
||||
func (ps *PathStore) Len() int {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
return len(ps.chatPaths)
|
||||
}
|
||||
|
||||
// Subscribe returns a channel that receives a signal whenever
|
||||
// paths change for chatID, along with an unsubscribe function
|
||||
// that removes the channel.
|
||||
func (ps *PathStore) Subscribe(chatID uuid.UUID) (<-chan struct{}, func()) {
|
||||
ch := make(chan struct{}, 1)
|
||||
|
||||
ps.mu.Lock()
|
||||
ps.subscribers[chatID] = append(ps.subscribers[chatID], ch)
|
||||
ps.mu.Unlock()
|
||||
|
||||
unsub := func() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
subs := ps.subscribers[chatID]
|
||||
for i, s := range subs {
|
||||
if s == ch {
|
||||
ps.subscribers[chatID] = append(subs[:i], subs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ch, unsub
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package agentgit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPathStore_AddPaths_StoresForChatAndAncestors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestor1 := uuid.New()
|
||||
ancestor2 := uuid.New()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID, ancestor1, ancestor2}, []string{"/a", "/b"})
|
||||
|
||||
// All three IDs should see the paths.
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(chatID))
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor1))
|
||||
require.Equal(t, []string{"/a", "/b"}, ps.GetPaths(ancestor2))
|
||||
|
||||
// An unrelated chat should see nothing.
|
||||
require.Nil(t, ps.GetPaths(uuid.New()))
|
||||
}
|
||||
|
||||
func TestPathStore_AddPaths_SkipsNilUUIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
|
||||
// A nil chatID should be a no-op.
|
||||
ps.AddPaths([]uuid.UUID{uuid.Nil}, []string{"/x"})
|
||||
require.Nil(t, ps.GetPaths(uuid.Nil))
|
||||
|
||||
// A nil ancestor should be silently skipped.
|
||||
chatID := uuid.New()
|
||||
ps.AddPaths([]uuid.UUID{chatID, uuid.Nil}, []string{"/y"})
|
||||
require.Equal(t, []string{"/y"}, ps.GetPaths(chatID))
|
||||
require.Nil(t, ps.GetPaths(uuid.Nil))
|
||||
}
|
||||
|
||||
func TestPathStore_GetPaths_DeduplicatedSorted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/z", "/a", "/m", "/a", "/z"})
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/a", "/b"})
|
||||
|
||||
got := ps.GetPaths(chatID)
|
||||
require.Equal(t, []string{"/a", "/b", "/m", "/z"}, got)
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_ReceivesNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for notification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_MultipleSubscribers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch1, unsub1 := ps.Subscribe(chatID)
|
||||
defer unsub1()
|
||||
ch2, unsub2 := ps.Subscribe(chatID)
|
||||
defer unsub2()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
for i, ch := range []<-chan struct{}{ch1, ch2} {
|
||||
select {
|
||||
case <-ch:
|
||||
// OK
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("subscriber %d did not receive notification", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Unsubscribe_StopsNotifications(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID}, []string{"/file"})
|
||||
|
||||
// AddPaths sends synchronously via a non-blocking send to the
|
||||
// buffered channel, so if a notification were going to arrive
|
||||
// it would already be in the channel by now.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("received notification after unsubscribe")
|
||||
default:
|
||||
// Expected: no notification.
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Subscribe_AncestorNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestor := uuid.New()
|
||||
|
||||
// Subscribe to the ancestor, then add paths via the child.
|
||||
ch, unsub := ps.Subscribe(ancestor)
|
||||
defer unsub()
|
||||
|
||||
ps.AddPaths([]uuid.UUID{chatID, ancestor}, []string{"/file"})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("ancestor subscriber did not receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_NotifiesWithoutAddingPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{chatID})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for notification")
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_SkipsNilUUIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
|
||||
ch, unsub := ps.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{uuid.Nil})
|
||||
|
||||
// Notify sends synchronously via a non-blocking send to the
|
||||
// buffered channel, so if a notification were going to arrive
|
||||
// it would already be in the channel by now.
|
||||
select {
|
||||
case <-ch:
|
||||
t.Fatal("received notification for nil UUID")
|
||||
default:
|
||||
// Expected: no notification.
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestPathStore_Notify_AncestorNotification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ancestorID := uuid.New()
|
||||
|
||||
// Subscribe to the ancestor, then notify via the child.
|
||||
ch, unsub := ps.Subscribe(ancestorID)
|
||||
defer unsub()
|
||||
|
||||
ps.Notify([]uuid.UUID{chatID, ancestorID})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
select {
|
||||
case <-ch:
|
||||
// Success.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("ancestor subscriber did not receive notification")
|
||||
}
|
||||
|
||||
require.Nil(t, ps.GetPaths(ancestorID))
|
||||
}
|
||||
|
||||
func TestPathStore_ConcurrentSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := agentgit.NewPathStore()
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
|
||||
chatIDs := make([]uuid.UUID, goroutines)
|
||||
for i := range chatIDs {
|
||||
chatIDs[i] = uuid.New()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 2) // writers + readers
|
||||
|
||||
// Writers.
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
for j := range iterations {
|
||||
ancestors := []uuid.UUID{chatIDs[(idx+1)%goroutines]}
|
||||
path := []string{
|
||||
"/file-" + chatIDs[idx].String() + "-" + time.Now().Format(time.RFC3339Nano),
|
||||
"/iter-" + string(rune('0'+j%10)),
|
||||
}
|
||||
ps.AddPaths(append([]uuid.UUID{chatIDs[idx]}, ancestors...), path)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Readers.
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_ = ps.GetPaths(chatIDs[idx])
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify every chat has at least the paths it wrote.
|
||||
for _, id := range chatIDs {
|
||||
paths := ps.GetPaths(id)
|
||||
require.NotEmpty(t, paths, "chat %s should have paths", id)
|
||||
}
|
||||
}
|
||||
+55
-7
@@ -5,11 +5,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -17,15 +20,17 @@ import (
|
||||
|
||||
// API exposes process-related operations through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
pathStore *agentgit.PathStore
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
pathStore: pathStore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +70,12 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
proc, err := api.manager.start(req)
|
||||
var chatID string
|
||||
if id, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
chatID = id.String()
|
||||
}
|
||||
|
||||
proc, err := api.manager.start(req, chatID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start process.",
|
||||
@@ -74,6 +84,23 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Notify git watchers after the process finishes so that
|
||||
// file changes made by the command are visible in the scan.
|
||||
// If a workdir is provided, track it as a path as well.
|
||||
if api.pathStore != nil {
|
||||
if chatID, ancestorIDs, ok := agentgit.ExtractChatContext(r); ok {
|
||||
allIDs := append([]uuid.UUID{chatID}, ancestorIDs...)
|
||||
go func() {
|
||||
<-proc.done
|
||||
if req.WorkDir != "" {
|
||||
api.pathStore.AddPaths(allIDs, []string{req.WorkDir})
|
||||
} else {
|
||||
api.pathStore.Notify(allIDs)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
|
||||
ID: proc.id,
|
||||
Started: true,
|
||||
@@ -84,7 +111,28 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
infos := api.manager.list()
|
||||
var chatID string
|
||||
if id, _, ok := agentgit.ExtractChatContext(r); ok {
|
||||
chatID = id.String()
|
||||
}
|
||||
|
||||
infos := api.manager.list(chatID)
|
||||
|
||||
// Sort by running state (running first), then by started_at
|
||||
// descending so the most recent processes appear first.
|
||||
sort.Slice(infos, func(i, j int) bool {
|
||||
if infos[i].Running != infos[j].Running {
|
||||
return infos[i].Running
|
||||
}
|
||||
return infos[i].StartedAt > infos[j].StartedAt
|
||||
})
|
||||
|
||||
// Cap the response to avoid bloating LLM context.
|
||||
const maxListProcesses = 10
|
||||
if len(infos) > maxListProcesses {
|
||||
infos = infos[:maxListProcesses]
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
|
||||
Processes: infos,
|
||||
})
|
||||
|
||||
+244
-4
@@ -12,12 +12,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -25,7 +27,7 @@ import (
|
||||
)
|
||||
|
||||
// postStart sends a POST /start request and returns the recorder.
|
||||
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
|
||||
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
@@ -36,6 +38,13 @@ func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcess
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
|
||||
for _, h := range headers {
|
||||
for k, vals := range h {
|
||||
for _, v := range vals {
|
||||
r.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
@@ -99,7 +108,7 @@ func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, e
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
@@ -138,10 +147,10 @@ func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.Pro
|
||||
|
||||
// startAndGetID is a helper that starts a process and returns
|
||||
// the process ID.
|
||||
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
|
||||
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) string {
|
||||
t.Helper()
|
||||
|
||||
w := postStart(t, handler, req)
|
||||
w := postStart(t, handler, req, headers...)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
@@ -331,6 +340,180 @@ func TestListProcesses(t *testing.T) {
|
||||
require.Empty(t, resp.Processes)
|
||||
})
|
||||
|
||||
t.Run("FilterByChatID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
chatA := uuid.New().String()
|
||||
chatB := uuid.New().String()
|
||||
headersA := http.Header{workspacesdk.CoderChatIDHeader: {chatA}}
|
||||
headersB := http.Header{workspacesdk.CoderChatIDHeader: {chatB}}
|
||||
|
||||
// Start processes with different chat IDs.
|
||||
id1 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo chat-a",
|
||||
}, headersA)
|
||||
waitForExit(t, handler, id1)
|
||||
|
||||
id2 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo chat-b",
|
||||
}, headersB)
|
||||
waitForExit(t, handler, id2)
|
||||
|
||||
id3 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo chat-a-2",
|
||||
}, headersA)
|
||||
waitForExit(t, handler, id3)
|
||||
|
||||
// List with chat A header should return 2 processes.
|
||||
w := getListWithChatHeader(t, handler, chatA)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 2)
|
||||
|
||||
ids := make(map[string]bool)
|
||||
for _, p := range resp.Processes {
|
||||
ids[p.ID] = true
|
||||
}
|
||||
require.True(t, ids[id1])
|
||||
require.True(t, ids[id3])
|
||||
|
||||
// List with chat B header should return 1 process.
|
||||
w2 := getListWithChatHeader(t, handler, chatB)
|
||||
require.Equal(t, http.StatusOK, w2.Code)
|
||||
|
||||
var resp2 workspacesdk.ListProcessesResponse
|
||||
err = json.NewDecoder(w2.Body).Decode(&resp2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp2.Processes, 1)
|
||||
require.Equal(t, id2, resp2.Processes[0].ID)
|
||||
|
||||
// List without chat header should return all 3.
|
||||
w3 := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w3.Code)
|
||||
|
||||
var resp3 workspacesdk.ListProcessesResponse
|
||||
err = json.NewDecoder(w3.Body).Decode(&resp3)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp3.Processes, 3)
|
||||
})
|
||||
|
||||
t.Run("ChatIDFiltering", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
chatID := uuid.New().String()
|
||||
headers := http.Header{workspacesdk.CoderChatIDHeader: {chatID}}
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo with-chat",
|
||||
}, headers)
|
||||
waitForExit(t, handler, id)
|
||||
|
||||
// Listing with the same chat header should return
|
||||
// the process.
|
||||
w := getListWithChatHeader(t, handler, chatID)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 1)
|
||||
require.Equal(t, id, resp.Processes[0].ID)
|
||||
|
||||
// Listing with a different chat header should not
|
||||
// return the process.
|
||||
w2 := getListWithChatHeader(t, handler, uuid.New().String())
|
||||
require.Equal(t, http.StatusOK, w2.Code)
|
||||
|
||||
var resp2 workspacesdk.ListProcessesResponse
|
||||
err = json.NewDecoder(w2.Body).Decode(&resp2)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, resp2.Processes)
|
||||
|
||||
// Listing without a chat header should return the
|
||||
// process (no filtering).
|
||||
w3 := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w3.Code)
|
||||
|
||||
var resp3 workspacesdk.ListProcessesResponse
|
||||
err = json.NewDecoder(w3.Body).Decode(&resp3)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp3.Processes, 1)
|
||||
})
|
||||
|
||||
t.Run("SortAndLimit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start 12 short-lived processes so we exceed the
|
||||
// limit of 10.
|
||||
for i := 0; i < 12; i++ {
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: fmt.Sprintf("echo proc-%d", i),
|
||||
})
|
||||
waitForExit(t, handler, id)
|
||||
}
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 10, "should be capped at 10")
|
||||
|
||||
// All returned processes are exited, so they should
|
||||
// be sorted by StartedAt descending (newest first).
|
||||
for i := 1; i < len(resp.Processes); i++ {
|
||||
require.GreaterOrEqual(t, resp.Processes[i-1].StartedAt, resp.Processes[i].StartedAt,
|
||||
"processes should be sorted by started_at descending")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RunningProcessesSortedFirst", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start an exited process first.
|
||||
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
waitForExit(t, handler, exitedID)
|
||||
|
||||
// Start a running process after.
|
||||
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 2)
|
||||
|
||||
// Running process should come first regardless of
|
||||
// start order.
|
||||
require.Equal(t, runningID, resp.Processes[0].ID)
|
||||
require.True(t, resp.Processes[0].Running)
|
||||
require.Equal(t, exitedID, resp.Processes[1].ID)
|
||||
require.False(t, resp.Processes[1].Running)
|
||||
|
||||
// Clean up.
|
||||
postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("MixedRunningAndExited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -379,6 +562,23 @@ func TestListProcesses(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// getListWithChatHeader sends a GET /list request with the
|
||||
// Coder-Chat-Id header set and returns the recorder.
|
||||
func getListWithChatHeader(t *testing.T, handler http.Handler, chatID string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
|
||||
if chatID != "" {
|
||||
r.Header.Set(workspacesdk.CoderChatIDHeader, chatID)
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
func TestProcessOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -570,6 +770,46 @@ func TestSignalProcess(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
chatID := uuid.New()
|
||||
ch, unsub := pathStore.Subscribe(chatID)
|
||||
defer unsub()
|
||||
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
|
||||
return current, nil
|
||||
}, pathStore)
|
||||
defer api.Close()
|
||||
|
||||
routes := api.Routes()
|
||||
|
||||
body, err := json.Marshal(workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/start", bytes.NewReader(body))
|
||||
req.Header.Set(workspacesdk.CoderChatIDHeader, chatID.String())
|
||||
rw := httptest.NewRecorder()
|
||||
routes.ServeHTTP(rw, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
// The subscriber should be notified even though no paths
|
||||
// were added.
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for path store notification")
|
||||
}
|
||||
|
||||
// No paths should have been stored for this chat.
|
||||
require.Nil(t, pathStore.GetPaths(chatID))
|
||||
}
|
||||
|
||||
func TestProcessLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ import (
|
||||
var (
|
||||
errProcessNotFound = xerrors.New("process not found")
|
||||
errProcessNotRunning = xerrors.New("process is not running")
|
||||
|
||||
// exitedProcessReapAge is how long an exited process is
|
||||
// kept before being automatically removed from the map.
|
||||
exitedProcessReapAge = 5 * time.Minute
|
||||
)
|
||||
|
||||
// process represents a running or completed process.
|
||||
@@ -30,6 +34,7 @@ type process struct {
|
||||
command string
|
||||
workDir string
|
||||
background bool
|
||||
chatID string
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
buf *HeadTailBuffer
|
||||
@@ -89,7 +94,7 @@ func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(curr
|
||||
// processes use a long-lived context so the process survives
|
||||
// the HTTP request lifecycle. The background flag only affects
|
||||
// client-side polling behavior.
|
||||
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
|
||||
func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*process, error) {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
@@ -154,6 +159,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error)
|
||||
command: req.Command,
|
||||
workDir: req.WorkDir,
|
||||
background: req.Background,
|
||||
chatID: chatID,
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
buf: buf,
|
||||
@@ -215,14 +221,32 @@ func (m *manager) get(id string) (*process, bool) {
|
||||
return proc, ok
|
||||
}
|
||||
|
||||
// list returns info about all tracked processes.
|
||||
func (m *manager) list() []workspacesdk.ProcessInfo {
|
||||
// list returns info about all tracked processes. Exited
|
||||
// processes older than exitedProcessReapAge are removed.
|
||||
// If chatID is non-empty, only processes belonging to that
|
||||
// chat are returned.
|
||||
func (m *manager) list(chatID string) []workspacesdk.ProcessInfo {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
now := m.clock.Now()
|
||||
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
|
||||
for _, proc := range m.procs {
|
||||
infos = append(infos, proc.info())
|
||||
for id, proc := range m.procs {
|
||||
info := proc.info()
|
||||
// Reap processes that exited more than 5 minutes ago
|
||||
// to prevent unbounded map growth.
|
||||
if !info.Running && info.ExitedAt != nil {
|
||||
exitedAt := time.Unix(*info.ExitedAt, 0)
|
||||
if now.Sub(exitedAt) > exitedProcessReapAge {
|
||||
delete(m.procs, id)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Filter by chatID if provided.
|
||||
if chatID != "" && proc.chatID != chatID {
|
||||
continue
|
||||
}
|
||||
infos = append(infos, info)
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
@@ -110,6 +110,11 @@ type Config struct {
|
||||
// X11DisplayOffset is the offset to add to the X11 display number.
|
||||
// Default is 10.
|
||||
X11DisplayOffset *int
|
||||
// X11MaxPort overrides the highest port used for X11 forwarding
|
||||
// listeners. Defaults to X11MaxPort (6200). Useful in tests
|
||||
// to shrink the port range and reduce the number of sessions
|
||||
// required.
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// ReportConnection.
|
||||
@@ -158,6 +163,10 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
offset := X11DefaultDisplayOffset
|
||||
config.X11DisplayOffset = &offset
|
||||
}
|
||||
if config.X11MaxPort == nil {
|
||||
maxPort := X11MaxPort
|
||||
config.X11MaxPort = &maxPort
|
||||
}
|
||||
if config.UpdateEnv == nil {
|
||||
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
|
||||
}
|
||||
@@ -201,6 +210,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
x11HandlerErrors: metrics.x11HandlerErrors,
|
||||
fs: fs,
|
||||
displayOffset: *config.X11DisplayOffset,
|
||||
maxPort: *config.X11MaxPort,
|
||||
sessions: make(map[*x11Session]struct{}),
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
network: func() X11Network {
|
||||
|
||||
@@ -57,6 +57,7 @@ type x11Forwarder struct {
|
||||
x11HandlerErrors *prometheus.CounterVec
|
||||
fs afero.Fs
|
||||
displayOffset int
|
||||
maxPort int
|
||||
|
||||
// network creates X11 listener sockets. Defaults to osNet{}.
|
||||
network X11Network
|
||||
@@ -314,7 +315,7 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
|
||||
// the next available port starting from X11StartPort and displayOffset.
|
||||
func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) {
|
||||
// Look for an open port to listen on.
|
||||
for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ {
|
||||
for port := X11StartPort + x.displayOffset; port <= x.maxPort; port++ {
|
||||
if ctx.Err() != nil {
|
||||
return nil, -1, ctx.Err()
|
||||
}
|
||||
|
||||
@@ -142,8 +142,13 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// Use in-process networking for X11 forwarding.
|
||||
inproc := testutil.NewInProcNet()
|
||||
|
||||
// Limit port range so we only need a handful of sessions to fill it
|
||||
// (the default 190 ports may easily timeout or conflict with other
|
||||
// ports on the system).
|
||||
maxPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset + 5
|
||||
cfg := &agentssh.Config{
|
||||
X11Net: inproc,
|
||||
X11Net: inproc,
|
||||
X11MaxPort: &maxPort,
|
||||
}
|
||||
|
||||
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
|
||||
@@ -172,7 +177,7 @@ func TestServer_X11_EvictionLRU(t *testing.T) {
|
||||
// configured port range.
|
||||
|
||||
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
|
||||
maxSessions := agentssh.X11MaxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
maxSessions := maxPort - startPort + 1 - 1 // -1 for the blocked port
|
||||
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
|
||||
|
||||
// shellSession holds references to the session and its standard streams so
|
||||
|
||||
@@ -28,7 +28,9 @@ func (a *agent) apiHandler() http.Handler {
|
||||
})
|
||||
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/git", a.gitAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -156,7 +156,7 @@ func (fw *fsWatcher) loop(ctx context.Context) {
|
||||
|
||||
func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
|
||||
var events []FSEvent
|
||||
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if walkErr := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil //nolint:nilerr // best-effort
|
||||
}
|
||||
@@ -176,7 +176,10 @@ func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
|
||||
}
|
||||
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: false})
|
||||
return nil
|
||||
})
|
||||
}); walkErr != nil {
|
||||
fw.logger.Warn(context.Background(), "failed to walk directory",
|
||||
slog.F("dir", dir), slog.Error(walkErr))
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
|
||||
+39
-5
@@ -2,6 +2,7 @@ package reaper
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-reap"
|
||||
|
||||
@@ -42,9 +43,42 @@ func WithLogger(logger slog.Logger) Option {
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
// WithReaperStop sets a channel that, when closed, stops the reaper
|
||||
// goroutine. Callers that invoke ForkReap more than once in the
|
||||
// same process (e.g. tests) should use this to prevent goroutine
|
||||
// accumulation.
|
||||
func WithReaperStop(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.ReaperStop = ch
|
||||
}
|
||||
}
|
||||
|
||||
// WithReaperStopped sets a channel that is closed after the
|
||||
// reaper goroutine has fully exited.
|
||||
func WithReaperStopped(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.ReaperStopped = ch
|
||||
}
|
||||
}
|
||||
|
||||
// WithReapLock sets a mutex shared between the reaper and Wait4.
|
||||
// The reaper holds the write lock while reaping, and ForkReap
|
||||
// holds the read lock during Wait4, preventing the reaper from
|
||||
// stealing the child's exit status. This is only needed for
|
||||
// tests with instant-exit children where the race window is
|
||||
// large.
|
||||
func WithReapLock(mu *sync.RWMutex) Option {
|
||||
return func(o *options) {
|
||||
o.ReapLock = mu
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
ReaperStop chan struct{}
|
||||
ReaperStopped chan struct{}
|
||||
ReapLock *sync.RWMutex
|
||||
}
|
||||
|
||||
+97
-29
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -18,25 +19,79 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestReap checks that's the reaper is successfully reaping
|
||||
// exited processes and passing the PIDs through the shared
|
||||
// channel.
|
||||
// subprocessEnvKey is set when a test re-execs itself as an
|
||||
// isolated subprocess. Tests that call ForkReap or send signals
|
||||
// to their own process check this to decide whether to run real
|
||||
// test logic or launch the subprocess and wait for it.
|
||||
const subprocessEnvKey = "CODER_REAPER_TEST_SUBPROCESS"
|
||||
|
||||
// runSubprocess re-execs the current test binary in a new process
|
||||
// running only the named test. This isolates ForkReap's
|
||||
// syscall.ForkExec and any process-directed signals (e.g. SIGINT)
|
||||
// from the parent test binary, making these tests safe to run in
|
||||
// CI and alongside other tests.
|
||||
//
|
||||
//nolint:paralleltest
|
||||
// Returns true inside the subprocess (caller should proceed with
|
||||
// the real test logic). Returns false in the parent after the
|
||||
// subprocess exits successfully (caller should return).
|
||||
func runSubprocess(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
if os.Getenv(subprocessEnvKey) == "1" {
|
||||
return true
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
//nolint:gosec // Test-controlled arguments.
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=^"+t.Name()+"$",
|
||||
"-test.v",
|
||||
)
|
||||
cmd.Env = append(os.Environ(), subprocessEnvKey+"=1")
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
t.Logf("Subprocess output:\n%s", out)
|
||||
require.NoError(t, err, "subprocess failed")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// withDone returns options that stop the reaper goroutine when t
|
||||
// completes and wait for it to fully exit, preventing
|
||||
// overlapping reapers across sequential subtests.
|
||||
func withDone(t *testing.T) []reaper.Option {
|
||||
t.Helper()
|
||||
stop := make(chan struct{})
|
||||
stopped := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(stop)
|
||||
<-stopped
|
||||
})
|
||||
return []reaper.Option{
|
||||
reaper.WithReaperStop(stop),
|
||||
reaper.WithReaperStopped(stopped),
|
||||
}
|
||||
}
|
||||
|
||||
// TestReap checks that the reaper successfully reaps exited
|
||||
// processes and passes their PIDs through the shared channel.
|
||||
func TestReap(t *testing.T) {
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
t.Parallel()
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
pids := make(reap.PidCh, 1)
|
||||
exitCode, err := reaper.ForkReap(
|
||||
var reapLock sync.RWMutex
|
||||
opts := append([]reaper.Option{
|
||||
reaper.WithPIDCallback(pids),
|
||||
// Provide some argument that immediately exits.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
|
||||
)
|
||||
reaper.WithReapLock(&reapLock),
|
||||
}, withDone(t)...)
|
||||
reapLock.RLock()
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
reapLock.RUnlock()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, exitCode)
|
||||
|
||||
@@ -56,7 +111,7 @@ func TestReap(t *testing.T) {
|
||||
|
||||
expectedPIDs := []int{cmd.Process.Pid, cmd2.Process.Pid}
|
||||
|
||||
for i := 0; i < len(expectedPIDs); i++ {
|
||||
for range len(expectedPIDs) {
|
||||
select {
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("Timed out waiting for process")
|
||||
@@ -66,10 +121,11 @@ func TestReap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:paralleltest
|
||||
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
|
||||
func TestForkReapExitCodes(t *testing.T) {
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
t.Parallel()
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -85,24 +141,31 @@ func TestForkReapExitCodes(t *testing.T) {
|
||||
{"SIGTERM", "kill -15 $$", 128 + 15},
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Subtests must be sequential, each starts its own reaper.
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
var reapLock sync.RWMutex
|
||||
opts := append([]reaper.Option{
|
||||
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
|
||||
)
|
||||
reaper.WithReapLock(&reapLock),
|
||||
}, withDone(t)...)
|
||||
reapLock.RLock()
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
reapLock.RUnlock()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Signal handling.
|
||||
// TestReapInterrupt verifies that ForkReap forwards caught signals
|
||||
// to the child process. The test sends SIGINT to its own process
|
||||
// and checks that the child receives it. Running in a subprocess
|
||||
// ensures SIGINT cannot kill the parent test binary.
|
||||
func TestReapInterrupt(t *testing.T) {
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
t.Parallel()
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
errC := make(chan error, 1)
|
||||
@@ -115,23 +178,28 @@ func TestReapInterrupt(t *testing.T) {
|
||||
defer signal.Stop(usrSig)
|
||||
|
||||
go func() {
|
||||
exitCode, err := reaper.ForkReap(
|
||||
opts := append([]reaper.Option{
|
||||
reaper.WithPIDCallback(pids),
|
||||
reaper.WithCatchSignals(os.Interrupt),
|
||||
// Signal propagation does not extend to children of children, so
|
||||
// we create a little bash script to ensure sleep is interrupted.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||
)
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf(
|
||||
"pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait",
|
||||
os.Getpid(), os.Getpid(),
|
||||
)),
|
||||
}, withDone(t)...)
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
// The child exits with 128 + SIGTERM (15) = 143, but the trap catches
|
||||
// SIGINT and sends SIGTERM to the sleep process, so exit code varies.
|
||||
_ = exitCode
|
||||
errC <- err
|
||||
}()
|
||||
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR1)
|
||||
require.Equal(t, syscall.SIGUSR1, <-usrSig)
|
||||
|
||||
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR2)
|
||||
|
||||
require.Equal(t, syscall.SIGUSR2, <-usrSig)
|
||||
require.NoError(t, <-errC)
|
||||
}
|
||||
|
||||
+24
-14
@@ -19,31 +19,36 @@ func IsInitProcess() bool {
|
||||
return os.Getpid() == 1
|
||||
}
|
||||
|
||||
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
|
||||
// startSignalForwarding registers signal handlers synchronously
|
||||
// then forwards caught signals to the child in a background
|
||||
// goroutine. Registering before the goroutine starts ensures no
|
||||
// signal is lost between ForkExec and the handler being ready.
|
||||
func startSignalForwarding(logger slog.Logger, pid int, sigs []os.Signal) {
|
||||
if len(sigs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, sigs...)
|
||||
defer signal.Stop(sc)
|
||||
|
||||
logger.Info(context.Background(), "reaper catching signals",
|
||||
slog.F("signals", sigs),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
|
||||
for {
|
||||
s := <-sc
|
||||
sig, ok := s.(syscall.Signal)
|
||||
if ok {
|
||||
logger.Info(context.Background(), "reaper caught signal, killing child process",
|
||||
slog.F("signal", sig.String()),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
_ = syscall.Kill(pid, sig)
|
||||
go func() {
|
||||
defer signal.Stop(sc)
|
||||
for s := range sc {
|
||||
sig, ok := s.(syscall.Signal)
|
||||
if ok {
|
||||
logger.Info(context.Background(), "reaper caught signal, killing child process",
|
||||
slog.F("signal", sig.String()),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
_ = syscall.Kill(pid, sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ForkReap spawns a goroutine that reaps children. In order to avoid
|
||||
@@ -64,7 +69,12 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
o(opts)
|
||||
}
|
||||
|
||||
go reap.ReapChildren(opts.PIDs, nil, nil, nil)
|
||||
go func() {
|
||||
reap.ReapChildren(opts.PIDs, nil, opts.ReaperStop, opts.ReapLock)
|
||||
if opts.ReaperStopped != nil {
|
||||
close(opts.ReaperStopped)
|
||||
}
|
||||
}()
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -90,7 +100,7 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
return 1, xerrors.Errorf("fork exec: %w", err)
|
||||
}
|
||||
|
||||
go catchSignals(opts.Logger, pid, opts.CatchSignals)
|
||||
startSignalForwarding(opts.Logger, pid, opts.CatchSignals)
|
||||
|
||||
var wstatus syscall.WaitStatus
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -40,6 +41,18 @@ func New(t testing.TB, args ...string) (*serpent.Invocation, config.Root) {
|
||||
return NewWithCommand(t, cmd, args...)
|
||||
}
|
||||
|
||||
// NewWithClock is like New, but injects the given clock for
|
||||
// tests that are time-dependent.
|
||||
func NewWithClock(t testing.TB, clk quartz.Clock, args ...string) (*serpent.Invocation, config.Root) {
|
||||
var root cli.RootCmd
|
||||
root.SetClock(clk)
|
||||
|
||||
cmd, err := root.Command(root.AGPL())
|
||||
require.NoError(t, err)
|
||||
|
||||
return NewWithCommand(t, cmd, args...)
|
||||
}
|
||||
|
||||
type logWriter struct {
|
||||
prefix string
|
||||
log slog.Logger
|
||||
|
||||
@@ -123,6 +123,10 @@ func Select(inv *serpent.Invocation, opts SelectOptions) (string, error) {
|
||||
initialModel.height = defaultSelectModelHeight
|
||||
}
|
||||
|
||||
if idx := slices.Index(opts.Options, opts.Default); idx >= 0 {
|
||||
initialModel.cursor = idx
|
||||
}
|
||||
|
||||
initialModel.search.Prompt = ""
|
||||
initialModel.search.Focus()
|
||||
|
||||
|
||||
+3
-3
@@ -109,13 +109,13 @@ func (RootCmd) promptExample() *serpent.Command {
|
||||
Options: []string{
|
||||
"Blue", "Green", "Yellow", "Red", "Something else",
|
||||
},
|
||||
Default: "",
|
||||
Default: "Green",
|
||||
Message: "Select your favorite color:",
|
||||
Size: 5,
|
||||
HideSearch: !useSearch,
|
||||
})
|
||||
if value == "Something else" {
|
||||
_, _ = fmt.Fprint(inv.Stdout, "I would have picked blue.\n")
|
||||
_, _ = fmt.Fprint(inv.Stdout, "I would have picked green.\n")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "%s is a nice color.\n", value)
|
||||
}
|
||||
@@ -128,7 +128,7 @@ func (RootCmd) promptExample() *serpent.Command {
|
||||
Options: []string{
|
||||
"Car", "Bike", "Plane", "Boat", "Train",
|
||||
},
|
||||
Default: "Car",
|
||||
Default: "Bike",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -57,7 +57,9 @@ func (*RootCmd) scaletestLLMMock() *serpent.Command {
|
||||
return xerrors.Errorf("start mock LLM server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
if err := srv.Stop(); err != nil {
|
||||
logger.Error(ctx, "failed to stop mock LLM server", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
|
||||
|
||||
+12
-6
@@ -19,12 +19,18 @@ func OverrideVSCodeConfigs(fs afero.Fs) error {
|
||||
return err
|
||||
}
|
||||
mutate := func(m map[string]interface{}) {
|
||||
// This prevents VS Code from overriding GIT_ASKPASS, which
|
||||
// we use to automatically authenticate Git providers.
|
||||
m["git.useIntegratedAskPass"] = false
|
||||
// This prevents VS Code from using it's own GitHub authentication
|
||||
// which would circumvent cloning with Coder-configured providers.
|
||||
m["github.gitAuthentication"] = false
|
||||
// These defaults prevent VS Code from overriding
|
||||
// GIT_ASKPASS and using its own GitHub authentication,
|
||||
// which would circumvent cloning with Coder-configured
|
||||
// providers. We only set them if they are not already
|
||||
// present so that template authors can override them
|
||||
// via module settings (e.g. the vscode-web module).
|
||||
if _, ok := m["git.useIntegratedAskPass"]; !ok {
|
||||
m["git.useIntegratedAskPass"] = false
|
||||
}
|
||||
if _, ok := m["github.gitAuthentication"]; !ok {
|
||||
m["github.gitAuthentication"] = false
|
||||
}
|
||||
}
|
||||
|
||||
for _, configPath := range []string{
|
||||
|
||||
@@ -61,4 +61,31 @@ func TestOverrideVSCodeConfigs(t *testing.T) {
|
||||
require.Equal(t, "something", mapping["hotdogs"])
|
||||
}
|
||||
})
|
||||
t.Run("NoOverwrite", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fs := afero.NewMemMapFs()
|
||||
mapping := map[string]interface{}{
|
||||
"git.useIntegratedAskPass": true,
|
||||
"github.gitAuthentication": true,
|
||||
"other.setting": "preserved",
|
||||
}
|
||||
data, err := json.Marshal(mapping)
|
||||
require.NoError(t, err)
|
||||
for _, configPath := range configPaths {
|
||||
err = afero.WriteFile(fs, configPath, data, 0o600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = gitauth.OverrideVSCodeConfigs(fs)
|
||||
require.NoError(t, err)
|
||||
for _, configPath := range configPaths {
|
||||
data, err := afero.ReadFile(fs, configPath)
|
||||
require.NoError(t, err)
|
||||
mapping := map[string]interface{}{}
|
||||
err = json.Unmarshal(data, &mapping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, mapping["git.useIntegratedAskPass"])
|
||||
require.Equal(t, true, mapping["github.gitAuthentication"])
|
||||
require.Equal(t, "preserved", mapping["other.setting"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+2
-2
@@ -58,7 +58,7 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str
|
||||
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
_ = coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
|
||||
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).WithContext(ctx).Wait()
|
||||
return agentClient, r.AgentToken, pubkey
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ func TestGitSSH(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
writePrivateKeyToFile(t, idFile, privkey)
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setupCtx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
client, token, coderPubkey := prepareTestGitSSH(setupCtx, t)
|
||||
|
||||
authkey := make(chan gossh.PublicKey, 1)
|
||||
|
||||
+39
-1
@@ -357,6 +357,25 @@ func (r *RootCmd) login() *serpent.Command {
|
||||
}
|
||||
|
||||
sessionToken, _ := inv.ParsedFlags().GetString(varToken)
|
||||
tokenFlagProvided := inv.ParsedFlags().Changed(varToken)
|
||||
|
||||
// If CODER_SESSION_TOKEN is set in the environment, abort
|
||||
// interactive login unless --use-token-as-session or --token
|
||||
// is specified. The env var takes precedence over a token
|
||||
// stored on disk, so even if we complete login and write a
|
||||
// new token to the session file, subsequent CLI commands
|
||||
// would still use the environment variable value. When
|
||||
// --token is provided on the command line, the user
|
||||
// explicitly wants to authenticate with that token (common
|
||||
// in CI), so we skip this check.
|
||||
if !tokenFlagProvided && inv.Environ.Get(envSessionToken) != "" && !useTokenForSession {
|
||||
return xerrors.Errorf(
|
||||
"%s is set. This environment variable takes precedence over any session token stored on disk.\n\n"+
|
||||
"To log in, unset the environment variable and re-run this command:\n\n"+
|
||||
"\tunset %s",
|
||||
envSessionToken, envSessionToken,
|
||||
)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
authURL := *serverURL
|
||||
// Don't use filepath.Join, we don't want to use the os separator
|
||||
@@ -475,7 +494,26 @@ func (r *RootCmd) loginToken() *serpent.Command {
|
||||
Long: "Print the session token for use in scripts and automation.",
|
||||
Middleware: serpent.RequireNArgs(0),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
tok, err := r.ensureTokenBackend().Read(r.clientURL)
|
||||
if err := r.ensureClientURL(); err != nil {
|
||||
return err
|
||||
}
|
||||
// When using the file storage, a session token is stored for a single
|
||||
// deployment URL that the user is logged in to. They keyring can store
|
||||
// multiple deployment session tokens. Error if the requested URL doesn't
|
||||
// match the stored config URL when using file storage to avoid returning
|
||||
// a token for the wrong deployment.
|
||||
backend := r.ensureTokenBackend()
|
||||
if _, ok := backend.(*sessionstore.File); ok {
|
||||
conf := r.createConfig()
|
||||
storedURL, err := conf.URL().Read()
|
||||
if err == nil {
|
||||
storedURL = strings.TrimSpace(storedURL)
|
||||
if storedURL != r.clientURL.String() {
|
||||
return xerrors.Errorf("file session token storage only supports one server at a time: requested %s but logged into %s", r.clientURL.String(), storedURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
tok, err := backend.Read(r.clientURL)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, os.ErrNotExist) {
|
||||
return xerrors.New("no session token found - run 'coder login' first")
|
||||
|
||||
+58
-1
@@ -516,6 +516,40 @@ func TestLogin(t *testing.T) {
|
||||
require.NotEqual(t, client.SessionToken(), sessionFile)
|
||||
})
|
||||
|
||||
t.Run("SessionTokenEnvVar", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
root, _ := clitest.New(t, "login", client.URL.String())
|
||||
root.Environ.Set("CODER_SESSION_TOKEN", "invalid-token")
|
||||
err := root.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "CODER_SESSION_TOKEN is set")
|
||||
require.Contains(t, err.Error(), "unset CODER_SESSION_TOKEN")
|
||||
})
|
||||
|
||||
t.Run("SessionTokenEnvVarWithUseTokenAsSession", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
root, _ := clitest.New(t, "login", client.URL.String(), "--use-token-as-session")
|
||||
root.Environ.Set("CODER_SESSION_TOKEN", client.SessionToken())
|
||||
err := root.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SessionTokenEnvVarWithTokenFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
// Using --token with CODER_SESSION_TOKEN set should succeed.
|
||||
// This is the standard pattern used by coder/setup-action.
|
||||
root, _ := clitest.New(t, "login", client.URL.String(), "--token", client.SessionToken())
|
||||
root.Environ.Set("CODER_SESSION_TOKEN", client.SessionToken())
|
||||
err := root.Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("KeepOrganizationContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
@@ -558,10 +592,33 @@ func TestLoginToken(t *testing.T) {
|
||||
|
||||
t.Run("NoTokenStored", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inv, _ := clitest.New(t, "login", "token")
|
||||
client := coderdtest.New(t, nil)
|
||||
inv, _ := clitest.New(t, "login", "token", "--url", client.URL.String())
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no session token found")
|
||||
})
|
||||
|
||||
t.Run("NoURLProvided", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
inv, _ := clitest.New(t, "login", "token")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "You are not logged in")
|
||||
})
|
||||
|
||||
t.Run("URLMismatchFileBackend", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
inv, root := clitest.New(t, "login", "token", "--url", "https://other.example.com")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "file session token storage only supports one server")
|
||||
})
|
||||
}
|
||||
|
||||
+39
-21
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -230,6 +231,10 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
|
||||
}
|
||||
|
||||
func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) {
|
||||
if r.clock == nil {
|
||||
r.clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
fmtLong := `Coder %s — A tool for provisioning self-hosted development environments with Terraform.
|
||||
`
|
||||
hiddenAgentAuth := &AgentAuth{}
|
||||
@@ -548,32 +553,45 @@ type RootCmd struct {
|
||||
useKeyring bool
|
||||
keyringServiceName string
|
||||
useKeyringWithGlobalConfig bool
|
||||
|
||||
// clock is used for time-dependent operations. Initialized to
|
||||
// quartz.NewReal() in Command() if not set via SetClock.
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// SetClock sets the clock used for time-dependent operations.
|
||||
// Must be called before Command() to take effect.
|
||||
func (r *RootCmd) SetClock(clk quartz.Clock) {
|
||||
r.clock = clk
|
||||
}
|
||||
|
||||
// ensureClientURL loads the client URL from the config file if it
|
||||
// wasn't provided via --url or CODER_URL.
|
||||
func (r *RootCmd) ensureClientURL() error {
|
||||
if r.clientURL != nil && r.clientURL.String() != "" {
|
||||
return nil
|
||||
}
|
||||
rawURL, err := r.createConfig().URL().Read()
|
||||
// If the configuration files are absent, the user is logged out.
|
||||
if os.IsNotExist(err) {
|
||||
binPath, err := os.Executable()
|
||||
if err != nil {
|
||||
binPath = "coder"
|
||||
}
|
||||
return xerrors.Errorf(notLoggedInMessage, binPath)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.clientURL, err = url.Parse(strings.TrimSpace(rawURL))
|
||||
return err
|
||||
}
|
||||
|
||||
// InitClient creates and configures a new client with authentication, telemetry,
|
||||
// and version checks.
|
||||
func (r *RootCmd) InitClient(inv *serpent.Invocation) (*codersdk.Client, error) {
|
||||
conf := r.createConfig()
|
||||
var err error
|
||||
// Read the client URL stored on disk.
|
||||
if r.clientURL == nil || r.clientURL.String() == "" {
|
||||
rawURL, err := conf.URL().Read()
|
||||
// If the configuration files are absent, the user is logged out
|
||||
if os.IsNotExist(err) {
|
||||
binPath, err := os.Executable()
|
||||
if err != nil {
|
||||
binPath = "coder"
|
||||
}
|
||||
return nil, xerrors.Errorf(notLoggedInMessage, binPath)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.clientURL, err = url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.ensureClientURL(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.token == "" {
|
||||
tok, err := r.ensureTokenBackend().Read(r.clientURL)
|
||||
|
||||
@@ -2909,6 +2909,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
|
||||
provider.MCPToolDenyRegex = v.Value
|
||||
case "PKCE_METHODS":
|
||||
provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ")
|
||||
case "API_BASE_URL":
|
||||
provider.APIBaseURL = v.Value
|
||||
}
|
||||
providers[providerNum] = provider
|
||||
}
|
||||
|
||||
@@ -188,16 +188,17 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Creating user...")
|
||||
newUser, err = tx.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: uuid.New(),
|
||||
Email: newUserEmail,
|
||||
Username: newUserUsername,
|
||||
Name: "Admin User",
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
LoginType: database.LoginTypePassword,
|
||||
Status: "",
|
||||
ID: uuid.New(),
|
||||
Email: newUserEmail,
|
||||
Username: newUserUsername,
|
||||
Name: "Admin User",
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
RBACRoles: []string{rbac.RoleOwner().String()},
|
||||
LoginType: database.LoginTypePassword,
|
||||
Status: "",
|
||||
IsServiceAccount: false,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert user: %w", err)
|
||||
|
||||
@@ -108,6 +108,29 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE=github",
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
|
||||
"CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL)
|
||||
}
|
||||
|
||||
func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE=github",
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
assert.Equal(t, "", providers[0].APIBaseURL)
|
||||
}
|
||||
|
||||
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
|
||||
// environment variables are still supported.
|
||||
func TestReadGitAuthProvidersFromEnv(t *testing.T) {
|
||||
|
||||
@@ -21,9 +21,8 @@ type storedCredentials map[string]struct {
|
||||
APIToken string `json:"api_token"`
|
||||
}
|
||||
|
||||
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
|
||||
func TestKeyring(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "darwin" {
|
||||
t.Skip("linux is not supported yet")
|
||||
}
|
||||
@@ -37,8 +36,6 @@ func TestKeyring(t *testing.T) {
|
||||
)
|
||||
|
||||
t.Run("ReadNonExistent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -50,8 +47,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -63,8 +58,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("WriteAndRead", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -91,8 +84,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("WriteAndDelete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -115,8 +106,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("OverwriteToken", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -146,8 +135,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("MultipleServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
@@ -199,7 +186,6 @@ func TestKeyring(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("StorageFormat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// The storage format must remain consistent to ensure we don't break
|
||||
// compatibility with other Coder related applications that may read
|
||||
// or decode the same credential.
|
||||
|
||||
@@ -25,9 +25,8 @@ func readRawKeychainCredential(t *testing.T, serviceName string) []byte {
|
||||
return winCred.CredentialBlob
|
||||
}
|
||||
|
||||
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
|
||||
func TestWindowsKeyring_WriteReadDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testURL = "http://127.0.0.1:1337"
|
||||
srvURL, err := url.Parse(testURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
+7
-1
@@ -357,7 +357,13 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
// search domain expansion, which can add 20-30s of
|
||||
// delay on corporate networks with search domains
|
||||
// configured.
|
||||
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
|
||||
exists, ccErr := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost+".")
|
||||
if ccErr != nil {
|
||||
logger.Debug(ctx, "failed to check coder connect",
|
||||
slog.F("hostname", coderConnectHost),
|
||||
slog.Error(ccErr),
|
||||
)
|
||||
}
|
||||
if exists {
|
||||
defer cancel()
|
||||
|
||||
|
||||
+8
-16
@@ -180,15 +180,11 @@ func TestSSH(t *testing.T) {
|
||||
|
||||
// Delay until workspace is starting, otherwise the agent may be
|
||||
// booted due to outdated build.
|
||||
var err error
|
||||
for {
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
|
||||
break
|
||||
}
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
// When the agent connects, the workspace was started, and we should
|
||||
// have access to the shell.
|
||||
@@ -763,15 +759,11 @@ func TestSSH(t *testing.T) {
|
||||
|
||||
// Delay until workspace is starting, otherwise the agent may be
|
||||
// booted due to outdated build.
|
||||
var err error
|
||||
for {
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
|
||||
break
|
||||
}
|
||||
time.Sleep(testutil.IntervalFast)
|
||||
}
|
||||
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
// When the agent connects, the workspace was started, and we should
|
||||
// have access to the shell.
|
||||
|
||||
+12
-14
@@ -7,7 +7,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -103,13 +102,13 @@ func TestSyncCommands_Golden(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
client.Close()
|
||||
|
||||
// Start a goroutine to complete the dependency after a short delay
|
||||
// This simulates the dependency being satisfied while start is waiting
|
||||
// The delay ensures the "Waiting..." message appears in the output
|
||||
outBuf := testutil.NewWaitBuffer()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// Wait a moment to let the start command begin waiting and print the message
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := outBuf.WaitFor(ctx, "Waiting"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
|
||||
compCtx := context.Background()
|
||||
compClient, err := agentsocket.NewClient(compCtx, agentsocket.WithPath(path))
|
||||
@@ -119,7 +118,7 @@ func TestSyncCommands_Golden(t *testing.T) {
|
||||
}
|
||||
defer compClient.Close()
|
||||
|
||||
// Start and complete the dependency unit
|
||||
// Start and complete the dependency unit.
|
||||
err = compClient.SyncStart(compCtx, "dep-unit")
|
||||
if err != nil {
|
||||
done <- err
|
||||
@@ -129,21 +128,20 @@ func TestSyncCommands_Golden(t *testing.T) {
|
||||
done <- err
|
||||
}()
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--socket-path", path)
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
inv.Stdout = outBuf
|
||||
inv.Stderr = outBuf
|
||||
|
||||
// Run the start command - it should wait for the dependency
|
||||
// Run the start command - it should wait for the dependency.
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Ensure the completion goroutine finished
|
||||
// Ensure the completion goroutine finished.
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err, "complete dependency")
|
||||
case <-time.After(time.Second):
|
||||
// Goroutine should have finished by now
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for dependency completion goroutine")
|
||||
}
|
||||
|
||||
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/start_with_dependencies", outBuf.Bytes(), nil)
|
||||
|
||||
+12
-12
@@ -41,11 +41,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.Name, "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -64,11 +64,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String(), "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -87,11 +87,11 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -141,10 +141,10 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "logs", setup.task.ID.String())
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
|
||||
+22
-26
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -21,12 +20,12 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: Expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -34,7 +33,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -46,13 +45,13 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A different user's running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
|
||||
// Then: We expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -60,7 +59,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -70,11 +69,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// And: We confirm we want to pause the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -88,7 +87,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been paused")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -98,11 +97,11 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// But: We say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -114,7 +113,7 @@ func TestExpTaskPause(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to not be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -124,21 +123,18 @@ func TestExpTaskPause(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// And: We paused the running task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to pause the task again
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "pause", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is already paused
|
||||
err = inv.WithContext(ctx).Run()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "is already paused")
|
||||
})
|
||||
}
|
||||
|
||||
+31
-43
@@ -1,7 +1,6 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
@@ -17,29 +16,18 @@ import (
|
||||
func TestExpTaskResume(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// pauseTask is a helper that pauses a task and waits for the stop
|
||||
// build to complete.
|
||||
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -47,7 +35,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -59,14 +47,14 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A different user's paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
identifier := fmt.Sprintf("%s/%s", setup.task.OwnerName, setup.task.Name)
|
||||
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
clitest.SetupConfig(t, setup.ownerClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -74,7 +62,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
updated, err := setup.ownerClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -84,13 +72,13 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task (and specify no wait)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes", "--no-wait")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed in the background
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -99,11 +87,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Contains(t, output.Stdout(), "in the background")
|
||||
|
||||
// And: The task to eventually be resumed
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
require.True(t, setup.task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, setup.userClient, setup.task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, setup.userClient, ws.LatestBuild.ID)
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -113,12 +101,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// And: We confirm we want to resume the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -132,7 +120,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
pty.ExpectMatchContext(ctx, "has been resumed")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
@@ -142,12 +130,12 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// But: Say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
@@ -159,7 +147,7 @@ func TestExpTaskResume(t *testing.T) {
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to still be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
@@ -169,11 +157,11 @@ func TestExpTaskResume(t *testing.T) {
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to resume the task that is not paused
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
inv, root := clitest.New(t, "task", "resume", setup.task.Name, "--yes")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is not paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+154
-9
@@ -1,10 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -15,13 +20,15 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "send <task> [<input> | --stdin]",
|
||||
Short: "Send input to a task",
|
||||
Long: FormatExamples(Example{
|
||||
Description: "Send direct input to a task.",
|
||||
Command: "coder task send task1 \"Please also add unit tests\"",
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task.",
|
||||
Command: "echo \"Please also add unit tests\" | coder task send task1 --stdin",
|
||||
}),
|
||||
Long: `Send input to a task. If the task is paused, it will be automatically resumed before input is sent. If the task is initializing, it will wait for the task to become ready.
|
||||
` +
|
||||
FormatExamples(Example{
|
||||
Description: "Send direct input to a task",
|
||||
Command: `coder task send task1 "Please also add unit tests"`,
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task",
|
||||
Command: `echo "Please also add unit tests" | coder task send task1 --stdin`,
|
||||
}),
|
||||
Middleware: serpent.RequireRangeArgs(1, 2),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
@@ -64,8 +71,48 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
if err = client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task: %w", err)
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
// Before attempting to send, check the task status and
|
||||
// handle non-active states.
|
||||
var workspaceBuildID uuid.UUID
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusActive:
|
||||
// Already active, no build to watch.
|
||||
|
||||
case codersdk.TaskStatusPaused:
|
||||
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resume task %q: %w", display, err)
|
||||
} else if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("resume task %q", display)
|
||||
}
|
||||
|
||||
workspaceBuildID = resp.WorkspaceBuild.ID
|
||||
|
||||
case codersdk.TaskStatusInitializing:
|
||||
if !task.WorkspaceID.Valid {
|
||||
return xerrors.Errorf("send input to task %q: task has no backing workspace", display)
|
||||
}
|
||||
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
workspaceBuildID = workspace.LatestBuild.ID
|
||||
|
||||
default:
|
||||
return xerrors.Errorf("task %q has status %s and cannot be sent input", display, task.Status)
|
||||
}
|
||||
|
||||
if err := waitForTaskIdle(ctx, inv, client, task, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("wait for task %q to be idle: %w", display, err)
|
||||
}
|
||||
|
||||
if err := client.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task %q: %w", display, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -74,3 +121,101 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// waitForTaskIdle optionally watches a workspace build to completion,
|
||||
// then polls until the task becomes active and its app state is idle.
|
||||
// This merges build-watching and idle-polling into a single loop so
|
||||
// that status changes (e.g. paused) are never missed between phases.
|
||||
func waitForTaskIdle(ctx context.Context, inv *serpent.Invocation, client *codersdk.Client, task codersdk.Task, workspaceBuildID uuid.UUID) error {
|
||||
if workspaceBuildID != uuid.Nil {
|
||||
if err := cliui.WorkspaceBuild(ctx, inv.Stdout, client, workspaceBuildID); err != nil {
|
||||
return xerrors.Errorf("watch workspace build: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
cliui.Infof(inv.Stdout, "Waiting for task to become idle...")
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// It has been observed that the `TaskStatusError` state has
|
||||
// appeared during a typical healthy startup [^0]. To combat
|
||||
// this, we allow a 5 minute grace period where we allow
|
||||
// `TaskStatusError` to surface without immediately failing.
|
||||
//
|
||||
// TODO(DanielleMaywood):
|
||||
// Remove this grace period once the upstream agentapi health
|
||||
// check no longer reports transient error states during normal
|
||||
// startup.
|
||||
//
|
||||
// [0]: https://github.com/coder/coder/pull/22203#discussion_r2858002569
|
||||
const errorGracePeriod = 5 * time.Minute
|
||||
gracePeriodDeadline := time.Now().Add(errorGracePeriod)
|
||||
|
||||
// NOTE(DanielleMaywood):
|
||||
// On resume the MCP may not report an initial app status,
|
||||
// leaving CurrentState nil indefinitely. To avoid hanging
|
||||
// forever we treat Active with nil CurrentState as idle
|
||||
// after a grace period, giving the MCP time to report
|
||||
// during normal startup.
|
||||
const nilStateGracePeriod = 30 * time.Second
|
||||
var nilStateDeadline time.Time
|
||||
|
||||
// TODO(DanielleMaywood):
|
||||
// When we have a streaming Task API, this should be converted
|
||||
// away from polling.
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
task, err := client.TaskByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get task by id: %w", err)
|
||||
}
|
||||
|
||||
switch task.Status {
|
||||
case codersdk.TaskStatusInitializing,
|
||||
codersdk.TaskStatusPending:
|
||||
// Not yet active, keep polling.
|
||||
continue
|
||||
case codersdk.TaskStatusActive:
|
||||
// Task is active; check app state.
|
||||
if task.CurrentState == nil {
|
||||
// The MCP may not have reported state yet.
|
||||
// Start a grace period on first observation
|
||||
// and treat as idle once it expires.
|
||||
if nilStateDeadline.IsZero() {
|
||||
nilStateDeadline = time.Now().Add(nilStateGracePeriod)
|
||||
}
|
||||
if time.Now().After(nilStateDeadline) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Reset nil-state deadline since we got a real
|
||||
// state report.
|
||||
nilStateDeadline = time.Time{}
|
||||
switch task.CurrentState.State {
|
||||
case codersdk.TaskStateIdle,
|
||||
codersdk.TaskStateComplete,
|
||||
codersdk.TaskStateFailed:
|
||||
return nil
|
||||
default:
|
||||
// Still working, keep polling.
|
||||
continue
|
||||
}
|
||||
case codersdk.TaskStatusError:
|
||||
if time.Now().After(gracePeriodDeadline) {
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
}
|
||||
case codersdk.TaskStatusPaused:
|
||||
return xerrors.Errorf("task was paused while waiting for it to become idle")
|
||||
case codersdk.TaskStatusUnknown:
|
||||
return xerrors.Errorf("task entered %s state while waiting for it to become idle", task.Status)
|
||||
default:
|
||||
return xerrors.Errorf("task entered unexpected state (%s) while waiting for it to become idle", task.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+224
-13
@@ -12,9 +12,14 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -25,12 +30,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -41,12 +46,12 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.ID.String(), "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -57,13 +62,13 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "--stdin")
|
||||
inv.Stdout = &stdout
|
||||
inv.Stdin = strings.NewReader("carry on with the task")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
@@ -110,17 +115,223 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, assert.AnError.Error())
|
||||
})
|
||||
|
||||
t.Run("WaitsForInitializingTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent, pause, then resume the task so the
|
||||
// workspace is started but no agent is connected.
|
||||
// This puts the task in "initializing" state.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the initializing code
|
||||
// path before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the initializing state and
|
||||
// start watching the workspace build. This ensures the command
|
||||
// has entered the waiting code path.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("ResumesPausedTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Close the first agent before pausing so it does not conflict
|
||||
// with the agent we reconnect after the workspace is resumed.
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the paused task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
// Use a pty so we can wait for the command to produce build
|
||||
// output, confirming it has entered the paused code path and
|
||||
// triggered a resume before we connect the agent.
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to observe the paused state, trigger
|
||||
// a resume, and start watching the workspace build.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Connect a new agent so the task can transition to active.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
setup.agent = agenttest.New(t, setup.userClient.URL, setup.agentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, setup.userClient, setup.task.WorkspaceID.UUID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Report the task app as idle so waitForTaskIdle can proceed.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := setup.userClient.TaskByIdentifier(ctx, setup.task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusActive, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PausedDuringWaitForReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An initializing task (workspace running, no agent
|
||||
// connected).
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
require.NoError(t, setup.agent.Close())
|
||||
pauseTask(setupCtx, t, setup.userClient, setup.task)
|
||||
resumeTask(setupCtx, t, setup.userClient, setup.task)
|
||||
|
||||
// When: We attempt to send input to the initializing task.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Wait for the command to enter the build-watching phase
|
||||
// of waitForTaskReady.
|
||||
pty.ExpectMatchContext(ctx, "Queued")
|
||||
|
||||
// Pause the task while waitForTaskReady is polling. Since
|
||||
// no agent is connected, the task stays initializing until
|
||||
// we pause it, at which point the status becomes paused.
|
||||
pauseTask(ctx, t, setup.userClient, setup.task)
|
||||
|
||||
// Then: The command should fail because the task was paused.
|
||||
err := w.Wait()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "was paused while waiting for it to become idle")
|
||||
})
|
||||
|
||||
t.Run("WaitsForWorkingAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: An active task whose app is in "working" state.
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some task input", "some task response"))
|
||||
|
||||
// Move the app into "working" state before running the command.
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "busy",
|
||||
}))
|
||||
|
||||
// When: We send input while the app is working.
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some task input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv = inv.WithContext(ctx)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
|
||||
// Transition the app back to idle so waitForTaskIdle proceeds.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
|
||||
// Then: The command should complete successfully.
|
||||
require.NoError(t, w.Wait())
|
||||
})
|
||||
|
||||
t.Run("SendToNonIdleAppState", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, appState := range []codersdk.WorkspaceAppStatusState{
|
||||
codersdk.WorkspaceAppStatusStateComplete,
|
||||
codersdk.WorkspaceAppStatusStateFailure,
|
||||
} {
|
||||
t.Run(string(appState), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
setup := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "some input", "some response"))
|
||||
|
||||
agentClient := agentsdk.New(setup.userClient.URL, agentsdk.WithFixedToken(setup.agentToken))
|
||||
require.NoError(t, agentClient.PatchAppStatus(setupCtx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: appState,
|
||||
Message: "done",
|
||||
}))
|
||||
|
||||
inv, root := clitest.New(t, "task", "send", setup.task.Name, "some input")
|
||||
clitest.SetupConfig(t, setup.userClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
|
||||
@@ -151,7 +362,7 @@ func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) m
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
|
||||
func fakeAgentAPITaskSendErr(returnErr error) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
+5
-5
@@ -90,7 +90,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
tsr := toStatusRow(task)
|
||||
tsr := toStatusRow(task, r.clock.Now())
|
||||
out, err := formatter.Format(ctx, []taskStatusRow{tsr})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format task status: %w", err)
|
||||
@@ -112,7 +112,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
}
|
||||
|
||||
// Only print if something changed
|
||||
newStatusRow := toStatusRow(task)
|
||||
newStatusRow := toStatusRow(task, r.clock.Now())
|
||||
if !taskStatusRowEqual(lastStatusRow, newStatusRow) {
|
||||
out, err := formatter.Format(ctx, []taskStatusRow{newStatusRow})
|
||||
if err != nil {
|
||||
@@ -166,10 +166,10 @@ func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
|
||||
taskStateEqual(r1.CurrentState, r2.CurrentState)
|
||||
}
|
||||
|
||||
func toStatusRow(task codersdk.Task) taskStatusRow {
|
||||
func toStatusRow(task codersdk.Task, now time.Time) taskStatusRow {
|
||||
tsr := taskStatusRow{
|
||||
Task: task,
|
||||
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
|
||||
ChangedAgo: now.Sub(task.UpdatedAt).Truncate(time.Second).String() + " ago",
|
||||
}
|
||||
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
|
||||
task.WorkspaceAgentHealth.Healthy &&
|
||||
@@ -178,7 +178,7 @@ func toStatusRow(task codersdk.Task) taskStatusRow {
|
||||
!task.WorkspaceAgentLifecycle.ShuttingDown()
|
||||
|
||||
if task.CurrentState != nil {
|
||||
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
|
||||
tsr.ChangedAgo = now.Sub(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
|
||||
}
|
||||
return tsr
|
||||
}
|
||||
|
||||
+12
-9
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func Test_TaskStatus(t *testing.T) {
|
||||
@@ -28,12 +29,12 @@ func Test_TaskStatus(t *testing.T) {
|
||||
args []string
|
||||
expectOutput string
|
||||
expectError string
|
||||
hf func(context.Context, time.Time) func(http.ResponseWriter, *http.Request)
|
||||
hf func(context.Context, quartz.Clock) func(http.ResponseWriter, *http.Request)
|
||||
}{
|
||||
{
|
||||
args: []string{"doesnotexist"},
|
||||
expectError: httpapi.ResourceNotFoundResponse.Message,
|
||||
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
hf: func(ctx context.Context, _ quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/tasks/me/doesnotexist":
|
||||
@@ -49,7 +50,8 @@ func Test_TaskStatus(t *testing.T) {
|
||||
args: []string{"exists"},
|
||||
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
0s ago active true working Thinking furiously...`,
|
||||
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
hf: func(ctx context.Context, clk quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
|
||||
now := clk.Now()
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/tasks/me/exists":
|
||||
@@ -84,7 +86,8 @@ func Test_TaskStatus(t *testing.T) {
|
||||
4s ago active true
|
||||
3s ago active true working Reticulating splines...
|
||||
2s ago active true complete Splines reticulated successfully!`,
|
||||
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
|
||||
hf: func(ctx context.Context, clk quartz.Clock) func(http.ResponseWriter, *http.Request) {
|
||||
now := clk.Now()
|
||||
var calls atomic.Int64
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -215,7 +218,7 @@ func Test_TaskStatus(t *testing.T) {
|
||||
"created_at": "2025-08-26T12:34:56Z",
|
||||
"updated_at": "2025-08-26T12:34:56Z"
|
||||
}`,
|
||||
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
|
||||
hf: func(ctx context.Context, _ quartz.Clock) func(http.ResponseWriter, *http.Request) {
|
||||
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -252,8 +255,8 @@ func Test_TaskStatus(t *testing.T) {
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
now = time.Now().UTC() // TODO: replace with quartz
|
||||
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
|
||||
mClock = quartz.NewMock(t)
|
||||
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, mClock)))
|
||||
client = codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
sb = strings.Builder{}
|
||||
args = []string{"task", "status", "--watch-interval", testutil.IntervalFast.String()}
|
||||
@@ -261,10 +264,10 @@ func Test_TaskStatus(t *testing.T) {
|
||||
|
||||
t.Cleanup(srv.Close)
|
||||
args = append(args, tc.args...)
|
||||
inv, root := clitest.New(t, args...)
|
||||
inv, cfgDir := clitest.NewWithClock(t, mClock, args...)
|
||||
inv.Stdout = &sb
|
||||
inv.Stderr = &sb
|
||||
clitest.SetupConfig(t, client, root)
|
||||
clitest.SetupConfig(t, client, cfgDir)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
if tc.expectError == "" {
|
||||
assert.NoError(t, err)
|
||||
|
||||
+56
-5
@@ -88,6 +88,13 @@ func Test_Tasks(t *testing.T) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
// Report the task app as idle so that waitForTaskIdle
|
||||
// can proceed during the "send task message" step.
|
||||
require.NoError(t, agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
}))
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -272,10 +279,19 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
|
||||
// setupCLITaskTestResult holds the return values from setupCLITaskTest.
|
||||
type setupCLITaskTestResult struct {
|
||||
ownerClient *codersdk.Client
|
||||
userClient *codersdk.Client
|
||||
task codersdk.Task
|
||||
agentToken string
|
||||
agent agent.Agent
|
||||
}
|
||||
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) setupCLITaskTestResult {
|
||||
t.Helper()
|
||||
|
||||
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
|
||||
|
||||
@@ -292,21 +308,56 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the task's underlying workspace to be built
|
||||
// Wait for the task's underlying workspace to be built.
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
agt := agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
return ownerClient, userClient, task
|
||||
// Report the task app as idle so that waitForTaskIdle can proceed.
|
||||
err = agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
|
||||
AppSlug: "task-sidebar",
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "ready",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return setupCLITaskTestResult{
|
||||
ownerClient: ownerClient,
|
||||
userClient: userClient,
|
||||
task: task,
|
||||
agentToken: authToken,
|
||||
agent: agt,
|
||||
}
|
||||
}
|
||||
|
||||
// pauseTask pauses the task and waits for the stop build to complete.
|
||||
func pauseTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
// resumeTask resumes the task waits for the start build to complete. The task
|
||||
// will be in "initializing" state after this returns because no agent is connected.
|
||||
func resumeTask(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
resumeResp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resumeResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, resumeResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
|
||||
|
||||
-5
@@ -143,11 +143,6 @@ AI BRIDGE OPTIONS:
|
||||
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
|
||||
Whether to start an in-memory aibridged instance.
|
||||
|
||||
--aibridge-inject-coder-mcp-tools bool, $CODER_AIBRIDGE_INJECT_CODER_MCP_TOOLS (default: false)
|
||||
Whether to inject Coder's MCP tools into intercepted AI Bridge
|
||||
requests (requires the "oauth2" and "mcp-server-http" experiments to
|
||||
be enabled).
|
||||
|
||||
--aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0)
|
||||
Maximum number of concurrent AI Bridge requests per replica. Set to 0
|
||||
to disable (unlimited).
|
||||
|
||||
+5
-2
@@ -5,11 +5,14 @@ USAGE:
|
||||
|
||||
Send input to a task
|
||||
|
||||
- Send direct input to a task.:
|
||||
Send input to a task. If the task is paused, it will be automatically resumed
|
||||
before input is sent. If the task is initializing, it will wait for the task
|
||||
to become ready.
|
||||
- Send direct input to a task:
|
||||
|
||||
$ coder task send task1 "Please also add unit tests"
|
||||
|
||||
- Send input from stdin to a task.:
|
||||
- Send input from stdin to a task:
|
||||
|
||||
$ echo "Please also add unit tests" | coder task send task1 --stdin
|
||||
|
||||
|
||||
+4
-2
@@ -778,8 +778,10 @@ aibridge:
|
||||
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
|
||||
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
|
||||
# Whether to inject Coder's MCP tools into intercepted AI Bridge requests
|
||||
# (requires the "oauth2" and "mcp-server-http" experiments to be enabled).
|
||||
# Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a
|
||||
# future release. Whether to inject Coder's MCP tools into intercepted AI Bridge
|
||||
# requests (requires the "oauth2" and "mcp-server-http" experiments to be
|
||||
# enabled).
|
||||
# (default: false, type: bool)
|
||||
inject_coder_mcp_tools: false
|
||||
# Length of time to retain data such as interceptions and all related records
|
||||
|
||||
+13
-2
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -184,12 +185,22 @@ func TestTokens(t *testing.T) {
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
inv.Stdout = buf
|
||||
|
||||
// Precondition: validate token is not expired before expiring
|
||||
var expiredAtBefore time.Time
|
||||
token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two")
|
||||
require.NoError(t, err)
|
||||
now := dbtime.Now()
|
||||
require.True(t, token.ExpiresAt.After(now), "token should not be expired yet (expiresAt=%s, now=%s)", token.ExpiresAt.UTC(), now)
|
||||
expiredAtBefore = token.ExpiresAt
|
||||
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
// Validate that token was expired
|
||||
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
|
||||
now := time.Now()
|
||||
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future, but was %s (now=%s)", token.ExpiresAt, now)
|
||||
now := dbtime.Now()
|
||||
require.NotEqual(t, token.ExpiresAt, expiredAtBefore, "token expiresAt is the same as before expiring, but should have been updated")
|
||||
require.False(t, token.ExpiresAt.After(now), "token expiresAt should not be in the future after expiring, but was %s (now=%s)", token.ExpiresAt.UTC(), now)
|
||||
}
|
||||
|
||||
// Delete by ID (explicit delete flag)
|
||||
|
||||
@@ -116,10 +116,10 @@ func TestWorkspaceActivityBump(t *testing.T) {
|
||||
// is required. The Activity Bump behavior is also coupled with
|
||||
// Last Used, so it would be obvious to the user if we
|
||||
// are falsely recognizing activity.
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspace.LatestBuild.Deadline.Time, firstDeadline)
|
||||
require.Never(t, func() bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return err == nil && !workspace.LatestBuild.Deadline.Time.Equal(firstDeadline)
|
||||
}, testutil.IntervalMedium, testutil.IntervalFast, "deadline should not change")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -134,9 +134,12 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
|
||||
case database.WorkspaceAgentLifecycleStateReady,
|
||||
database.WorkspaceAgentLifecycleStateStartTimeout,
|
||||
database.WorkspaceAgentLifecycleStateStartError:
|
||||
a.emitMetricsOnce.Do(func() {
|
||||
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
|
||||
})
|
||||
// Only emit metrics for the parent agent, this metric is not intended to measure devcontainer durations.
|
||||
if !workspaceAgent.ParentID.Valid {
|
||||
a.emitMetricsOnce.Do(func() {
|
||||
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return req.Lifecycle, nil
|
||||
|
||||
@@ -582,6 +582,64 @@ func TestUpdateLifecycle(t *testing.T) {
|
||||
require.Equal(t, uint64(1), got.GetSampleCount())
|
||||
require.Equal(t, expectedDuration, got.GetSampleSum())
|
||||
})
|
||||
|
||||
t.Run("SubAgentDoesNotEmitMetric", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
parentID := uuid.New()
|
||||
subAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
ParentID: uuid.NullUUID{UUID: parentID, Valid: true},
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
|
||||
StartedAt: sql.NullTime{Valid: true, Time: someTime},
|
||||
ReadyAt: sql.NullTime{Valid: false},
|
||||
}
|
||||
lifecycle := &agentproto.Lifecycle{
|
||||
State: agentproto.Lifecycle_READY,
|
||||
ChangedAt: timestamppb.New(now),
|
||||
}
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
|
||||
ID: subAgent.ID,
|
||||
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
|
||||
StartedAt: subAgent.StartedAt,
|
||||
ReadyAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
}).Return(nil)
|
||||
// GetWorkspaceBuildMetricsByResourceID should NOT be called
|
||||
// because sub-agents should be skipped before querying.
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := agentapi.NewLifecycleMetrics(reg)
|
||||
api := &agentapi.LifecycleAPI{
|
||||
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
|
||||
return subAgent, nil
|
||||
},
|
||||
WorkspaceID: workspaceID,
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
Metrics: metrics,
|
||||
PublishWorkspaceUpdateFn: nil,
|
||||
}
|
||||
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
|
||||
Lifecycle: lifecycle,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, lifecycle, resp)
|
||||
|
||||
// We don't expect the metric to be emitted for sub-agents, by default this will fail anyway but it doesn't hurt
|
||||
// to document the test explicitly.
|
||||
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), gomock.Any()).Times(0)
|
||||
|
||||
// If we were emitting the metric we would have failed by now since it would include a call to the database that we're not expecting.
|
||||
pm, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
for _, m := range pm {
|
||||
if m.GetName() == fullMetricName {
|
||||
t.Fatal("metric should not be emitted for sub-agent")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateStartup(t *testing.T) {
|
||||
|
||||
@@ -387,9 +387,9 @@ func (b *Batcher) flush(ctx context.Context, reason string) {
|
||||
b.Metrics.BatchSize.Observe(float64(count))
|
||||
b.Metrics.MetadataTotal.Add(float64(count))
|
||||
b.Metrics.BatchesTotal.WithLabelValues(reason).Inc()
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(time.Since(start).Seconds())
|
||||
elapsed = b.clock.Since(start)
|
||||
b.Metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
|
||||
elapsed = time.Since(start)
|
||||
b.log.Debug(ctx, "flush complete",
|
||||
slog.F("count", count),
|
||||
slog.F("elapsed", elapsed),
|
||||
|
||||
+13
-1
@@ -315,6 +315,18 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
|
||||
}
|
||||
}
|
||||
|
||||
// appStatusStateToTaskState converts a WorkspaceAppStatusState to a
|
||||
// TaskState. The two enums mostly share values but "failure" in the
|
||||
// app status maps to "failed" in the public task API.
|
||||
func appStatusStateToTaskState(s codersdk.WorkspaceAppStatusState) codersdk.TaskState {
|
||||
switch s {
|
||||
case codersdk.WorkspaceAppStatusStateFailure:
|
||||
return codersdk.TaskStateFailed
|
||||
default:
|
||||
return codersdk.TaskState(s)
|
||||
}
|
||||
}
|
||||
|
||||
// deriveTaskCurrentState determines the current state of a task based on the
|
||||
// workspace's latest app status and initialization phase.
|
||||
// Returns nil if no valid state can be determined.
|
||||
@@ -334,7 +346,7 @@ func deriveTaskCurrentState(
|
||||
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
|
||||
currentState = &codersdk.TaskStateEntry{
|
||||
Timestamp: ws.LatestAppStatus.CreatedAt,
|
||||
State: codersdk.TaskState(ws.LatestAppStatus.State),
|
||||
State: appStatusStateToTaskState(ws.LatestAppStatus.State),
|
||||
Message: ws.LatestAppStatus.Message,
|
||||
URI: ws.LatestAppStatus.URI,
|
||||
}
|
||||
|
||||
Generated
+55
-7
@@ -495,16 +495,31 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/unarchive": {
|
||||
"post": {
|
||||
"/chats/{chat}/desktop": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Unarchive a chat",
|
||||
"operationId": "unarchive-chat",
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -897,6 +912,28 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/profile": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Collect debug profiles",
|
||||
"operationId": "collect-debug-profiles",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/tailnet": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -12454,6 +12491,7 @@ const docTemplate = `{
|
||||
"type": "boolean"
|
||||
},
|
||||
"inject_coder_mcp_tools": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"max_concurrency": {
|
||||
@@ -14340,7 +14378,6 @@ const docTemplate = `{
|
||||
"codersdk.CreateUserRequestWithOrgs": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"email",
|
||||
"username"
|
||||
],
|
||||
"properties": {
|
||||
@@ -14370,6 +14407,10 @@ const docTemplate = `{
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
"service_account": {
|
||||
"description": "Service accounts are admin-managed accounts that cannot login.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"user_status": {
|
||||
"description": "UserStatus defaults to UserStatusDormant.",
|
||||
"allOf": [
|
||||
@@ -15297,6 +15338,10 @@ const docTemplate = `{
|
||||
"codersdk.ExternalAuthConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_base_url": {
|
||||
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
|
||||
"type": "string"
|
||||
},
|
||||
"app_install_url": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -15335,12 +15380,15 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_tool_allow_regex": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_tool_deny_regex": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_url": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"no_refresh": {
|
||||
|
||||
Generated
+54
-7
@@ -422,14 +422,29 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/chats/{chat}/unarchive": {
|
||||
"post": {
|
||||
"/chats/{chat}/desktop": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Unarchive a chat",
|
||||
"operationId": "unarchive-chat",
|
||||
"summary": "Watch chat desktop",
|
||||
"operationId": "watch-chat-desktop",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Chat ID",
|
||||
"name": "chat",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
"101": {
|
||||
"description": "Switching Protocols"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -776,6 +791,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/profile": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Collect debug profiles",
|
||||
"operationId": "collect-debug-profiles",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/tailnet": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -11062,6 +11097,7 @@
|
||||
"type": "boolean"
|
||||
},
|
||||
"inject_coder_mcp_tools": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"max_concurrency": {
|
||||
@@ -12880,7 +12916,7 @@
|
||||
},
|
||||
"codersdk.CreateUserRequestWithOrgs": {
|
||||
"type": "object",
|
||||
"required": ["email", "username"],
|
||||
"required": ["username"],
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
@@ -12908,6 +12944,10 @@
|
||||
"password": {
|
||||
"type": "string"
|
||||
},
|
||||
"service_account": {
|
||||
"description": "Service accounts are admin-managed accounts that cannot login.",
|
||||
"type": "boolean"
|
||||
},
|
||||
"user_status": {
|
||||
"description": "UserStatus defaults to UserStatusDormant.",
|
||||
"allOf": [
|
||||
@@ -13816,6 +13856,10 @@
|
||||
"codersdk.ExternalAuthConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_base_url": {
|
||||
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
|
||||
"type": "string"
|
||||
},
|
||||
"app_install_url": {
|
||||
"type": "string"
|
||||
},
|
||||
@@ -13854,12 +13898,15 @@
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_tool_allow_regex": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_tool_deny_regex": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"mcp_url": {
|
||||
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
|
||||
"type": "string"
|
||||
},
|
||||
"no_refresh": {
|
||||
|
||||
+11
-11
@@ -48,8 +48,8 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.EqualValues(t, len(keys), 1)
|
||||
require.Contains(t, res.Key, keys[0].ID)
|
||||
// expires_at should default to 30 days
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
|
||||
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
|
||||
require.Len(t, keys[0].AllowList, 1)
|
||||
require.Equal(t, "*:*", keys[0].AllowList[0].String())
|
||||
@@ -194,8 +194,8 @@ func TestUserSetTokenDuration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*6*24))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*8*24))
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*6*24))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*8*24))
|
||||
}
|
||||
|
||||
func TestDefaultTokenDuration(t *testing.T) {
|
||||
@@ -210,8 +210,8 @@ func TestDefaultTokenDuration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
keys, err := client.Tokens(ctx, codersdk.Me, codersdk.TokensFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
|
||||
require.Greater(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, dbtime.Now().Add(time.Hour*24*8))
|
||||
}
|
||||
|
||||
func TestTokenUserSetMaxLifetime(t *testing.T) {
|
||||
@@ -518,7 +518,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is not expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.After(time.Now()))
|
||||
require.True(t, key.ExpiresAt.After(dbtime.Now()))
|
||||
|
||||
auditor.ResetLogs()
|
||||
|
||||
@@ -529,7 +529,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is expired.
|
||||
key, err = adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
|
||||
// Verify audit log.
|
||||
als := auditor.AuditLogs()
|
||||
@@ -556,7 +556,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify the token is expired.
|
||||
key, err := memberClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
})
|
||||
|
||||
t.Run("MemberCannotExpireOtherUsersToken", func(t *testing.T) {
|
||||
@@ -607,7 +607,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Invariant: make sure it's actually expired
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.LessOrEqual(t, key.ExpiresAt, time.Now(), "key should be expired")
|
||||
require.LessOrEqual(t, key.ExpiresAt, dbtime.Now(), "key should be expired")
|
||||
|
||||
// Expire it again - should succeed (idempotent).
|
||||
err = adminClient.ExpireAPIKey(ctx, codersdk.Me, keyID)
|
||||
@@ -636,7 +636,7 @@ func TestExpireAPIKey(t *testing.T) {
|
||||
// Verify it's expired.
|
||||
key, err := adminClient.APIKeyByID(ctx, codersdk.Me, keyID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, key.ExpiresAt.Before(time.Now()))
|
||||
require.True(t, key.ExpiresAt.Before(dbtime.Now()))
|
||||
|
||||
// Delete the expired token - should succeed.
|
||||
err = adminClient.DeleteAPIKey(ctx, codersdk.Me, keyID)
|
||||
|
||||
@@ -240,9 +240,7 @@ func (c *Compressor) serveRef(w http.ResponseWriter, r *http.Request, headers ht
|
||||
}
|
||||
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
w.Header()[key] = values
|
||||
}
|
||||
w.Header().Set("Content-Encoding", cref.key.encoding)
|
||||
w.Header().Add("Vary", "Accept-Encoding")
|
||||
|
||||
@@ -155,6 +155,41 @@ type nopEncoder struct {
|
||||
|
||||
func (nopEncoder) Close() error { return nil }
|
||||
|
||||
func TestCompressorPresetHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
tempDir := t.TempDir()
|
||||
cacheDir := filepath.Join(tempDir, "cache")
|
||||
err := os.MkdirAll(cacheDir, 0o700)
|
||||
require.NoError(t, err)
|
||||
srcDir := filepath.Join(tempDir, "src")
|
||||
err = os.MkdirAll(srcDir, 0o700)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(srcDir, "file.html"), []byte("textstring"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
compressor := NewCompressor(logger, 5, cacheDir, http.FS(os.DirFS(srcDir)))
|
||||
|
||||
for range 2 {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := httptest.NewRequestWithContext(ctx, "GET", "/file.html", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
respRec := httptest.NewRecorder()
|
||||
respRec.Header().Set("X-Original-Content-Length", "10")
|
||||
respRec.Header().Set("ETag", `"abc123"`)
|
||||
|
||||
compressor.ServeHTTP(respRec, req)
|
||||
resp := respRec.Result()
|
||||
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, []string{"10"}, resp.Header.Values("X-Original-Content-Length"))
|
||||
require.Equal(t, []string{`"abc123"`}, resp.Header.Values("ETag"))
|
||||
require.NoError(t, resp.Body.Close())
|
||||
}
|
||||
}
|
||||
|
||||
// nolint: tparallel // we want to assert the state of the cache, so run synchronously
|
||||
func TestCompressorHeadings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package chatcost
|
||||
|
||||
import (
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Returns cost in micros -- millionths of a dollar, rounded up to the next
|
||||
// whole microdollar.
|
||||
// Returns nil when pricing is not configured or when all priced usage fields
|
||||
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
|
||||
func CalculateTotalCostMicros(
|
||||
usage codersdk.ChatMessageUsage,
|
||||
cost *codersdk.ModelCostConfig,
|
||||
) *int64 {
|
||||
if cost == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// A cost config with no prices set means pricing is effectively
|
||||
// unconfigured — return nil (unpriced) rather than zero.
|
||||
if cost.InputPricePerMillionTokens == nil &&
|
||||
cost.OutputPricePerMillionTokens == nil &&
|
||||
cost.CacheReadPricePerMillionTokens == nil &&
|
||||
cost.CacheWritePricePerMillionTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if usage.InputTokens == nil &&
|
||||
usage.OutputTokens == nil &&
|
||||
usage.ReasoningTokens == nil &&
|
||||
usage.CacheCreationTokens == nil &&
|
||||
usage.CacheReadTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OutputTokens already includes reasoning tokens per provider
|
||||
// semantics (e.g. OpenAI's completion_tokens encompasses
|
||||
// reasoning_tokens). Adding ReasoningTokens here would
|
||||
// double-count.
|
||||
|
||||
// Preserve nil when usage exists only in categories without configured
|
||||
// pricing, so callers can distinguish "unpriced" from "priced at zero".
|
||||
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
|
||||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
|
||||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
|
||||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
|
||||
if !hasMatchingPrice {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
|
||||
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
|
||||
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
|
||||
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
|
||||
|
||||
total := inputMicros.
|
||||
Add(outputMicros).
|
||||
Add(cacheReadMicros).
|
||||
Add(cacheWriteMicros)
|
||||
rounded := total.Ceil().IntPart()
|
||||
return &rounded
|
||||
}
|
||||
|
||||
// calcCost returns the cost in fractional microdollars (millionths of a USD)
|
||||
// for the given token count at the specified per-million-token price.
|
||||
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
|
||||
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package chatcost_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
usage codersdk.ChatMessageUsage
|
||||
cost *codersdk.ModelCostConfig
|
||||
want *int64
|
||||
}{
|
||||
{
|
||||
name: "nil cost returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "all priced usage fields nil returns nil",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
TotalTokens: ptr.Ref[int64](1234),
|
||||
ContextLimit: ptr.Ref[int64](8192),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "sub-micro total rounds up to 1",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
|
||||
},
|
||||
want: ptr.Ref[int64](1),
|
||||
},
|
||||
{
|
||||
name: "simple input only",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "simple output only",
|
||||
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "reasoning tokens included in output total",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
OutputTokens: ptr.Ref[int64](500),
|
||||
ReasoningTokens: ptr.Ref[int64](200),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "cache read tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "cache creation tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
|
||||
},
|
||||
want: ptr.Ref[int64](18750),
|
||||
},
|
||||
{
|
||||
name: "full mixed usage totals all components exactly",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](101),
|
||||
OutputTokens: ptr.Ref[int64](201),
|
||||
ReasoningTokens: ptr.Ref[int64](52),
|
||||
CacheReadTokens: ptr.Ref[int64](1005),
|
||||
CacheCreationTokens: ptr.Ref[int64](33),
|
||||
TotalTokens: ptr.Ref[int64](1391),
|
||||
ContextLimit: ptr.Ref[int64](4096),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
|
||||
},
|
||||
want: ptr.Ref[int64](2005),
|
||||
},
|
||||
{
|
||||
name: "partial pricing only input contributes",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](1234),
|
||||
OutputTokens: ptr.Ref[int64](999),
|
||||
ReasoningTokens: ptr.Ref[int64](111),
|
||||
CacheReadTokens: ptr.Ref[int64](500),
|
||||
CacheCreationTokens: ptr.Ref[int64](250),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
|
||||
},
|
||||
want: ptr.Ref[int64](3085),
|
||||
},
|
||||
{
|
||||
name: "zero tokens with pricing returns zero pointer",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](0),
|
||||
},
|
||||
{
|
||||
name: "usage only in unpriced categories returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "non nil usage with empty cost config returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
||||
cost: &codersdk.ModelCostConfig{},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
||||
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+1046
-417
File diff suppressed because it is too large
Load Diff
+614
-76
@@ -6,21 +6,26 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -29,6 +34,8 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
proto "github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -48,7 +55,7 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-me",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -73,10 +80,11 @@ func TestInterruptChatBroadcastsStatusAcrossInstances(t *testing.T) {
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case event := <-events:
|
||||
if event.Type != codersdk.ChatStreamEventTypeStatus || event.Status == nil {
|
||||
return false
|
||||
if event.Type == codersdk.ChatStreamEventTypeStatus && event.Status != nil {
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
}
|
||||
return event.Status.Status == codersdk.ChatStatusWaiting
|
||||
t.Logf("skipping unexpected event: type=%s", event.Type)
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -175,7 +183,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
if got.Chat.Status != codersdk.ChatStatusWaiting && got.Chat.Status != codersdk.ChatStatusError {
|
||||
if got.Status != codersdk.ChatStatusWaiting && got.Status != codersdk.ChatStatusError {
|
||||
return false
|
||||
}
|
||||
// Also ensure the subagent LLM call has been made.
|
||||
@@ -250,7 +258,7 @@ func TestInterruptChatClearsWorkerInDatabase(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "db-transition",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -286,7 +294,7 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "heartbeat-ownership",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -328,7 +336,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "queue-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -344,7 +352,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "queued"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -366,7 +374,7 @@ func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -379,7 +387,7 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-when-busy",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -394,30 +402,35 @@ func TestSendMessageInterruptBehaviorSendsImmediatelyWhenBusy(t *testing.T) {
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "interrupt"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("interrupt")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Queued)
|
||||
require.Equal(t, database.ChatStatusPending, result.Chat.Status)
|
||||
require.False(t, result.Chat.WorkerID.Valid)
|
||||
|
||||
// The message should be queued, not inserted directly.
|
||||
require.True(t, result.Queued)
|
||||
require.NotNil(t, result.QueuedMessage)
|
||||
|
||||
// The chat should transition to waiting (interrupt signal),
|
||||
// not pending.
|
||||
require.Equal(t, database.ChatStatusWaiting, result.Chat.Status)
|
||||
|
||||
fromDB, err := db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, fromDB.Status)
|
||||
require.False(t, fromDB.WorkerID.Valid)
|
||||
require.Equal(t, database.ChatStatusWaiting, fromDB.Status)
|
||||
|
||||
// The message should be in the queue, not in chat_messages.
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queued, 0)
|
||||
require.Len(t, queued, 1)
|
||||
|
||||
// Only the initial user message should be in chat_messages.
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
require.Equal(t, messages[len(messages)-1].ID, result.Message.ID)
|
||||
require.Len(t, messages, 1)
|
||||
}
|
||||
|
||||
func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
@@ -433,7 +446,7 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "edit-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "original"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("original")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -447,20 +460,24 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
|
||||
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "follow-up"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow-up")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "another"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("another")},
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("queued"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
||||
ChatID: chat.ID,
|
||||
Content: json.RawMessage(`"queued"`),
|
||||
Content: queuedContent,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -476,7 +493,7 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: editedMessageID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, editedMessageID, editResult.Message.ID)
|
||||
@@ -521,14 +538,14 @@ func TestEditMessageRejectsMissingMessage(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "missing-edited-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: 999999,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotFound))
|
||||
@@ -547,18 +564,21 @@ func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "non-user-edited-message",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assistantMessage, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`"assistant"`),
|
||||
Valid: true,
|
||||
},
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: assistantContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
@@ -574,7 +594,7 @@ func TestEditMessageRejectsNonUserMessage(t *testing.T) {
|
||||
_, err = replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: assistantMessage.ID,
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: "edited"}},
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, chatd.ErrEditedMessageNotUser))
|
||||
@@ -821,7 +841,7 @@ func TestSubscribeSnapshotIncludesStatusEvent(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "status-snapshot",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -850,7 +870,7 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "no-dup-parts",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -865,15 +885,15 @@ func TestSubscribeNoPubsubNoDuplicateMessageParts(t *testing.T) {
|
||||
// events — the snapshot already contained everything. Before
|
||||
// the fix, localSnapshot was replayed into the channel,
|
||||
// causing duplicates.
|
||||
select {
|
||||
case event, ok := <-events:
|
||||
if ok {
|
||||
t.Fatalf("unexpected event from channel (would be a duplicate): type=%s", event.Type)
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-events:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// Channel closed without events is fine.
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// No events — correct behavior.
|
||||
}
|
||||
}, 200*time.Millisecond, testutil.IntervalFast,
|
||||
"expected no duplicate events after snapshot")
|
||||
}
|
||||
|
||||
func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
@@ -890,17 +910,23 @@ func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "after-id-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "first"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("first")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert two more messages so we have three total visible
|
||||
// messages (the initial user message plus these two).
|
||||
secondContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("second"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg2, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"second"`), Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: secondContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
@@ -913,11 +939,17 @@ func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
thirdContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("third"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: "user",
|
||||
Content: pqtype.NullRawMessage{RawMessage: json.RawMessage(`"third"`), Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
Content: thirdContent,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{},
|
||||
OutputTokens: sql.NullInt64{},
|
||||
@@ -946,7 +978,7 @@ func TestSubscribeAfterMessageID(t *testing.T) {
|
||||
|
||||
partialMessages := filterMessageEvents(partialSnapshot)
|
||||
require.Len(t, partialMessages, 1, "afterMessageID=msg2.ID should return only messages after msg2")
|
||||
require.Equal(t, "user", partialMessages[0].Message.Role)
|
||||
require.Equal(t, codersdk.ChatMessageRoleUser, partialMessages[0].Message.Role)
|
||||
}
|
||||
|
||||
// filterMessageEvents returns only the Message-type events from a
|
||||
@@ -1049,33 +1081,36 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var chatWithMessages codersdk.ChatWithMessages
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatWithMessages = got
|
||||
return got.Chat.Status == codersdk.ChatStatusWaiting || got.Chat.Status == codersdk.ChatStatusError
|
||||
chatResult = got
|
||||
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
if chatWithMessages.Chat.Status == codersdk.ChatStatusError {
|
||||
if chatResult.Status == codersdk.ChatStatusError {
|
||||
lastError := ""
|
||||
if chatWithMessages.Chat.LastError != nil {
|
||||
lastError = *chatWithMessages.Chat.LastError
|
||||
if chatResult.LastError != nil {
|
||||
lastError = *chatResult.LastError
|
||||
}
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
||||
}
|
||||
|
||||
require.NotNil(t, chatWithMessages.Chat.WorkspaceID)
|
||||
workspaceID := *chatWithMessages.Chat.WorkspaceID
|
||||
require.NotNil(t, chatResult.WorkspaceID)
|
||||
workspaceID := *chatResult.WorkspaceID
|
||||
workspace, err := client.Workspace(ctx, workspaceID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspaceName, workspace.Name)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var foundCreateWorkspaceResult bool
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
if message.Role != "tool" {
|
||||
for _, message := range chatMsgs.Messages {
|
||||
if message.Role != codersdk.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
@@ -1217,34 +1252,37 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var chatWithMessages codersdk.ChatWithMessages
|
||||
var chatResult codersdk.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := client.GetChat(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatWithMessages = got
|
||||
return got.Chat.Status == codersdk.ChatStatusWaiting || got.Chat.Status == codersdk.ChatStatusError
|
||||
chatResult = got
|
||||
return got.Status == codersdk.ChatStatusWaiting || got.Status == codersdk.ChatStatusError
|
||||
}, testutil.WaitSuperLong, testutil.IntervalFast)
|
||||
|
||||
if chatWithMessages.Chat.Status == codersdk.ChatStatusError {
|
||||
if chatResult.Status == codersdk.ChatStatusError {
|
||||
lastError := ""
|
||||
if chatWithMessages.Chat.LastError != nil {
|
||||
lastError = *chatWithMessages.Chat.LastError
|
||||
if chatResult.LastError != nil {
|
||||
lastError = *chatResult.LastError
|
||||
}
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", lastError)
|
||||
}
|
||||
|
||||
// Verify the workspace was started.
|
||||
require.NotNil(t, chatWithMessages.Chat.WorkspaceID)
|
||||
require.NotNil(t, chatResult.WorkspaceID)
|
||||
updatedWorkspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.WorkspaceTransitionStart, updatedWorkspace.LatestBuild.Transition)
|
||||
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify start_workspace tool result exists in the chat messages.
|
||||
var foundStartWorkspaceResult bool
|
||||
for _, message := range chatWithMessages.Messages {
|
||||
if message.Role != "tool" {
|
||||
for _, message := range chatMsgs.Messages {
|
||||
if message.Role != codersdk.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
for _, part := range message.Content {
|
||||
@@ -1419,7 +1457,7 @@ func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
||||
OwnerID: user.ID,
|
||||
Title: "interrupt-no-push",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1462,13 +1500,26 @@ func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) {
|
||||
// mockWebpushDispatcher implements webpush.Dispatcher and records Dispatch calls.
|
||||
type mockWebpushDispatcher struct {
|
||||
dispatchCount atomic.Int32
|
||||
mu sync.Mutex
|
||||
lastMessage codersdk.WebpushMessage
|
||||
lastUserID uuid.UUID
|
||||
}
|
||||
|
||||
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, _ uuid.UUID, _ codersdk.WebpushMessage) error {
|
||||
func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, msg codersdk.WebpushMessage) error {
|
||||
m.dispatchCount.Add(1)
|
||||
m.mu.Lock()
|
||||
m.lastMessage = msg
|
||||
m.lastUserID = userID
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockWebpushDispatcher) getLastMessage() codersdk.WebpushMessage {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastMessage
|
||||
}
|
||||
|
||||
func (*mockWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1477,6 +1528,78 @@ func (*mockWebpushDispatcher) PublicKey() string {
|
||||
return "test-vapid-public-key"
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushWithNavigationData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Set up a mock OpenAI that returns a simple successful response.
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("done")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Mock webpush dispatcher that captures the dispatched message.
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "push-nav-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to complete and return to waiting status.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
fromDB, dbErr := db.GetChatByID(ctx, chat.ID)
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
return fromDB.Status == database.ChatStatusWaiting && !fromDB.WorkerID.Valid && mockPush.dispatchCount.Load() == 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Verify a web push notification was dispatched exactly once.
|
||||
require.Equal(t, int32(1), mockPush.dispatchCount.Load(),
|
||||
"expected exactly one web push dispatch for a completed chat")
|
||||
|
||||
// Verify the notification was sent to the correct user.
|
||||
mockPush.mu.Lock()
|
||||
capturedMsg := mockPush.lastMessage
|
||||
capturedUserID := mockPush.lastUserID
|
||||
mockPush.mu.Unlock()
|
||||
|
||||
require.Equal(t, user.ID, capturedUserID,
|
||||
"web push should be dispatched to the chat owner")
|
||||
|
||||
// Verify the Data field contains the correct navigation URL.
|
||||
expectedURL := fmt.Sprintf("/agents/%s", chat.ID)
|
||||
require.Equal(t, expectedURL, capturedMsg.Data["url"],
|
||||
"web push Data should contain the chat navigation URL")
|
||||
}
|
||||
|
||||
func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1486,6 +1609,12 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T)
|
||||
var requestCount atomic.Int32
|
||||
streamStarted := make(chan struct{})
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
// Ignore non-streaming requests (e.g. title generation) so
|
||||
// they don't interfere with the request counter used to
|
||||
// coordinate the streaming chat flow.
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("shutdown-retry")
|
||||
}
|
||||
if requestCount.Add(1) == 1 {
|
||||
chunks := make(chan chattest.OpenAIChunk, 1)
|
||||
go func() {
|
||||
@@ -1523,7 +1652,7 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T)
|
||||
OwnerID: user.ID,
|
||||
Title: "shutdown-retry",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "hello"}},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1583,3 +1712,412 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T)
|
||||
!fromDB.LastError.Valid
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
const assistantText = "I have completed the task successfully and all tests are passing now."
|
||||
const summaryText = "Completed task and verified all tests pass."
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
// Non-streaming calls are used for title
|
||||
// generation and push summary generation.
|
||||
// Return the summary text for both — the title
|
||||
// result is irrelevant to this test.
|
||||
return chattest.OpenAINonStreamingResponse(summaryText)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks(assistantText)...,
|
||||
)
|
||||
})
|
||||
|
||||
mockPush := &mockWebpushDispatcher{}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
WebpushDispatcher: mockPush,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
setOpenAIProviderBaseURL(ctx, t, db, openAIURL)
|
||||
|
||||
_, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "summary-push-test",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do the thing")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The push notification is dispatched asynchronously after the
|
||||
// chat finishes, so we poll for it rather than checking
|
||||
// immediately after the status transitions to waiting.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
return mockPush.dispatchCount.Load() >= 1
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
msg := mockPush.getLastMessage()
|
||||
require.Equal(t, summaryText, msg.Body,
|
||||
"push body should be the LLM-generated summary")
|
||||
require.NotEqual(t, "Agent has finished running.", msg.Body,
|
||||
"push body should not use the default fallback text")
|
||||
}
|
||||
|
||||
func TestComputerUseSubagentToolsAndModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track tools and model from the Anthropic LLM calls (the
|
||||
// computer use child chat). We use a raw HTTP handler because
|
||||
// the chattest AnthropicRequest struct does not capture tools.
|
||||
type anthropicCall struct {
|
||||
Model string
|
||||
Tools []string
|
||||
}
|
||||
var anthropicMu sync.Mutex
|
||||
var anthropicCalls []anthropicCall
|
||||
|
||||
anthropicSrv := httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
names := make([]string, len(req.Tools))
|
||||
for i, tool := range req.Tools {
|
||||
names[i] = tool.Name
|
||||
}
|
||||
anthropicMu.Lock()
|
||||
anthropicCalls = append(anthropicCalls, anthropicCall{
|
||||
Model: req.Model,
|
||||
Tools: names,
|
||||
})
|
||||
anthropicMu.Unlock()
|
||||
|
||||
if !req.Stream {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "msg-test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": chattool.ComputerUseModelName,
|
||||
"content": []map[string]any{{"type": "text", "text": "Done."}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": map[string]any{"input_tokens": 10, "output_tokens": 5},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Stream a minimal Anthropic SSE response.
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
chunks := []map[string]any{
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg-test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": chattool.ComputerUseModelName,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]any{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]any{
|
||||
"type": "text_delta",
|
||||
"text": "Done.",
|
||||
},
|
||||
},
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{"stop_reason": "end_turn"},
|
||||
"usage": map[string]any{"output_tokens": 5},
|
||||
},
|
||||
{"type": "message_stop"},
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkBytes, _ := json.Marshal(chunk)
|
||||
eventType, _ := chunk["type"].(string)
|
||||
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n",
|
||||
eventType, chunkBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
},
|
||||
))
|
||||
t.Cleanup(anthropicSrv.Close)
|
||||
|
||||
// OpenAI mock for the root chat. The first streaming call
|
||||
// triggers spawn_computer_use_agent; subsequent calls reply
|
||||
// with text.
|
||||
var openAICallCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("title")
|
||||
}
|
||||
if openAICallCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"spawn_computer_use_agent",
|
||||
`{"prompt":"do the desktop thing","title":"cu-sub"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
// Seed the DB: user, openai-compat provider, model config.
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai-compat",
|
||||
DisplayName: "OpenAI Compat",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: openAIURL,
|
||||
CreatedBy: uuid.NullUUID{},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{},
|
||||
UpdatedBy: uuid.NullUUID{},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an Anthropic provider pointing to our mock server.
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
APIKey: "test-anthropic-key",
|
||||
BaseUrl: anthropicSrv.URL,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build workspace + agent records so getWorkspaceConn can
|
||||
// resolve the agent for the computer use child.
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
CreatedBy: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
ActiveVersionID: tv.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
TemplateID: tpl.ID,
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
InitiatorID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
TemplateVersionID: tv.ID,
|
||||
WorkspaceID: ws.ID,
|
||||
JobID: pj.ID,
|
||||
})
|
||||
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
JobID: pj.ID,
|
||||
})
|
||||
dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: res.ID,
|
||||
})
|
||||
|
||||
// Mock agent connection that returns valid display dimensions
|
||||
// for the initial screenshot check in the computer use path.
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
mockConn.EXPECT().
|
||||
ExecuteDesktopAction(gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.DesktopActionResponse{
|
||||
ScreenshotWidth: 1920,
|
||||
ScreenshotHeight: 1080,
|
||||
ScreenshotData: "iVBOR",
|
||||
}, nil).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
SetExtraHeaders(gomock.Any()).
|
||||
AnyTimes()
|
||||
mockConn.EXPECT().
|
||||
LS(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(workspacesdk.LSResponse{}, xerrors.New("not found")).
|
||||
AnyTimes()
|
||||
|
||||
agentConnFn := func(
|
||||
_ context.Context, agentID uuid.UUID,
|
||||
) (workspacesdk.AgentConn, func(), error) {
|
||||
require.Equal(t, dbAgent.ID, agentID)
|
||||
return mockConn, func() {}, nil
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: 10 * time.Millisecond,
|
||||
InFlightChatStaleAfter: testutil.WaitSuperLong,
|
||||
AgentConn: agentConnFn,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// Create a root chat with a workspace so the child inherits it.
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "computer-use-detection",
|
||||
ModelConfigID: model.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Use the desktop to check the UI"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the root chat AND the computer use child to finish.
|
||||
// The root chat spawns the child, then the chatd server picks
|
||||
// up and runs the child (which hits the Anthropic mock).
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
if got.Status != database.ChatStatusWaiting &&
|
||||
got.Status != database.ChatStatusError {
|
||||
return false
|
||||
}
|
||||
// Ensure the Anthropic mock received at least one call.
|
||||
anthropicMu.Lock()
|
||||
n := len(anthropicCalls)
|
||||
anthropicMu.Unlock()
|
||||
return n >= 1
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
anthropicMu.Lock()
|
||||
calls := append([]anthropicCall(nil), anthropicCalls...)
|
||||
anthropicMu.Unlock()
|
||||
|
||||
require.NotEmpty(t, calls,
|
||||
"expected at least one Anthropic LLM call")
|
||||
|
||||
childModel := calls[0].Model
|
||||
childTools := calls[0].Tools
|
||||
|
||||
// 1. Verify the model is the computer use model.
|
||||
require.Equal(t, chattool.ComputerUseModelName, childModel,
|
||||
"computer use subagent should use %s",
|
||||
chattool.ComputerUseModelName)
|
||||
|
||||
// 2. Verify the computer tool is present.
|
||||
require.Contains(t, childTools, "computer",
|
||||
"computer use subagent should have the computer tool")
|
||||
|
||||
// 3. Verify standard workspace tools are present (the same
|
||||
// set a regular subagent gets).
|
||||
standardTools := []string{
|
||||
"read_file", "write_file", "edit_files", "execute",
|
||||
"process_output", "process_list", "process_signal",
|
||||
}
|
||||
for _, tool := range standardTools {
|
||||
require.Contains(t, childTools, tool,
|
||||
"computer use subagent should have standard tool %q",
|
||||
tool)
|
||||
}
|
||||
|
||||
// 4. Verify workspace provisioning tools are NOT present.
|
||||
workspaceProvisioningTools := []string{
|
||||
"list_templates", "read_template",
|
||||
"create_workspace", "start_workspace",
|
||||
}
|
||||
for _, tool := range workspaceProvisioningTools {
|
||||
require.NotContains(t, childTools, tool,
|
||||
"computer use subagent should NOT have workspace "+
|
||||
"provisioning tool %q", tool)
|
||||
}
|
||||
|
||||
// 5. Verify subagent tools are NOT present.
|
||||
subagentTools := []string{
|
||||
"spawn_agent", "spawn_computer_use_agent",
|
||||
"wait_agent", "message_agent", "close_agent",
|
||||
}
|
||||
for _, tool := range subagentTools {
|
||||
require.NotContains(t, childTools, tool,
|
||||
"computer use subagent should NOT have subagent "+
|
||||
"tool %q", tool)
|
||||
}
|
||||
|
||||
// 6. Verify the child chat has Mode = computer_use in
|
||||
// the DB.
|
||||
allChats, err := db.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
var children []database.Chat
|
||||
for _, c := range allChats {
|
||||
if c.ParentChatID.Valid && c.ParentChatID.UUID == chat.ID {
|
||||
children = append(children, c)
|
||||
}
|
||||
}
|
||||
require.Len(t, children, 1)
|
||||
require.True(t, children[0].Mode.Valid)
|
||||
require.Equal(t, database.ChatModeComputerUse,
|
||||
children[0].Mode.ChatMode)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -62,9 +63,16 @@ type RunOptions struct {
|
||||
// of the provider, which lives in chatd, not chatloop.
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
|
||||
// ProviderTools are provider-native tools (like web search
|
||||
// and computer use) whose definitions are passed directly
|
||||
// to the provider API. When a ProviderTool has a non-nil
|
||||
// Runner, tool calls are executed locally; otherwise the
|
||||
// provider handles execution (e.g. web search).
|
||||
ProviderTools []ProviderTool
|
||||
|
||||
PersistStep func(context.Context, PersistedStep) error
|
||||
PublishMessagePart func(
|
||||
role fantasy.MessageRole,
|
||||
role codersdk.ChatMessageRole,
|
||||
part codersdk.ChatMessagePart,
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
@@ -73,12 +81,24 @@ type RunOptions struct {
|
||||
// OnRetry is called before each retry attempt when the LLM
|
||||
// stream fails with a retryable error. It provides the attempt
|
||||
// number, error, and backoff delay so callers can publish status
|
||||
// events to connected clients.
|
||||
// events to connected clients. Callers should also clear any
|
||||
// buffered stream state from the failed attempt in this callback
|
||||
// to avoid sending duplicated content.
|
||||
OnRetry chatretry.OnRetryFn
|
||||
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
// ProviderTool pairs a provider-native tool definition with an
|
||||
// optional local executor. When Runner is nil the tool is fully
|
||||
// provider-executed (e.g. web search). When Runner is non-nil
|
||||
// the definition is sent to the API but execution is handled
|
||||
// locally (e.g. computer use).
|
||||
type ProviderTool struct {
|
||||
Definition fantasy.Tool
|
||||
Runner fantasy.AgentTool
|
||||
}
|
||||
|
||||
// stepResult holds the accumulated output of a single streaming
|
||||
// step. Since we own the stream consumer, all content is tracked
|
||||
// directly here — no shadow draft state needed.
|
||||
@@ -149,11 +169,23 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolParts = append(toolParts, fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
Output: result.Result,
|
||||
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
|
||||
})
|
||||
part := fantasy.ToolResultPart{
|
||||
ToolCallID: result.ToolCallID,
|
||||
Output: result.Result,
|
||||
ProviderExecuted: result.ProviderExecuted,
|
||||
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
|
||||
}
|
||||
// Provider-executed tool results (e.g. web_search)
|
||||
// must stay in the assistant message so the result
|
||||
// block appears inline after the corresponding
|
||||
// server_tool_use block. This matches the persistence
|
||||
// layer in chatd.go which keeps them in
|
||||
// assistantBlocks.
|
||||
if result.ProviderExecuted {
|
||||
assistantParts = append(assistantParts, part)
|
||||
} else {
|
||||
toolParts = append(toolParts, part)
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
@@ -195,20 +227,24 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
|
||||
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
return
|
||||
}
|
||||
opts.PublishMessagePart(role, part)
|
||||
}
|
||||
|
||||
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools)
|
||||
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools, opts.ProviderTools)
|
||||
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
|
||||
|
||||
messages := opts.Messages
|
||||
var lastUsage fantasy.Usage
|
||||
var lastProviderMetadata fantasy.ProviderMetadata
|
||||
|
||||
totalSteps := 0
|
||||
// When totalSteps reaches MaxSteps the inner loop exits immediately
|
||||
// (its condition is false), stoppedByModel stays false, and the
|
||||
// post-loop guard breaks the outer compaction loop.
|
||||
for compactionAttempt := 0; ; compactionAttempt++ {
|
||||
alreadyCompacted := false
|
||||
// stoppedByModel is true when the inner step loop
|
||||
@@ -222,7 +258,8 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// agent never had a chance to use the compacted context.
|
||||
compactedOnFinalStep := false
|
||||
|
||||
for step := 0; step < opts.MaxSteps; step++ {
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -289,17 +326,29 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleTool,
|
||||
codersdk.ChatMessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
})
|
||||
for _, tr := range toolResults {
|
||||
result.content = append(result.content, tr)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for interruption after tool execution.
|
||||
// Tools that were canceled mid-flight produce error
|
||||
// results via ctx cancellation. Persist the full
|
||||
// step (assistant blocks + tool results) through
|
||||
// the interrupt-safe path so nothing is lost.
|
||||
if ctx.Err() != nil {
|
||||
if errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
return ErrInterrupted
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
// Extract context limit from provider metadata.
|
||||
contextLimit := extractContextLimit(result.providerMetadata)
|
||||
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
|
||||
@@ -308,19 +357,30 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the step — errors propagate directly.
|
||||
// Persist the step. If persistence fails because
|
||||
// the chat was interrupted between the previous
|
||||
// check and here, fall back to the interrupt-safe
|
||||
// path so partial content is not lost.
|
||||
if err := opts.PersistStep(ctx, PersistedStep{
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
return ErrInterrupted
|
||||
}
|
||||
return xerrors.Errorf("persist step: %w", err)
|
||||
}
|
||||
|
||||
lastUsage = result.usage
|
||||
lastProviderMetadata = result.providerMetadata
|
||||
|
||||
// Append the step's response messages so that both
|
||||
// inline and post-loop compaction see the full
|
||||
// conversation including the latest assistant reply.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
|
||||
// Inline compaction.
|
||||
if opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, compactErr := tryCompact(
|
||||
@@ -354,17 +414,11 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// The agent is continuing with tool calls, so any
|
||||
// prior compaction has already been consumed.
|
||||
compactedOnFinalStep = false
|
||||
|
||||
// Build messages from the step for the next iteration.
|
||||
// toResponseMessages produces assistant-role content
|
||||
// (text, reasoning, tool calls) and tool-result content.
|
||||
stepMessages := result.toResponseMessages()
|
||||
messages = append(messages, stepMessages...)
|
||||
}
|
||||
|
||||
// Post-run compaction safety net: if we never compacted
|
||||
// during the loop, try once at the end.
|
||||
if !alreadyCompacted && opts.Compaction != nil {
|
||||
if !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil {
|
||||
did, err := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
@@ -383,7 +437,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
compactedOnFinalStep = true
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enter the step loop when compaction fired on the
|
||||
// model's final step. This lets the agent continue
|
||||
// working with fresh summarized context instead of
|
||||
@@ -413,7 +466,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
func processStepStream(
|
||||
ctx context.Context,
|
||||
stream fantasy.StreamResponse,
|
||||
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
|
||||
publishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart),
|
||||
) (stepResult, error) {
|
||||
var result stepResult
|
||||
|
||||
@@ -422,27 +475,6 @@ func processStepStream(
|
||||
activeReasoningContent := make(map[string]reasoningState)
|
||||
// Track tool names by ID for input delta publishing.
|
||||
toolNames := make(map[string]string)
|
||||
// Track reasoning text/titles for title extraction.
|
||||
reasoningTitles := make(map[string]string)
|
||||
reasoningText := make(map[string]string)
|
||||
|
||||
setReasoningTitleFromText := func(id string, text string) {
|
||||
if id == "" || strings.TrimSpace(text) == "" {
|
||||
return
|
||||
}
|
||||
if reasoningTitles[id] != "" {
|
||||
return
|
||||
}
|
||||
reasoningText[id] += text
|
||||
if !strings.ContainsAny(reasoningText[id], "\r\n") {
|
||||
return
|
||||
}
|
||||
title := chatprompt.ReasoningTitleFromFirstLine(reasoningText[id])
|
||||
if title == "" {
|
||||
return
|
||||
}
|
||||
reasoningTitles[id] = title
|
||||
}
|
||||
|
||||
for part := range stream {
|
||||
switch part.Type {
|
||||
@@ -453,10 +485,7 @@ func processStepStream(
|
||||
if _, exists := activeTextContent[part.ID]; exists {
|
||||
activeTextContent[part.ID] += part.Delta
|
||||
}
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: part.Delta,
|
||||
})
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(part.Delta))
|
||||
|
||||
case fantasy.StreamPartTypeTextEnd:
|
||||
if text, exists := activeTextContent[part.ID]; exists {
|
||||
@@ -479,13 +508,7 @@ func processStepStream(
|
||||
active.options = part.ProviderMetadata
|
||||
activeReasoningContent[part.ID] = active
|
||||
}
|
||||
setReasoningTitleFromText(part.ID, part.Delta)
|
||||
title := reasoningTitles[part.ID]
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: part.Delta,
|
||||
Title: title,
|
||||
})
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageReasoning(part.Delta))
|
||||
|
||||
case fantasy.StreamPartTypeReasoningEnd:
|
||||
if active, exists := activeReasoningContent[part.ID]; exists {
|
||||
@@ -498,23 +521,7 @@ func processStepStream(
|
||||
}
|
||||
result.content = append(result.content, content)
|
||||
delete(activeReasoningContent, part.ID)
|
||||
|
||||
// Derive reasoning title at end of reasoning
|
||||
// block if we haven't yet.
|
||||
if reasoningTitles[part.ID] == "" {
|
||||
reasoningTitles[part.ID] = chatprompt.ReasoningTitleFromFirstLine(
|
||||
reasoningText[part.ID],
|
||||
)
|
||||
}
|
||||
title := reasoningTitles[part.ID]
|
||||
if title != "" {
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Title: title,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case fantasy.StreamPartTypeToolInputStart:
|
||||
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
|
||||
ToolCallID: part.ID,
|
||||
@@ -527,17 +534,19 @@ func processStepStream(
|
||||
}
|
||||
|
||||
case fantasy.StreamPartTypeToolInputDelta:
|
||||
var providerExecuted bool
|
||||
if toolCall, exists := activeToolCalls[part.ID]; exists {
|
||||
toolCall.Input += part.Delta
|
||||
providerExecuted = toolCall.ProviderExecuted
|
||||
}
|
||||
toolName := toolNames[part.ID]
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: part.ID,
|
||||
ToolName: toolName,
|
||||
ArgsDelta: part.Delta,
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: part.ID,
|
||||
ToolName: toolName,
|
||||
ArgsDelta: part.Delta,
|
||||
ProviderExecuted: providerExecuted,
|
||||
})
|
||||
|
||||
case fantasy.StreamPartTypeToolInputEnd:
|
||||
// No callback needed; the full tool call arrives in
|
||||
// StreamPartTypeToolCall.
|
||||
@@ -559,7 +568,7 @@ func processStepStream(
|
||||
delete(activeToolCalls, part.ID)
|
||||
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
chatprompt.PartFromContent(tc),
|
||||
)
|
||||
|
||||
@@ -573,10 +582,28 @@ func processStepStream(
|
||||
}
|
||||
result.content = append(result.content, sourceContent)
|
||||
publishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
chatprompt.PartFromContent(sourceContent),
|
||||
)
|
||||
|
||||
case fantasy.StreamPartTypeToolResult:
|
||||
// Provider-executed tool results (e.g. web search)
|
||||
// are emitted by the provider and added directly
|
||||
// to the step content for multi-turn round-tripping.
|
||||
// This mirrors fantasy's agent.go accumulation logic.
|
||||
if part.ProviderExecuted {
|
||||
tr := fantasy.ToolResultContent{
|
||||
ToolCallID: part.ID,
|
||||
ToolName: part.ToolCallName,
|
||||
ProviderExecuted: part.ProviderExecuted,
|
||||
ProviderMetadata: part.ProviderMetadata,
|
||||
}
|
||||
result.content = append(result.content, tr)
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
}
|
||||
case fantasy.StreamPartTypeFinish:
|
||||
result.usage = part.Usage
|
||||
result.finishReason = part.FinishReason
|
||||
@@ -604,17 +631,26 @@ func processStepStream(
|
||||
}
|
||||
}
|
||||
|
||||
result.shouldContinue = len(result.toolCalls) > 0 &&
|
||||
hasLocalToolCalls := false
|
||||
for _, tc := range result.toolCalls {
|
||||
if !tc.ProviderExecuted {
|
||||
hasLocalToolCalls = true
|
||||
break
|
||||
}
|
||||
}
|
||||
result.shouldContinue = hasLocalToolCalls &&
|
||||
result.finishReason == fantasy.FinishReasonToolCalls
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// executeTools runs each tool call sequentially after the stream
|
||||
// completes. Results are published via onResult as each tool
|
||||
// finishes.
|
||||
// executeTools runs all tool calls concurrently after the stream
|
||||
// completes. Results are published via onResult in the original
|
||||
// tool-call order after all tools finish, preserving deterministic
|
||||
// event ordering for SSE subscribers.
|
||||
func executeTools(
|
||||
ctx context.Context,
|
||||
allTools []fantasy.AgentTool,
|
||||
providerTools []ProviderTool,
|
||||
toolCalls []fantasy.ToolCallContent,
|
||||
onResult func(fantasy.ToolResultContent),
|
||||
) []fantasy.ToolResultContent {
|
||||
@@ -622,16 +658,58 @@ func executeTools(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter out provider-executed tool calls. These were
|
||||
// handled server-side by the LLM provider (e.g., web
|
||||
// search) and their results are already in the stream
|
||||
// content.
|
||||
localToolCalls := make([]fantasy.ToolCallContent, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
if !tc.ProviderExecuted {
|
||||
localToolCalls = append(localToolCalls, tc)
|
||||
}
|
||||
}
|
||||
if len(localToolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
toolMap := make(map[string]fantasy.AgentTool, len(allTools))
|
||||
for _, t := range allTools {
|
||||
toolMap[t.Info().Name] = t
|
||||
}
|
||||
// Include runners from provider tools so locally-executed
|
||||
// provider tools (e.g. computer use) can be dispatched.
|
||||
for _, pt := range providerTools {
|
||||
if pt.Runner != nil {
|
||||
toolMap[pt.Runner.Info().Name] = pt.Runner
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]fantasy.ToolResultContent, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
tr := executeSingleTool(ctx, toolMap, tc)
|
||||
results = append(results, tr)
|
||||
if onResult != nil {
|
||||
results := make([]fantasy.ToolResultContent, len(localToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(localToolCalls))
|
||||
for i, tc := range localToolCalls {
|
||||
go func(i int, tc fantasy.ToolCallContent) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
results[i] = fantasy.ToolResultContent{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
Result: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.Errorf("tool panicked: %v", r),
|
||||
},
|
||||
}
|
||||
}
|
||||
}()
|
||||
results[i] = executeSingleTool(ctx, toolMap, tc)
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Publish results in the original tool-call order so SSE
|
||||
// subscribers see a deterministic event sequence.
|
||||
if onResult != nil {
|
||||
for _, tr := range results {
|
||||
onResult(tr)
|
||||
}
|
||||
}
|
||||
@@ -781,8 +859,9 @@ func persistInterruptedStep(
|
||||
continue
|
||||
}
|
||||
content = append(content, fantasy.ToolResultContent{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
ProviderExecuted: tc.ProviderExecuted,
|
||||
Result: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New(interruptedToolResultErrorMessage),
|
||||
},
|
||||
@@ -802,15 +881,17 @@ func persistInterruptedStep(
|
||||
|
||||
// buildToolDefinitions converts AgentTool definitions into the
|
||||
// fantasy.Tool slice expected by fantasy.Call. When activeTools
|
||||
// is non-empty, only tools whose name appears in the list are
|
||||
// included. This mirrors fantasy's agent.prepareTools filtering.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools))
|
||||
// is non-empty, only function tools whose name appears in the
|
||||
// list are included. Provider tool definitions are always
|
||||
// appended unconditionally.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []ProviderTool) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools))
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
inputSchema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
@@ -824,6 +905,9 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fan
|
||||
ProviderOptions: tool.ProviderOptions(),
|
||||
})
|
||||
}
|
||||
for _, pt := range providerTools {
|
||||
prepared = append(prepared, pt.Definition)
|
||||
}
|
||||
return prepared
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package chatloop //nolint:testpackage // Uses internal symbols.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
@@ -405,6 +407,174 @@ func TestRun_PersistStepErrorPropagates(t *testing.T) {
|
||||
require.ErrorContains(t, err, "database write failed")
|
||||
}
|
||||
|
||||
// TestRun_ShutdownDuringToolExecutionReturnsContextCanceled verifies that
|
||||
// when the parent context is canceled (simulating server shutdown) while
|
||||
// a tool is blocked, Run returns context.Canceled — not ErrInterrupted.
|
||||
// This matters because the caller uses the error type to decide whether
|
||||
// to set chat status to "pending" (retryable on another worker) vs
|
||||
// "waiting" (stuck forever).
|
||||
func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
toolStarted := make(chan struct{})
|
||||
|
||||
// Model returns a single tool call, then finishes.
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-block", ToolCallName: "blocking_tool"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-block", Delta: `{}`},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-block"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "tc-block",
|
||||
ToolCallName: "blocking_tool",
|
||||
ToolCallInput: `{}`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
// Tool that blocks until its context is canceled, simulating
|
||||
// a long-running operation like wait_agent.
|
||||
blockingTool := fantasy.NewAgentTool(
|
||||
"blocking_tool",
|
||||
"blocks until context canceled",
|
||||
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
close(toolStarted)
|
||||
<-ctx.Done()
|
||||
return fantasy.ToolResponse{}, ctx.Err()
|
||||
},
|
||||
)
|
||||
|
||||
// Simulate the server context (parent) and chat context
|
||||
// (child). Canceling the parent simulates graceful shutdown.
|
||||
serverCtx, serverCancel := context.WithCancel(context.Background())
|
||||
defer serverCancel()
|
||||
|
||||
serverCancelDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverCancelDone)
|
||||
<-toolStarted
|
||||
t.Logf("tool started, canceling server context to simulate shutdown")
|
||||
serverCancel()
|
||||
}()
|
||||
|
||||
// persistStep mirrors the FIXED chatd.go code: it only returns
|
||||
// ErrInterrupted when the context was actually canceled due to
|
||||
// an interruption (cause is ErrInterrupted). For shutdown
|
||||
// (plain context.Canceled), it returns the original error so
|
||||
// callers can distinguish the two.
|
||||
persistStep := func(persistCtx context.Context, _ PersistedStep) error {
|
||||
if persistCtx.Err() != nil {
|
||||
if errors.Is(context.Cause(persistCtx), ErrInterrupted) {
|
||||
return ErrInterrupted
|
||||
}
|
||||
return persistCtx.Err()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := Run(serverCtx, RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "run the blocking tool"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{blockingTool},
|
||||
MaxSteps: 3,
|
||||
PersistStep: persistStep,
|
||||
})
|
||||
// Wait for the cancel goroutine to finish to aid flake
|
||||
// diagnosis if the test ever hangs.
|
||||
<-serverCancelDone
|
||||
|
||||
require.Error(t, err)
|
||||
// The error must NOT be ErrInterrupted — it should propagate
|
||||
// as context.Canceled so the caller can distinguish shutdown
|
||||
// from user interruption. Use assert (not require) so both
|
||||
// checks are evaluated even if the first fails.
|
||||
assert.NotErrorIs(t, err, ErrInterrupted, "shutdown cancellation must not be converted to ErrInterrupted")
|
||||
assert.ErrorIs(t, err, context.Canceled, "shutdown should propagate as context.Canceled")
|
||||
}
|
||||
|
||||
func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sr := stepResult{
|
||||
content: []fantasy.Content{
|
||||
// Provider-executed tool call (e.g. web_search).
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "provider-tc-1",
|
||||
ToolName: "web_search",
|
||||
Input: `{"query":"coder"}`,
|
||||
ProviderExecuted: true,
|
||||
},
|
||||
// Provider-executed tool result — must stay in
|
||||
// assistant message.
|
||||
fantasy.ToolResultContent{
|
||||
ToolCallID: "provider-tc-1",
|
||||
ToolName: "web_search",
|
||||
ProviderExecuted: true,
|
||||
ProviderMetadata: fantasy.ProviderMetadata{"anthropic": nil},
|
||||
},
|
||||
// Local tool call (e.g. read_file).
|
||||
fantasy.ToolCallContent{
|
||||
ToolCallID: "local-tc-1",
|
||||
ToolName: "read_file",
|
||||
Input: `{"path":"main.go"}`,
|
||||
ProviderExecuted: false,
|
||||
},
|
||||
// Local tool result — should go into tool message.
|
||||
fantasy.ToolResultContent{
|
||||
ToolCallID: "local-tc-1",
|
||||
ToolName: "read_file",
|
||||
Result: fantasy.ToolResultOutputContentText{Text: "some result"},
|
||||
ProviderExecuted: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgs := sr.toResponseMessages()
|
||||
require.Len(t, msgs, 2, "expected assistant + tool messages")
|
||||
|
||||
// First message: assistant role.
|
||||
assistantMsg := msgs[0]
|
||||
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
|
||||
require.Len(t, assistantMsg.Content, 3,
|
||||
"assistant message should have provider ToolCallPart, provider ToolResultPart, and local ToolCallPart")
|
||||
|
||||
// Part 0: provider tool call.
|
||||
providerTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[0])
|
||||
require.True(t, ok, "part 0 should be ToolCallPart")
|
||||
assert.Equal(t, "provider-tc-1", providerTC.ToolCallID)
|
||||
assert.True(t, providerTC.ProviderExecuted)
|
||||
|
||||
// Part 1: provider tool result (inline in assistant turn).
|
||||
providerTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](assistantMsg.Content[1])
|
||||
require.True(t, ok, "part 1 should be ToolResultPart")
|
||||
assert.Equal(t, "provider-tc-1", providerTR.ToolCallID)
|
||||
assert.True(t, providerTR.ProviderExecuted)
|
||||
|
||||
// Part 2: local tool call.
|
||||
localTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[2])
|
||||
require.True(t, ok, "part 2 should be ToolCallPart")
|
||||
assert.Equal(t, "local-tc-1", localTC.ToolCallID)
|
||||
assert.False(t, localTC.ProviderExecuted)
|
||||
|
||||
// Second message: tool role.
|
||||
toolMsg := msgs[1]
|
||||
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
|
||||
require.Len(t, toolMsg.Content, 1,
|
||||
"tool message should have only the local ToolResultPart")
|
||||
|
||||
localTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
|
||||
require.True(t, ok, "tool part should be ToolResultPart")
|
||||
assert.Equal(t, "local-tc-1", localTR.ToolCallID)
|
||||
assert.False(t, localTR.ProviderExecuted)
|
||||
}
|
||||
|
||||
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
|
||||
if len(message.ProviderOptions) == 0 {
|
||||
return false
|
||||
@@ -418,3 +588,179 @@ func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
|
||||
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
|
||||
return ok && cacheOptions.CacheControl.Type == "ephemeral"
|
||||
}
|
||||
|
||||
// TestRun_InterruptedDuringToolExecutionPersistsStep verifies that when
|
||||
// tools are executing and the chat is interrupted, the accumulated step
|
||||
// content (assistant blocks + tool results) is persisted via the
|
||||
// interrupt-safe path rather than being lost.
|
||||
func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
toolStarted := make(chan struct{})
|
||||
|
||||
// Model returns a completed tool call in the stream.
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"},
|
||||
{Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "let me think"},
|
||||
{Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-1"},
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "slow_tool"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"key":"value"}`},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "tc-1",
|
||||
ToolCallName: "slow_tool",
|
||||
ToolCallInput: `{"key":"value"}`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
// Tool that blocks until context is canceled, simulating
|
||||
// a long-running operation interrupted by the user.
|
||||
slowTool := fantasy.NewAgentTool(
|
||||
"slow_tool",
|
||||
"blocks until canceled",
|
||||
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
close(toolStarted)
|
||||
<-ctx.Done()
|
||||
return fantasy.ToolResponse{}, ctx.Err()
|
||||
},
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(nil)
|
||||
|
||||
go func() {
|
||||
<-toolStarted
|
||||
cancel(ErrInterrupted)
|
||||
}()
|
||||
|
||||
var persistedContent []fantasy.Content
|
||||
persistedCtxErr := xerrors.New("unset")
|
||||
|
||||
err := Run(ctx, RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "run the slow tool"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{slowTool},
|
||||
MaxSteps: 3,
|
||||
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
|
||||
persistedCtxErr = persistCtx.Err()
|
||||
persistedContent = append([]fantasy.Content(nil), step.Content...)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ErrInterrupted)
|
||||
// persistInterruptedStep uses context.WithoutCancel, so the
|
||||
// persist callback should see a non-canceled context.
|
||||
require.NoError(t, persistedCtxErr)
|
||||
require.NotEmpty(t, persistedContent)
|
||||
|
||||
var (
|
||||
foundText bool
|
||||
foundReasoning bool
|
||||
foundToolCall bool
|
||||
foundToolResult bool
|
||||
)
|
||||
for _, block := range persistedContent {
|
||||
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
|
||||
if strings.Contains(text.Text, "calling tool") {
|
||||
foundText = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](block); ok {
|
||||
if strings.Contains(reasoning.Text, "let me think") {
|
||||
foundReasoning = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
|
||||
if toolCall.ToolCallID == "tc-1" && toolCall.ToolName == "slow_tool" {
|
||||
foundToolCall = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
|
||||
if toolResult.ToolCallID == "tc-1" {
|
||||
foundToolResult = true
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundText, "persisted content should include text from the stream")
|
||||
require.True(t, foundReasoning, "persisted content should include reasoning from the stream")
|
||||
require.True(t, foundToolCall, "persisted content should include the tool call")
|
||||
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
|
||||
}
|
||||
|
||||
// TestRun_PersistStepInterruptedFallback verifies that when the normal
|
||||
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
|
||||
// race), the step is retried via the interrupt-safe path.
|
||||
func TestRun_PersistStepInterruptedFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
persistCalls int
|
||||
savedContent []fantasy.Content
|
||||
)
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
persistCalls++
|
||||
if persistCalls == 1 {
|
||||
// First call: simulate an interrupt race by
|
||||
// returning ErrInterrupted without persisting.
|
||||
return ErrInterrupted
|
||||
}
|
||||
// Second call (from persistInterruptedStep fallback):
|
||||
// accept the content.
|
||||
savedContent = append([]fantasy.Content(nil), step.Content...)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ErrInterrupted)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
require.Equal(t, 2, persistCalls, "PersistStep should be called twice: once normally (failing), once via fallback")
|
||||
require.NotEmpty(t, savedContent)
|
||||
|
||||
var foundText bool
|
||||
for _, block := range savedContent {
|
||||
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
|
||||
if strings.Contains(text.Text, "hello world") {
|
||||
foundText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
require.True(t, foundText, "fallback should persist the text content")
|
||||
}
|
||||
|
||||
@@ -17,13 +17,26 @@ const (
|
||||
minCompactionThresholdPercent = int32(0)
|
||||
maxCompactionThresholdPercent = int32(100)
|
||||
|
||||
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
|
||||
"new assistant can continue seamlessly. Include the user's goals, " +
|
||||
"decisions made, concrete technical details (files, commands, APIs), " +
|
||||
"errors encountered and fixes, and open questions. Be dense and factual. " +
|
||||
"Omit pleasantries and next-step suggestions."
|
||||
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
|
||||
defaultCompactionTimeout = 90 * time.Second
|
||||
defaultCompactionSummaryPrompt = "You are performing a context compaction. " +
|
||||
"Summarize the conversation so a new assistant can seamlessly " +
|
||||
"continue the work in progress.\n\n" +
|
||||
"Include:\n" +
|
||||
"- The user's overall goal and current task\n" +
|
||||
"- Key decisions made and their rationale\n" +
|
||||
"- Concrete technical details: file paths, function names, " +
|
||||
"commands, APIs, and configurations\n" +
|
||||
"- Errors encountered and how they were resolved\n" +
|
||||
"- Current state of the work: what is DONE, what is IN PROGRESS, " +
|
||||
"and what REMAINS to be done\n" +
|
||||
"- The specific action the assistant was performing or about to " +
|
||||
"perform when this summary was triggered\n\n" +
|
||||
"Be dense and factual. Every sentence should convey essential " +
|
||||
"context for continuation. Do not include pleasantries or " +
|
||||
"conversational filler."
|
||||
defaultCompactionSystemSummaryPrefix = "The following is a summary of " +
|
||||
"the earlier conversation. The assistant was actively working when " +
|
||||
"the context was compacted. Continue the work described below:"
|
||||
defaultCompactionTimeout = 90 * time.Second
|
||||
)
|
||||
|
||||
type CompactionOptions struct {
|
||||
@@ -42,7 +55,7 @@ type CompactionOptions struct {
|
||||
// PublishMessagePart publishes streaming parts to connected
|
||||
// clients so they see "Summarizing..." / "Summarized" UI
|
||||
// transitions during compaction.
|
||||
PublishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart)
|
||||
PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart)
|
||||
|
||||
OnError func(error)
|
||||
}
|
||||
@@ -97,12 +110,8 @@ func tryCompact(
|
||||
// connected clients see activity during summary generation.
|
||||
if config.PublishMessagePart != nil && config.ToolCallID != "" {
|
||||
config.PublishMessagePart(
|
||||
fantasy.MessageRoleAssistant,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
},
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
codersdk.ChatMessageToolCall(config.ToolCallID, config.ToolName, nil),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -123,7 +132,8 @@ func tryCompact(
|
||||
config.SystemSummaryPrefix + "\n\n" + summary,
|
||||
)
|
||||
|
||||
err = config.Persist(ctx, CompactionResult{
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
err = config.Persist(persistCtx, CompactionResult{
|
||||
SystemSummary: systemSummary,
|
||||
SummaryReport: summary,
|
||||
ThresholdPercent: config.ThresholdPercent,
|
||||
@@ -149,13 +159,8 @@ func tryCompact(
|
||||
"context_limit_tokens": contextLimit,
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
fantasy.MessageRoleTool,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
Result: resultJSON,
|
||||
},
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -172,14 +177,8 @@ func publishCompactionError(config CompactionOptions, msg string) {
|
||||
"error": msg,
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
fantasy.MessageRoleTool,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
Result: errJSON,
|
||||
IsError: true,
|
||||
},
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -76,9 +76,20 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, persistCompactionCalls)
|
||||
// Compaction fires twice: once inline when the threshold is
|
||||
// reached on step 0 (the only step, since MaxSteps=1), and
|
||||
// once from the post-run safety net during the re-entry
|
||||
// iteration (where totalSteps already equals MaxSteps so the
|
||||
// inner loop doesn't execute, but lastUsage still exceeds
|
||||
// the threshold).
|
||||
require.Equal(t, 2, persistCompactionCalls)
|
||||
require.Contains(t, persistedCompaction.SystemSummary, summaryText)
|
||||
require.Equal(t, summaryText, persistedCompaction.SummaryReport)
|
||||
require.Equal(t, int64(80), persistedCompaction.ContextTokens)
|
||||
@@ -138,7 +149,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
SummaryPrompt: "summarize now",
|
||||
ToolCallID: "test-tool-call-id",
|
||||
ToolName: "chat_summarized",
|
||||
PublishMessagePart: func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeToolCall:
|
||||
callOrder = append(callOrder, "publish_tool_call")
|
||||
@@ -151,13 +162,25 @@ func TestRun_Compaction(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Compaction fires twice (see PersistsWhenThresholdReached
|
||||
// for the full explanation). Each cycle follows the order:
|
||||
// publish_tool_call → generate → persist → publish_tool_result.
|
||||
require.Equal(t, []string{
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
"publish_tool_call",
|
||||
"generate",
|
||||
"persist",
|
||||
"publish_tool_result",
|
||||
}, callOrder)
|
||||
})
|
||||
|
||||
@@ -195,7 +218,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
ThresholdPercent: 70,
|
||||
ToolCallID: "test-tool-call-id",
|
||||
ToolName: "chat_summarized",
|
||||
PublishMessagePart: func(_ fantasy.MessageRole, _ codersdk.ChatMessagePart) {
|
||||
PublishMessagePart: func(_ codersdk.ChatMessageRole, _ codersdk.ChatMessagePart) {
|
||||
publishCalled = true
|
||||
},
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
@@ -457,6 +480,11 @@ func TestRun_Compaction(t *testing.T) {
|
||||
compactionErr = err
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Error(t, compactionErr)
|
||||
@@ -572,4 +600,117 @@ func TestRun_Compaction(t *testing.T) {
|
||||
// Two stream calls: one before compaction, one after re-entry.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
})
|
||||
|
||||
t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// After compaction the summary is stored as a user-role
|
||||
// message. When the loop re-enters, the reloaded prompt
|
||||
// must contain this user message so the LLM provider
|
||||
// receives a valid prompt (providers like Anthropic
|
||||
// require at least one non-system message).
|
||||
|
||||
var mu sync.Mutex
|
||||
var streamCallCount int
|
||||
var reEntryPrompt []fantasy.Message
|
||||
persistCompactionCalls := 0
|
||||
|
||||
const summaryText = "post-run compacted summary"
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCallCount
|
||||
streamCallCount++
|
||||
mu.Unlock()
|
||||
|
||||
switch step {
|
||||
case 0:
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
default:
|
||||
mu.Lock()
|
||||
reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...)
|
||||
mu.Unlock()
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-2"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 20,
|
||||
TotalTokens: 25,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate real post-compaction DB state: the summary is
|
||||
// a user-role message (the only non-system content).
|
||||
compactedMessages := []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleSystem, "system prompt"),
|
||||
textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"),
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return compactedMessages, nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, persistCompactionCalls, 1)
|
||||
// Re-entry happened: stream was called at least twice.
|
||||
require.Equal(t, 2, streamCallCount)
|
||||
// The re-entry prompt must contain the user summary.
|
||||
require.NotEmpty(t, reEntryPrompt)
|
||||
hasUser := false
|
||||
for _, msg := range reEntryPrompt {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
hasUser = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -553,30 +553,33 @@ func normalizedEnumValue(value string, allowed ...string) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeMissingCallConfig fills unset call config values from defaults.
|
||||
func MergeMissingCallConfig(
|
||||
dst *codersdk.ChatModelCallConfig,
|
||||
defaults codersdk.ChatModelCallConfig,
|
||||
// MergeMissingModelCostConfig fills unset pricing metadata from defaults.
|
||||
func MergeMissingModelCostConfig(
|
||||
dst **codersdk.ModelCostConfig,
|
||||
defaults *codersdk.ModelCostConfig,
|
||||
) {
|
||||
if dst.MaxOutputTokens == nil {
|
||||
dst.MaxOutputTokens = defaults.MaxOutputTokens
|
||||
if defaults == nil {
|
||||
return
|
||||
}
|
||||
if dst.Temperature == nil {
|
||||
dst.Temperature = defaults.Temperature
|
||||
if *dst == nil {
|
||||
copied := *defaults
|
||||
*dst = &copied
|
||||
return
|
||||
}
|
||||
if dst.TopP == nil {
|
||||
dst.TopP = defaults.TopP
|
||||
|
||||
current := *dst
|
||||
if current.InputPricePerMillionTokens == nil {
|
||||
current.InputPricePerMillionTokens = defaults.InputPricePerMillionTokens
|
||||
}
|
||||
if dst.TopK == nil {
|
||||
dst.TopK = defaults.TopK
|
||||
if current.OutputPricePerMillionTokens == nil {
|
||||
current.OutputPricePerMillionTokens = defaults.OutputPricePerMillionTokens
|
||||
}
|
||||
if dst.PresencePenalty == nil {
|
||||
dst.PresencePenalty = defaults.PresencePenalty
|
||||
if current.CacheReadPricePerMillionTokens == nil {
|
||||
current.CacheReadPricePerMillionTokens = defaults.CacheReadPricePerMillionTokens
|
||||
}
|
||||
if dst.FrequencyPenalty == nil {
|
||||
dst.FrequencyPenalty = defaults.FrequencyPenalty
|
||||
if current.CacheWritePricePerMillionTokens == nil {
|
||||
current.CacheWritePricePerMillionTokens = defaults.CacheWritePricePerMillionTokens
|
||||
}
|
||||
MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions)
|
||||
}
|
||||
|
||||
// MergeMissingProviderOptions fills unset provider option fields from defaults.
|
||||
@@ -885,11 +888,14 @@ func MergeMissingProviderOptions(
|
||||
}
|
||||
|
||||
// ModelFromConfig resolves a provider/model pair and constructs a fantasy
|
||||
// language model client using the provided provider credentials.
|
||||
// language model client using the provided provider credentials. The
|
||||
// userAgent is sent as the User-Agent header on every outgoing LLM
|
||||
// API request.
|
||||
func ModelFromConfig(
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys ProviderAPIKeys,
|
||||
userAgent string,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
provider, modelID, err := ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if err != nil {
|
||||
@@ -907,6 +913,7 @@ func ModelFromConfig(
|
||||
case fantasyanthropic.Name:
|
||||
options := []fantasyanthropic.Option{
|
||||
fantasyanthropic.WithAPIKey(apiKey),
|
||||
fantasyanthropic.WithUserAgent(userAgent),
|
||||
}
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyanthropic.WithBaseURL(baseURL))
|
||||
@@ -920,12 +927,17 @@ func ModelFromConfig(
|
||||
fantasyazure.WithAPIKey(apiKey),
|
||||
fantasyazure.WithBaseURL(baseURL),
|
||||
fantasyazure.WithUseResponsesAPI(),
|
||||
fantasyazure.WithUserAgent(userAgent),
|
||||
)
|
||||
case fantasybedrock.Name:
|
||||
providerClient, err = fantasybedrock.New(fantasybedrock.WithAPIKey(apiKey))
|
||||
providerClient, err = fantasybedrock.New(
|
||||
fantasybedrock.WithAPIKey(apiKey),
|
||||
fantasybedrock.WithUserAgent(userAgent),
|
||||
)
|
||||
case fantasygoogle.Name:
|
||||
options := []fantasygoogle.Option{
|
||||
fantasygoogle.WithGeminiAPIKey(apiKey),
|
||||
fantasygoogle.WithUserAgent(userAgent),
|
||||
}
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasygoogle.WithBaseURL(baseURL))
|
||||
@@ -935,6 +947,7 @@ func ModelFromConfig(
|
||||
options := []fantasyopenai.Option{
|
||||
fantasyopenai.WithAPIKey(apiKey),
|
||||
fantasyopenai.WithUseResponsesAPI(),
|
||||
fantasyopenai.WithUserAgent(userAgent),
|
||||
}
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenai.WithBaseURL(baseURL))
|
||||
@@ -943,16 +956,21 @@ func ModelFromConfig(
|
||||
case fantasyopenaicompat.Name:
|
||||
options := []fantasyopenaicompat.Option{
|
||||
fantasyopenaicompat.WithAPIKey(apiKey),
|
||||
fantasyopenaicompat.WithUserAgent(userAgent),
|
||||
}
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyopenaicompat.WithBaseURL(baseURL))
|
||||
}
|
||||
providerClient, err = fantasyopenaicompat.New(options...)
|
||||
case fantasyopenrouter.Name:
|
||||
providerClient, err = fantasyopenrouter.New(fantasyopenrouter.WithAPIKey(apiKey))
|
||||
providerClient, err = fantasyopenrouter.New(
|
||||
fantasyopenrouter.WithAPIKey(apiKey),
|
||||
fantasyopenrouter.WithUserAgent(userAgent),
|
||||
)
|
||||
case fantasyvercel.Name:
|
||||
options := []fantasyvercel.Option{
|
||||
fantasyvercel.WithAPIKey(apiKey),
|
||||
fantasyvercel.WithUserAgent(userAgent),
|
||||
}
|
||||
if baseURL != "" {
|
||||
options = append(options, fantasyvercel.WithBaseURL(baseURL))
|
||||
|
||||
@@ -137,43 +137,6 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := codersdk.ChatModelCallConfig{
|
||||
Temperature: float64Ptr(0.2),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("alice"),
|
||||
},
|
||||
},
|
||||
}
|
||||
defaults := codersdk.ChatModelCallConfig{
|
||||
MaxOutputTokens: int64Ptr(512),
|
||||
Temperature: float64Ptr(0.9),
|
||||
TopP: float64Ptr(0.8),
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("bob"),
|
||||
ReasoningEffort: stringPtr("medium"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingCallConfig(&dst, defaults)
|
||||
|
||||
require.NotNil(t, dst.MaxOutputTokens)
|
||||
require.EqualValues(t, 512, *dst.MaxOutputTokens)
|
||||
require.NotNil(t, dst.Temperature)
|
||||
require.Equal(t, 0.2, *dst.Temperature)
|
||||
require.NotNil(t, dst.TopP)
|
||||
require.Equal(t, 0.8, *dst.TopP)
|
||||
require.NotNil(t, dst.ProviderOptions)
|
||||
require.NotNil(t, dst.ProviderOptions.OpenAI)
|
||||
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
|
||||
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
@@ -185,7 +148,3 @@ func boolPtr(value bool) *bool {
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package chatprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
)
|
||||
|
||||
// UserAgent returns the User-Agent string sent on all outgoing LLM
|
||||
// API requests made by Coder's built-in chat (chatd). The format
|
||||
// mirrors conventions used by other coding agents so that LLM
|
||||
// providers can identify traffic originating from Coder.
|
||||
//
|
||||
// Example: coder-agents/v2.21.0 (linux/amd64)
|
||||
func UserAgent() string {
|
||||
return fmt.Sprintf("coder-agents/%s (%s/%s)",
|
||||
buildinfo.Version(), runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package chatprovider_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattest"
|
||||
)
|
||||
|
||||
func TestUserAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
ua := chatprovider.UserAgent()
|
||||
|
||||
// Must start with "coder-agents/" so LLM providers can
|
||||
// identify traffic from Coder.
|
||||
require.True(t, strings.HasPrefix(ua, "coder-agents/"),
|
||||
"User-Agent should start with 'coder-agents/', got %q", ua)
|
||||
|
||||
// Must contain the build version.
|
||||
assert.Contains(t, ua, buildinfo.Version())
|
||||
|
||||
// Must contain OS/arch.
|
||||
assert.Contains(t, ua, runtime.GOOS+"/"+runtime.GOARCH)
|
||||
}
|
||||
|
||||
func TestModelFromConfig_UserAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var mu sync.Mutex
|
||||
var capturedUA string
|
||||
|
||||
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
mu.Lock()
|
||||
capturedUA = req.Header.Get("User-Agent")
|
||||
mu.Unlock()
|
||||
return chattest.OpenAINonStreamingResponse("hello")
|
||||
})
|
||||
|
||||
expectedUA := chatprovider.UserAgent()
|
||||
keys := chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{"openai": "test-key"},
|
||||
BaseURLByProvider: map[string]string{"openai": serverURL},
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make a real call so Fantasy sends an HTTP request to the
|
||||
// fake server, which captures the User-Agent header.
|
||||
_, err = model.Generate(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{Text: "hello"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mu.Lock()
|
||||
got := capturedUA
|
||||
mu.Unlock()
|
||||
|
||||
require.NotEmpty(t, got, "User-Agent header was not sent")
|
||||
require.Equal(t, expectedUA, got,
|
||||
"User-Agent header should match chatprovider.UserAgent()")
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,6 +20,12 @@ const (
|
||||
// MaxDelay is the upper bound for the exponential backoff
|
||||
// duration. Matches the cap used in coder/mux.
|
||||
MaxDelay = 60 * time.Second
|
||||
|
||||
// MaxAttempts is the upper bound on retry attempts before
|
||||
// giving up. With a 60s max backoff this allows roughly
|
||||
// 25 minutes of retries, which is reasonable for transient
|
||||
// LLM provider issues.
|
||||
MaxAttempts = 25
|
||||
)
|
||||
|
||||
// nonRetryablePatterns are substrings that indicate a permanent error
|
||||
@@ -131,9 +139,8 @@ type RetryFn func(ctx context.Context) error
|
||||
type OnRetryFn func(attempt int, err error, delay time.Duration)
|
||||
|
||||
// Retry calls fn repeatedly until it succeeds, returns a
|
||||
// non-retryable error, or ctx is canceled. There is no max attempt
|
||||
// limit — retries continue indefinitely with exponential backoff
|
||||
// (capped at 60s), matching the behavior of coder/mux.
|
||||
// non-retryable error, ctx is canceled, or MaxAttempts is reached.
|
||||
// Retries use exponential backoff capped at MaxDelay.
|
||||
//
|
||||
// The onRetry callback (if non-nil) is called before each retry
|
||||
// attempt, giving the caller a chance to reset state, log, or
|
||||
@@ -156,10 +163,15 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
delay := Delay(attempt)
|
||||
attempt++
|
||||
if attempt >= MaxAttempts {
|
||||
return xerrors.Errorf("max retry attempts (%d) exceeded: %w", MaxAttempts, err)
|
||||
}
|
||||
|
||||
delay := Delay(attempt - 1)
|
||||
|
||||
if onRetry != nil {
|
||||
onRetry(attempt+1, err, delay)
|
||||
onRetry(attempt, err, delay)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
@@ -169,7 +181,5 @@ func Retry(ctx context.Context, fn RetryFn, onRetry OnRetryFn) error {
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +96,7 @@ type AnthropicDeltaBlock struct {
|
||||
// anthropicServer is a test server that mocks the Anthropic API.
|
||||
type anthropicServer struct {
|
||||
mu sync.Mutex
|
||||
t testing.TB
|
||||
server *httptest.Server
|
||||
handler AnthropicHandler
|
||||
request *AnthropicRequest
|
||||
@@ -109,6 +110,7 @@ func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &anthropicServer{
|
||||
t: t,
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
@@ -143,7 +145,7 @@ func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(w, resp.Error)
|
||||
writeErrorResponse(s.t, w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -223,7 +225,6 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"type": resp.Type,
|
||||
@@ -241,7 +242,9 @@ func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
s.t.Errorf("writeNonStreamingResponse: failed to encode response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicStreamingResponse creates a streaming response from chunks.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user