Compare commits

..

1 Commits

Author SHA1 Message Date
Danielle Maywood 48fd7b372b refactor(site): extract AgentsLayout from AgentsPageView
Rename AgentsPageView to AgentsLayout to follow the convention used
by HealthLayout, TemplateLayout, and other layout components. Move
the scrollbar gutter lock effect from AgentsPage into the layout
so all viewport-level concerns live in one place.
2026-03-10 22:03:41 +00:00
426 changed files with 10157 additions and 37950 deletions
+1 -1
View File
@@ -113,7 +113,7 @@ Coder emphasizes clear error handling, with specific patterns required:
All tests should run in parallel using `t.Parallel()` to ensure efficient testing and expose potential race conditions. The codebase is rigorously linted with golangci-lint to maintain consistent code quality.
Git contributions follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Git contributions follow a standard format with commit messages structured as `type: <message>`, where type is one of `feat`, `fix`, or `chore`.
## Development Workflow
+14 -5
View File
@@ -4,13 +4,22 @@ This guide documents the PR description style used in the Coder repository, base
## PR Title Format
Format: `type(scope): description`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/) format:
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
```text
type(scope): brief description
```
Examples:
**Common types:**
- `feat`: New features
- `fix`: Bug fixes
- `refactor`: Code refactoring without behavior change
- `perf`: Performance improvements
- `docs`: Documentation changes
- `chore`: Dependency updates, tooling changes
**Examples:**
- `feat: add tracing to aibridge`
- `fix: move contexts to appropriate locations`
+3 -5
View File
@@ -136,11 +136,9 @@ Then make your changes and push normally. Don't use `git push --force` unless th
## Commit Style
Format: `type(scope): message`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
- Scopes must be a real path (directory or file stem) containing all changed files
- Omit scope if changes span multiple top-level directories
- Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/)
- Format: `type(scope): message`
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
- Keep message titles concise (~70 characters)
- Use imperative, present tense in commit titles
-1
View File
@@ -64,7 +64,6 @@ runs:
TEST_PACKAGES: ${{ inputs.test-packages }}
RACE_DETECTION: ${{ inputs.race-detection }}
TS_DEBUG_DISCO: "true"
TS_DEBUG_DERP: "true"
LC_CTYPE: "en_US.UTF-8"
LC_ALL: "en_US.UTF-8"
run: |
+7 -69
View File
@@ -1198,7 +1198,7 @@ jobs:
make -j \
build/coder_linux_{amd64,arm64,armv7} \
build/coder_"$version"_windows_amd64.zip \
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
build/coder_"$version"_linux_amd64.{tar.gz,deb}
env:
# The Windows and Darwin slim binaries must be signed for Coder
# Desktop to accept them.
@@ -1216,28 +1216,11 @@ jobs:
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
# Free up disk space before building Docker images. The preceding
# Build step produces ~2 GB of binaries and packages, the Go build
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
# builds, pushes, and SBOM generation need headroom that isn't
# available without reclaiming some of that space.
- name: Clean up build cache
run: |
set -euxo pipefail
# Go caches are no longer needed — binaries are already compiled.
go clean -cache -modcache
# Remove .apk and .rpm packages that are not uploaded as
# artifacts and were only built as make prerequisites.
rm -f ./build/*.apk ./build/*.rpm
- name: Build Linux Docker images
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
DOCKER_CLI_EXPERIMENTAL: "enabled"
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
# the Docker image targets — they were already built above.
DOCKER_IMAGE_NO_PREREQUISITES: "true"
run: |
set -euxo pipefail
@@ -1455,60 +1438,15 @@ jobs:
^v
prune-untagged: true
- name: Upload build artifact (coder-linux-amd64.tar.gz)
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.tar.gz
path: ./build/*_linux_amd64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-amd64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-amd64.deb
path: ./build/*_linux_amd64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-arm64.tar.gz
path: ./build/*_linux_arm64.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-arm64.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-arm64.deb
path: ./build/*_linux_arm64.deb
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.tar.gz)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-armv7.tar.gz
path: ./build/*_linux_armv7.tar.gz
retention-days: 7
- name: Upload build artifact (coder-linux-armv7.deb)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-linux-armv7.deb
path: ./build/*_linux_armv7.deb
retention-days: 7
- name: Upload build artifact (coder-windows-amd64.zip)
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: coder-windows-amd64.zip
path: ./build/*_windows_amd64.zip
name: coder
path: |
./build/*.zip
./build/*.tar.gz
./build/*.deb
retention-days: 7
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
-103
View File
@@ -45,109 +45,6 @@ jobs:
# Some users have signed a corporate CLA with Coder so are exempt from signing our community one.
allowlist: "coryb,aaronlehmann,dependabot*,blink-so*,blinkagent*"
title:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request_target' }}
steps:
- name: Validate PR title
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
const { pull_request } = context.payload;
const title = pull_request.title;
const repo = { owner: context.repo.owner, repo: context.repo.repo };
const allowedTypes = [
"feat", "fix", "docs", "style", "refactor",
"perf", "test", "build", "ci", "chore", "revert",
];
const expectedFormat = `"type(scope): description" or "type: description"`;
const guidelinesLink = `See: https://github.com/coder/coder/blob/main/docs/about/contributing/CONTRIBUTING.md#commit-messages`;
const scopeHint = (type) =>
`Use a broader scope or no scope (e.g., "${type}: ...") for cross-cutting changes.\n` +
guidelinesLink;
console.log("Title: %s", title);
// Parse conventional commit format: type(scope)!: description
const match = title.match(/^(\w+)(\(([^)]*)\))?(!)?\s*:\s*.+/);
if (!match) {
core.setFailed(
`PR title does not match conventional commit format.\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
const type = match[1];
const scope = match[3]; // undefined if no parentheses
// Validate type.
if (!allowedTypes.includes(type)) {
core.setFailed(
`PR title has invalid type "${type}".\n` +
`Expected: ${expectedFormat}\n` +
`Allowed types: ${allowedTypes.join(", ")}\n` +
guidelinesLink
);
return;
}
// If no scope, we're done.
if (!scope) {
console.log("No scope provided, title is valid.");
return;
}
console.log("Scope: %s", scope);
// Fetch changed files.
const files = await github.paginate(github.rest.pulls.listFiles, {
...repo,
pull_number: pull_request.number,
per_page: 100,
});
const changedPaths = files.map(f => f.filename);
console.log("Changed files: %d", changedPaths.length);
// Derive scope type from the changed files. The diff is the
// source of truth: if files exist under the scope, the path
// exists on the PR branch. No need for Contents API calls.
const isDir = changedPaths.some(f => f.startsWith(scope + "/"));
const isFile = changedPaths.some(f => f === scope);
const isStem = changedPaths.some(f => f.startsWith(scope + "."));
if (!isDir && !isFile && !isStem) {
core.setFailed(
`PR title scope "${scope}" does not match any files changed in this PR.\n` +
`Scopes must reference a path (directory or file stem) that contains changed files.\n` +
scopeHint(type)
);
return;
}
// Verify all changed files fall under the scope.
const outsideFiles = changedPaths.filter(f => {
if (isDir && f.startsWith(scope + "/")) return false;
if (f === scope) return false;
if (isStem && f.startsWith(scope + ".")) return false;
return true;
});
if (outsideFiles.length > 0) {
const listed = outsideFiles.map(f => " - " + f).join("\n");
core.setFailed(
`PR title scope "${scope}" does not contain all changed files.\n` +
`Files outside scope:\n${listed}\n\n` +
scopeHint(type)
);
return;
}
console.log("PR title is valid.");
release-labels:
runs-on: ubuntu-latest
permissions:
+15 -11
View File
@@ -61,7 +61,7 @@ jobs:
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
permissions:
contents: read
id-token: write # to authenticate to EKS cluster
id-token: write
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
@@ -82,23 +82,27 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Configure AWS Credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
role-to-assume: ${{ vars.AWS_DOGFOOD_DEPLOY_ROLE }}
aws-region: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Get Cluster Credentials
run: aws eks update-kubeconfig --name "$AWS_DOGFOOD_CLUSTER_NAME" --region "$AWS_DOGFOOD_DEPLOY_REGION"
env:
AWS_DOGFOOD_CLUSTER_NAME: ${{ vars.AWS_DOGFOOD_CLUSTER_NAME }}
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.8.2"
version: "2.7.0"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
# Retag image as dogfood while maintaining the multi-arch manifest
- name: Tag image as dogfood
-16
View File
@@ -30,22 +30,6 @@ jobs:
with:
persist-credentials: false
- name: Rewrite same-repo links for PR branch
if: github.event_name == 'pull_request'
env:
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
run: |
# Rewrite same-repo blob/tree main links to the PR head SHA
# so that files or directories introduced in the PR are
# reachable during link checking.
{
echo 'replacementPatterns:'
echo " - pattern: \"https://github.com/coder/coder/blob/main/\""
echo " replacement: \"https://github.com/coder/coder/blob/${HEAD_SHA}/\""
echo " - pattern: \"https://github.com/coder/coder/tree/main/\""
echo " replacement: \"https://github.com/coder/coder/tree/${HEAD_SHA}/\""
} >> .github/.linkspector.yml
- name: Check Markdown links
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
id: markdown-link-check
+4 -44
View File
@@ -50,7 +50,7 @@ Only pause to ask for confirmation when:
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
| **Pre-commit** | `make pre-commit` | Fast CI checks (gen/fmt/lint/build) |
| **Pre-push** | `make pre-push` | Heavier CI checks (allowlisted) |
| **Pre-push** | `make pre-push` | All CI checks including tests |
### Documentation Commands
@@ -100,31 +100,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
```
### API Design
- Add swagger annotations when introducing new HTTP endpoints. Do this in
the same change as the handler so the docs do not get missed before
release.
- For user-scoped or resource-scoped routes, prefer path parameters over
query parameters when that matches existing route patterns.
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
### Database Query Naming
- Use `ByX` when `X` is the lookup or filter column.
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
dimension.
- Avoid `ByX` names for grouped queries.
### Database-to-SDK Conversions
- Extract explicit db-to-SDK conversion helpers instead of inlining large
conversion blocks inside handlers.
- Keep nullable-field handling, type coercion, and response shaping in the
converter so handlers stay focused on request flow and authorization.
## Quick Reference
### Full workflows available in imported WORKFLOWS.md
@@ -148,9 +123,9 @@ Two hooks run automatically:
- **pre-commit**: `make pre-commit` (gen, fmt, lint, typos, build).
Fast checks that catch most CI failures. Allow at least 5 minutes.
- **pre-push**: `make pre-push` (heavier checks including tests).
Allowlisted in `scripts/githooks/pre-push`. Runs only for developers
who opt in. Allow at least 15 minutes.
- **pre-push**: `make pre-push` (full CI suite including tests).
Runs before pushing to catch everything CI would. Allow at least
15 minutes (race tests are slow without cache).
`git commit` and `git push` will appear to hang while hooks run.
This is normal. Do not interrupt, retry, or reduce the timeout.
@@ -209,21 +184,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
### Frontend Patterns
- Prefer existing shared UI components and utilities over custom
implementations. Reuse common primitives such as loading, table, and error
handling components when they fit the use case.
- Use Storybook stories for all component and page testing, including
visual presentation, user interactions, keyboard navigation, focus
management, and accessibility behavior. Do not create standalone
vitest/RTL test files for components or pages. Stories double as living
documentation, visual regression coverage, and interaction test suites
via `play` functions. Reserve plain vitest files for pure logic only:
utility functions, data transformations, hooks tested via
`renderHook()` that do not require DOM assertions, and query/cache
operations with no rendered output.
### Writing Comments
Code comments should be clear, well-formatted, and add meaningful context.
+55 -48
View File
@@ -27,7 +27,6 @@ ifdef MAKE_TIMED
SHELL := $(CURDIR)/scripts/lib/timed-shell.sh
.SHELLFLAGS = $@ -ceu
export MAKE_TIMED
export MAKE_LOGDIR
endif
# This doesn't work on directories.
@@ -115,7 +114,7 @@ POSTGRES_VERSION ?= 17
POSTGRES_IMAGE ?= us-docker.pkg.dev/coder-v2-images-public/public/postgres:$(POSTGRES_VERSION)
# Limit parallel Make jobs in pre-commit/pre-push. Defaults to
# nproc/4 (min 2) since test, lint, and build targets have internal
# nproc/4 (min 2) since test and lint targets have internal
# parallelism. Override: make pre-push PARALLEL_JOBS=8
PARALLEL_JOBS ?= $(shell n=$$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 8); echo $$(( n / 4 > 2 ? n / 4 : 2 )))
@@ -514,14 +513,8 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
build/.bin/develop: go.mod go.sum $(GO_SRC_FILES)
CGO_ENABLED=0 go build -o $@ ./scripts/develop
BOLD := $(shell tput bold 2>/dev/null)
GREEN := $(shell tput setaf 2 2>/dev/null)
RED := $(shell tput setaf 1 2>/dev/null)
YELLOW := $(shell tput setaf 3 2>/dev/null)
DIM := $(shell tput dim 2>/dev/null || tput setaf 8 2>/dev/null)
RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/biome fmt/markdown
@@ -720,73 +713,89 @@ lint/typos: build/typos-$(TYPOS_VERSION)
build/typos-$(TYPOS_VERSION) --config .github/workflows/typos.toml
.PHONY: lint/typos
# pre-commit and pre-push mirror CI checks locally.
# pre-commit and pre-push mirror CI "required" jobs locally.
# See the "required" job's needs list in .github/workflows/ci.yaml.
#
# pre-commit runs checks that don't need external services (Docker,
# Playwright). This is the git pre-commit hook default since Docker
# and browser issues in the local environment would otherwise block
# Playwright). This is the git pre-commit hook default since test
# and Docker failures in the local environment would otherwise block
# all commits.
#
# pre-push adds heavier checks: Go tests, JS tests, and site build.
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
# pre-push runs the full CI suite including tests. This is the git
# pre-push hook default, catching everything CI would before pushing.
#
# pre-commit uses two phases: gen+fmt first, then lint+build. This
# avoids races where gen's `go run` creates temporary .go files that
# lint's find-based checks pick up. Within each phase, targets run in
# parallel via -j. It fails if any tracked files have unstaged
# changes afterward.
# pre-push uses two-phase execution: gen+fmt+test-postgres-docker
# first (writes files, starts Docker), then lint+build+test in
# parallel. pre-commit uses two phases: gen+fmt first, then
# lint+build. This avoids races where gen's `go run` creates
# temporary .go files that lint's find-based checks pick up.
# Within each phase, targets run in parallel via -j. Both fail if
# any tracked files have unstaged changes afterward.
#
# Both pre-commit and pre-push:
# gen, fmt, lint, lint/typos, slim binary (local arch)
#
# pre-push only (need external services or are slow):
# site/out/index.html (pnpm build)
# test-postgres-docker + test (needs Docker)
# test-js, test-e2e (needs Playwright)
# sqlc-vet (needs Docker)
# offlinedocs/check
#
# Omitted:
# test-go-pg-17 (same tests, different PG version)
define check-unstaged
unstaged="$$(git diff --name-only)"
if [[ -n $$unstaged ]]; then
echo "$(RED)✗ check unstaged changes$(RESET)"
echo "$$unstaged" | sed 's/^/ - /'
echo ""
echo "$(DIM) Verify generated changes are correct before staging:$(RESET)"
echo "$(DIM) git diff$(RESET)"
echo "$(DIM) git add -u && git commit$(RESET)"
echo "ERROR: unstaged changes in tracked files:"
echo "$$unstaged"
echo
echo "Review each change (git diff), verify correctness, then stage:"
echo " git add -u && git commit"
exit 1
fi
endef
define check-untracked
untracked=$$(git ls-files --other --exclude-standard)
if [[ -n $$untracked ]]; then
echo "$(YELLOW)? check untracked files$(RESET)"
echo "$$untracked" | sed 's/^/ - /'
echo ""
echo "$(DIM) Review if these should be committed or added to .gitignore.$(RESET)"
echo "WARNING: untracked files (not in this commit, won't be in CI):"
echo "$$untracked"
echo
fi
endef
pre-commit:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-commit.XXXXXX")
echo "$(BOLD)pre-commit$(RESET) ($$logdir)"
echo "gen + fmt:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir gen fmt
echo "=== Phase 1/2: gen + fmt ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt
$(check-unstaged)
echo "lint + build:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
echo "=== Phase 2/2: lint + build ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
$(check-unstaged)
$(check-untracked)
rm -rf $$logdir
echo "$(GREEN)✓ pre-commit passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
echo "$(BOLD)$(GREEN)=== pre-commit passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-commit
pre-push:
start=$$(date +%s)
logdir=$$(mktemp -d "$${TMPDIR:-/tmp}/coder-pre-push.XXXXXX")
echo "$(BOLD)pre-push$(RESET) ($$logdir)"
echo "test + build site:"
$(MAKE) --no-print-directory -j$(PARALLEL_JOBS) MAKE_TIMED=1 MAKE_LOGDIR=$$logdir \
echo "=== Phase 1/2: gen + fmt + postgres ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 gen fmt test-postgres-docker
$(check-unstaged)
echo "=== Phase 2/2: lint + build + test ==="
$(MAKE) -j$(PARALLEL_JOBS) --output-sync=target MAKE_TIMED=1 \
lint \
lint/typos \
build/coder-slim_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT) \
site/out/index.html \
test \
test-js \
site/out/index.html
rm -rf $$logdir
echo "$(GREEN)✓ pre-push passed$(RESET) ($$(( $$(date +%s) - $$start ))s)"
test-e2e \
test-race \
sqlc-vet \
offlinedocs/check
$(check-unstaged)
echo "$(BOLD)$(GREEN)=== pre-push passed in $$(( $$(date +%s) - $$start ))s ===$(RESET)"
.PHONY: pre-push
offlinedocs/check: offlinedocs/node_modules/.installed
@@ -1466,5 +1475,3 @@ dogfood/coder/nix.hash: flake.nix flake.lock
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
.PHONY: count-test-databases
+1 -10
View File
@@ -39,7 +39,6 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/clistat"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentgit"
@@ -311,7 +310,6 @@ type agent struct {
filesAPI *agentfiles.API
gitAPI *agentgit.API
processAPI *agentproc.API
desktopAPI *agentdesktop.API
socketServerEnabled bool
socketPath string
@@ -388,10 +386,7 @@ func (a *agent) init() {
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
desktop := agentdesktop.NewPortableDesktop(
a.logger.Named("desktop"), a.execer, a.scriptDataDir,
)
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
@@ -2062,10 +2057,6 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
}
if err := a.desktopAPI.Close(); err != nil {
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
-536
View File
@@ -1,536 +0,0 @@
package agentdesktop
import (
"encoding/json"
"math"
"net/http"
"strconv"
"time"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
"github.com/coder/websocket"
)
// DesktopAction is the request body for the desktop action endpoint.
type DesktopAction struct {
Action string `json:"action"`
Coordinate *[2]int `json:"coordinate,omitempty"`
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
Text *string `json:"text,omitempty"`
Duration *int `json:"duration,omitempty"`
ScrollAmount *int `json:"scroll_amount,omitempty"`
ScrollDirection *string `json:"scroll_direction,omitempty"`
// ScaledWidth and ScaledHeight are the coordinate space the
// model is using. When provided, coordinates are linearly
// mapped from scaled → native before dispatching.
ScaledWidth *int `json:"scaled_width,omitempty"`
ScaledHeight *int `json:"scaled_height,omitempty"`
}
// DesktopActionResponse is the response from the desktop action
// endpoint.
type DesktopActionResponse struct {
Output string `json:"output,omitempty"`
ScreenshotData string `json:"screenshot_data,omitempty"`
ScreenshotWidth int `json:"screenshot_width,omitempty"`
ScreenshotHeight int `json:"screenshot_height,omitempty"`
}
// API exposes the desktop streaming HTTP routes for the agent.
type API struct {
logger slog.Logger
desktop Desktop
clock quartz.Clock
}
// NewAPI creates a new desktop streaming API.
func NewAPI(logger slog.Logger, desktop Desktop, clock quartz.Clock) *API {
if clock == nil {
clock = quartz.NewReal()
}
return &API{
logger: logger,
desktop: desktop,
clock: clock,
}
}
// Routes returns the chi router for mounting at /api/v0/desktop.
func (a *API) Routes() http.Handler {
r := chi.NewRouter()
r.Get("/vnc", a.handleDesktopVNC)
r.Post("/action", a.handleAction)
return r
}
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Start the desktop session (idempotent).
_, err := a.desktop.Start(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start desktop session.",
Detail: err.Error(),
})
return
}
// Get a VNC connection.
vncConn, err := a.desktop.VNCConn(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to connect to VNC server.",
Detail: err.Error(),
})
return
}
defer vncConn.Close()
// Accept WebSocket from coderd.
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
return
}
// No read limit — RFB framebuffer updates can be large.
conn.SetReadLimit(-1)
wsCtx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
defer wsNetConn.Close()
// Bicopy raw bytes between WebSocket and VNC TCP.
agentssh.Bicopy(wsCtx, wsNetConn, vncConn)
}
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
handlerStart := a.clock.Now()
// Ensure the desktop is running and grab native dimensions.
cfg, err := a.desktop.Start(ctx)
if err != nil {
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
slog.Error(err),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start desktop session.",
Detail: err.Error(),
})
return
}
var action DesktopAction
if err := json.NewDecoder(r.Body).Decode(&action); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to decode request body.",
Detail: err.Error(),
})
return
}
a.logger.Info(ctx, "handleAction: started",
slog.F("action", action.Action),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
// Helper to scale a coordinate pair from the model's space to
// native display pixels.
scaleXY := func(x, y int) (int, int) {
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
}
return x, y
}
var resp DesktopActionResponse
switch action.Action {
case "key":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for key action.",
})
return
}
if err := a.desktop.KeyPress(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key press failed.",
Detail: err.Error(),
})
return
}
resp.Output = "key action performed"
case "type":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for type action.",
})
return
}
if err := a.desktop.Type(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Type action failed.",
Detail: err.Error(),
})
return
}
resp.Output = "type action performed"
case "cursor_position":
x, y, err := a.desktop.CursorPosition(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Cursor position failed.",
Detail: err.Error(),
})
return
}
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
case "mouse_move":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Move(ctx, x, y); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Mouse move failed.",
Detail: err.Error(),
})
return
}
resp.Output = "mouse_move action performed"
case "left_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
stepStart := a.clock.Now()
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
a.logger.Warn(ctx, "handleAction: Click failed",
slog.F("action", "left_click"),
slog.F("step", "click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left click failed.",
Detail: err.Error(),
})
return
}
a.logger.Debug(ctx, "handleAction: Click completed",
slog.F("action", "left_click"),
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
)
resp.Output = "left_click action performed"
case "left_click_drag":
if action.Coordinate == nil || action.StartCoordinate == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"coordinate\" or \"start_coordinate\" for left_click_drag.",
})
return
}
sx, sy := scaleXY(action.StartCoordinate[0], action.StartCoordinate[1])
ex, ey := scaleXY(action.Coordinate[0], action.Coordinate[1])
if err := a.desktop.Drag(ctx, sx, sy, ex, ey); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left click drag failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_click_drag action performed"
case "left_mouse_down":
if err := a.desktop.ButtonDown(ctx, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left mouse down failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_mouse_down action performed"
case "left_mouse_up":
if err := a.desktop.ButtonUp(ctx, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Left mouse up failed.",
Detail: err.Error(),
})
return
}
resp.Output = "left_mouse_up action performed"
case "right_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Click(ctx, x, y, MouseButtonRight); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Right click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "right_click action performed"
case "middle_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.Click(ctx, x, y, MouseButtonMiddle); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Middle click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "middle_click action performed"
case "double_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
if err := a.desktop.DoubleClick(ctx, x, y, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Double click failed.",
Detail: err.Error(),
})
return
}
resp.Output = "double_click action performed"
case "triple_click":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
for range 3 {
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Triple click failed.",
Detail: err.Error(),
})
return
}
}
resp.Output = "triple_click action performed"
case "scroll":
x, y, err := coordFromAction(action)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: err.Error(),
})
return
}
x, y = scaleXY(x, y)
amount := 3
if action.ScrollAmount != nil {
amount = *action.ScrollAmount
}
direction := "down"
if action.ScrollDirection != nil {
direction = *action.ScrollDirection
}
var dx, dy int
switch direction {
case "up":
dy = -amount
case "down":
dy = amount
case "left":
dx = -amount
case "right":
dx = amount
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid scroll direction: " + direction,
})
return
}
if err := a.desktop.Scroll(ctx, x, y, dx, dy); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Scroll failed.",
Detail: err.Error(),
})
return
}
resp.Output = "scroll action performed"
case "hold_key":
if action.Text == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Missing \"text\" for hold_key action.",
})
return
}
dur := 1000
if action.Duration != nil {
dur = *action.Duration
}
if err := a.desktop.KeyDown(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key down failed.",
Detail: err.Error(),
})
return
}
timer := a.clock.NewTimer(time.Duration(dur)*time.Millisecond, "agentdesktop", "hold_key")
defer timer.Stop()
select {
case <-ctx.Done():
// Context canceled; release the key immediately.
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
}
return
case <-timer.C:
}
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Key up failed.",
Detail: err.Error(),
})
return
}
resp.Output = "hold_key action performed"
case "screenshot":
var opts ScreenshotOptions
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
opts.TargetWidth = *action.ScaledWidth
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
opts.TargetHeight = *action.ScaledHeight
}
result, err := a.desktop.Screenshot(ctx, opts)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Screenshot failed.",
Detail: err.Error(),
})
return
}
resp.Output = "screenshot"
resp.ScreenshotData = result.Data
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
resp.ScreenshotWidth = *action.ScaledWidth
} else {
resp.ScreenshotWidth = cfg.Width
}
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
resp.ScreenshotHeight = *action.ScaledHeight
} else {
resp.ScreenshotHeight = cfg.Height
}
default:
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Unknown action: " + action.Action,
})
return
}
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
if ctx.Err() != nil {
a.logger.Error(ctx, "handleAction: context canceled before writing response",
slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs),
slog.Error(ctx.Err()),
)
return
}
a.logger.Info(ctx, "handleAction: writing response",
slog.F("action", action.Action),
slog.F("elapsed_ms", elapsedMs),
)
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
// Close shuts down the desktop session if one is running.
func (a *API) Close() error {
return a.desktop.Close()
}
// coordFromAction extracts the coordinate pair from a DesktopAction,
// returning an error if the coordinate field is missing.
func coordFromAction(action DesktopAction) (x, y int, err error) {
if action.Coordinate == nil {
return 0, 0, &missingFieldError{field: "coordinate", action: action.Action}
}
return action.Coordinate[0], action.Coordinate[1], nil
}
// missingFieldError is returned when a required field is absent from
// a DesktopAction.
type missingFieldError struct {
field string
action string
}
func (e *missingFieldError) Error() string {
return "Missing \"" + e.field + "\" for " + e.action + " action."
}
// scaleCoordinate maps a coordinate from scaled → native space.
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
if scaledDim == 0 || scaledDim == nativeDim {
return scaled
}
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
// Clamp to valid range.
native = math.Max(native, 0)
native = math.Min(native, float64(nativeDim-1))
return int(native)
}
-467
View File
@@ -1,467 +0,0 @@
package agentdesktop_test
import (
"bytes"
"context"
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentdesktop"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
// fakeDesktop is a minimal Desktop implementation for unit tests.
type fakeDesktop struct {
startErr error
startCfg agentdesktop.DisplayConfig
vncConnErr error
screenshotErr error
screenshotRes agentdesktop.ScreenshotResult
closed bool
// Track calls for assertions.
lastMove [2]int
lastClick [3]int // x, y, button
lastScroll [4]int // x, y, dx, dy
lastKey string
lastTyped string
lastKeyDown string
lastKeyUp string
}
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
return f.startCfg, f.startErr
}
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
return nil, f.vncConnErr
}
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
return f.screenshotRes, f.screenshotErr
}
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
f.lastMove = [2]int{x, y}
return nil
}
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 1}
return nil
}
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
f.lastClick = [3]int{x, y, 2}
return nil
}
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
f.lastScroll = [4]int{x, y, dx, dy}
return nil
}
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
f.lastKey = key
return nil
}
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
f.lastKeyDown = key
return nil
}
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
f.lastKeyUp = key
return nil
}
func (f *fakeDesktop) Type(_ context.Context, text string) error {
f.lastTyped = text
return nil
}
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
return 10, 20, nil
}
func (f *fakeDesktop) Close() error {
f.closed = true
return nil
}
func TestHandleDesktopVNC_StartError(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
var resp codersdk.Response
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Failed to start desktop session.", resp.Message)
}
func TestHandleAction_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "screenshot"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var result agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&result)
require.NoError(t, err)
// Dimensions come from DisplayConfig, not the screenshot CLI.
assert.Equal(t, "screenshot", result.Output)
assert.Equal(t, "base64data", result.ScreenshotData)
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
}
func TestHandleAction_LeftClick(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{
Action: "left_click",
Coordinate: &[2]int{100, 200},
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "left_click action performed", resp.Output)
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
}
func TestHandleAction_UnknownAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "explode"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestHandleAction_KeyAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "Return"
body := agentdesktop.DesktopAction{
Action: "key",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "Return", fake.lastKey)
}
func TestHandleAction_TypeAction(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
text := "hello world"
body := agentdesktop.DesktopAction{
Action: "type",
Text: &text,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
assert.Equal(t, "hello world", fake.lastTyped)
}
func TestHandleAction_HoldKey(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
mClk := quartz.NewMock(t)
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
defer trap.Close()
api := agentdesktop.NewAPI(logger, fake, mClk)
defer api.Close()
text := "Shift_L"
dur := 100
body := agentdesktop.DesktopAction{
Action: "hold_key",
Text: &text,
Duration: &dur,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
done := make(chan struct{})
go func() {
defer close(done)
handler.ServeHTTP(rr, req)
}()
// Wait for the timer to be created, then advance past it.
trap.MustWait(req.Context()).MustRelease(req.Context())
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
<-done
assert.Equal(t, http.StatusOK, rr.Code)
var resp agentdesktop.DesktopActionResponse
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "hold_key action performed", resp.Output)
assert.Equal(t, "Shift_L", fake.lastKeyDown)
assert.Equal(t, "Shift_L", fake.lastKeyUp)
}
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
body := agentdesktop.DesktopAction{Action: "hold_key"}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusBadRequest, rr.Code)
var resp codersdk.Response
err = json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
}
func TestHandleAction_ScrollDown(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
dir := "down"
amount := 5
body := agentdesktop.DesktopAction{
Action: "scroll",
Coordinate: &[2]int{500, 400},
ScrollDirection: &dir,
ScrollAmount: &amount,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// dy should be positive 5 for "down".
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
}
func TestHandleAction_CoordinateScaling(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{
// Native display is 1920x1080.
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
}
api := agentdesktop.NewAPI(logger, fake, nil)
defer api.Close()
// Model is working in a 1280x720 coordinate space.
sw := 1280
sh := 720
body := agentdesktop.DesktopAction{
Action: "mouse_move",
Coordinate: &[2]int{640, 360},
ScaledWidth: &sw,
ScaledHeight: &sh,
}
b, err := json.Marshal(body)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
// midpoint).
assert.Equal(t, 960, fake.lastMove[0])
assert.Equal(t, 540, fake.lastMove[1])
}
func TestClose_DelegatesToDesktop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
assert.True(t, fake.closed)
}
func TestClose_PreventsNewSessions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// After Close(), Start() will return an error because the
// underlying Desktop is closed.
fake := &fakeDesktop{}
api := agentdesktop.NewAPI(logger, fake, nil)
err := api.Close()
require.NoError(t, err)
// Simulate the closed desktop returning an error on Start().
fake.startErr = xerrors.New("desktop is closed")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
handler := api.Routes()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
-91
View File
@@ -1,91 +0,0 @@
package agentdesktop
import (
"context"
"net"
)
// Desktop abstracts a virtual desktop session running inside a workspace.
type Desktop interface {
// Start launches the desktop session. It is idempotent — calling
// Start on an already-running session returns the existing
// config. The returned DisplayConfig describes the running
// session.
Start(ctx context.Context) (DisplayConfig, error)
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames. Each call returns a new
// connection; multiple clients can connect simultaneously.
// Start must be called before VNCConn.
VNCConn(ctx context.Context) (net.Conn, error)
// Screenshot captures the current framebuffer as a PNG and
// returns it base64-encoded. TargetWidth/TargetHeight in opts
// are the desired output dimensions (the implementation
// rescales); pass 0 to use native resolution.
Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error)
// Mouse operations.
// Move moves the mouse cursor to absolute coordinates.
Move(ctx context.Context, x, y int) error
// Click performs a mouse button click at the given coordinates.
Click(ctx context.Context, x, y int, button MouseButton) error
// DoubleClick performs a double-click at the given coordinates.
DoubleClick(ctx context.Context, x, y int, button MouseButton) error
// ButtonDown presses and holds a mouse button.
ButtonDown(ctx context.Context, button MouseButton) error
// ButtonUp releases a mouse button.
ButtonUp(ctx context.Context, button MouseButton) error
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
Scroll(ctx context.Context, x, y, dx, dy int) error
// Drag moves from (startX,startY) to (endX,endY) while holding
// the left mouse button.
Drag(ctx context.Context, startX, startY, endX, endY int) error
// Keyboard operations.
// KeyPress sends a key-down then key-up for a key combo string
// (e.g. "Return", "ctrl+c").
KeyPress(ctx context.Context, keys string) error
// KeyDown presses and holds a key.
KeyDown(ctx context.Context, key string) error
// KeyUp releases a key.
KeyUp(ctx context.Context, key string) error
// Type types a string of text character-by-character.
Type(ctx context.Context, text string) error
// CursorPosition returns the current cursor coordinates.
CursorPosition(ctx context.Context) (x, y int, err error)
// Close shuts down the desktop session and cleans up resources.
Close() error
}
// DisplayConfig describes a running desktop session.
type DisplayConfig struct {
Width int // native width in pixels
Height int // native height in pixels
VNCPort int // local TCP port for the VNC server
Display int // X11 display number (e.g. 1 for :1), -1 if N/A
}
// MouseButton identifies a mouse button.
type MouseButton string
const (
MouseButtonLeft MouseButton = "left"
MouseButtonRight MouseButton = "right"
MouseButtonMiddle MouseButton = "middle"
)
// ScreenshotOptions configures a screenshot capture.
type ScreenshotOptions struct {
TargetWidth int // 0 = native
TargetHeight int // 0 = native
}
// ScreenshotResult is a captured screenshot.
type ScreenshotResult struct {
Data string // base64-encoded PNG
}
-544
View File
@@ -1,544 +0,0 @@
package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
portableDesktopVersion = "v0.0.4"
downloadRetries = 3
downloadRetryDelay = time.Second
)
// platformBinaries maps GOARCH to download URL and expected SHA-256
// digest for each supported platform.
var platformBinaries = map[string]struct {
URL string
SHA256 string
}{
"amd64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-x64",
SHA256: "a04e05e6c7d6f2e6b3acbf1729a7b21271276300b4fee321f4ffee6136538317",
},
"arm64": {
URL: "https://github.com/coder/portabledesktop/releases/download/" + portableDesktopVersion + "/portabledesktop-linux-arm64",
SHA256: "b8cb9142dc32d46a608f25229cbe8168ff2a3aadc54253c74ff54cd347e16ca6",
},
}
// portableDesktopOutput is the JSON output from
// `portabledesktop up --json`.
type portableDesktopOutput struct {
VNCPort int `json:"vncPort"`
Geometry string `json:"geometry"` // e.g. "1920x1080"
}
// desktopSession tracks a running portabledesktop process.
type desktopSession struct {
cmd *exec.Cmd
vncPort int
width int // native width, parsed from geometry
height int // native height, parsed from geometry
display int // X11 display number, -1 if not available
cancel context.CancelFunc
}
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
type cursorOutput struct {
X int `json:"x"`
Y int `json:"y"`
}
// screenshotOutput is the JSON output from
// `portabledesktop screenshot --json`.
type screenshotOutput struct {
Data string `json:"data"`
}
// portableDesktop implements Desktop by shelling out to the
// portabledesktop CLI via agentexec.Execer.
type portableDesktop struct {
logger slog.Logger
execer agentexec.Execer
dataDir string // agent's ScriptDataDir, used for binary caching
mu sync.Mutex
session *desktopSession // nil until started
binPath string // resolved path to binary, cached
closed bool
// httpClient is used for downloading the binary. If nil,
// http.DefaultClient is used.
httpClient *http.Client
}
// NewPortableDesktop creates a Desktop backed by the portabledesktop
// CLI binary, using execer to spawn child processes. dataDir is used
// to cache the downloaded binary.
func NewPortableDesktop(
logger slog.Logger,
execer agentexec.Execer,
dataDir string,
) Desktop {
return &portableDesktop{
logger: logger,
execer: execer,
dataDir: dataDir,
}
}
// httpDo returns the HTTP client to use for downloads.
func (p *portableDesktop) httpDo() *http.Client {
if p.httpClient != nil {
return p.httpClient
}
return http.DefaultClient
}
// Start launches the desktop session (idempotent).
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return DisplayConfig{}, xerrors.New("desktop is closed")
}
if err := p.ensureBinary(ctx); err != nil {
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
}
// If we have an existing session, check if it's still alive.
if p.session != nil {
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
return DisplayConfig{
Width: p.session.width,
Height: p.session.height,
VNCPort: p.session.vncPort,
Display: p.session.display,
}, nil
}
// Process died — clean up and recreate.
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
p.session.cancel()
p.session = nil
}
// Spawn portabledesktop up --json.
sessionCtx, sessionCancel := context.WithCancel(context.Background())
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
stdout, err := cmd.StdoutPipe()
if err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
sessionCancel()
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
}
// Parse the JSON output to get VNC port and geometry.
var output portableDesktopOutput
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
}
if output.VNCPort == 0 {
sessionCancel()
_ = cmd.Process.Kill()
_ = cmd.Wait()
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
}
var w, h int
if output.Geometry != "" {
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
slog.F("geometry", output.Geometry),
slog.Error(err),
)
}
}
p.logger.Info(ctx, "started portabledesktop session",
slog.F("vnc_port", output.VNCPort),
slog.F("width", w),
slog.F("height", h),
slog.F("pid", cmd.Process.Pid),
)
p.session = &desktopSession{
cmd: cmd,
vncPort: output.VNCPort,
width: w,
height: h,
display: -1,
cancel: sessionCancel,
}
return DisplayConfig{
Width: w,
Height: h,
VNCPort: output.VNCPort,
Display: -1,
}, nil
}
// VNCConn dials the desktop's VNC server and returns a raw
// net.Conn carrying RFB binary frames.
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
p.mu.Lock()
session := p.session
p.mu.Unlock()
if session == nil {
return nil, xerrors.New("desktop session not started")
}
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
}
// Screenshot captures the current framebuffer as a base64-encoded PNG.
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
args := []string{"screenshot", "--json"}
if opts.TargetWidth > 0 {
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
}
if opts.TargetHeight > 0 {
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
}
out, err := p.runCmd(ctx, args...)
if err != nil {
return ScreenshotResult{}, err
}
var result screenshotOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
}
return ScreenshotResult(result), nil
}
// Move moves the mouse cursor to absolute coordinates.
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
return err
}
// Click performs a mouse button click at the given coordinates.
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// DoubleClick performs a double-click at the given coordinates.
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "click", string(button))
return err
}
// ButtonDown presses and holds a mouse button.
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "down", string(button))
return err
}
// ButtonUp releases a mouse button.
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
_, err := p.runCmd(ctx, "mouse", "up", string(button))
return err
}
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
return err
}
// Drag moves from (startX,startY) to (endX,endY) while holding the
// left mouse button.
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
return err
}
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
return err
}
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
return err
}
// KeyPress sends a key-down then key-up for a key combo string.
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
_, err := p.runCmd(ctx, "keyboard", "key", keys)
return err
}
// KeyDown presses and holds a key.
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "down", key)
return err
}
// KeyUp releases a key.
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
_, err := p.runCmd(ctx, "keyboard", "up", key)
return err
}
// Type types a string of text character-by-character.
func (p *portableDesktop) Type(ctx context.Context, text string) error {
_, err := p.runCmd(ctx, "keyboard", "type", text)
return err
}
// CursorPosition returns the current cursor coordinates.
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
out, err := p.runCmd(ctx, "cursor", "--json")
if err != nil {
return 0, 0, err
}
var result cursorOutput
if err := json.Unmarshal([]byte(out), &result); err != nil {
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
}
return result.X, result.Y, nil
}
// Close shuts down the desktop session and cleans up resources.
func (p *portableDesktop) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
if p.session != nil {
p.session.cancel()
// Xvnc is a child process — killing it cleans up the X
// session.
_ = p.session.cmd.Process.Kill()
_ = p.session.cmd.Wait()
p.session = nil
}
return nil
}
// runCmd executes a portabledesktop subcommand and returns combined
// output. The caller must have previously called ensureBinary.
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
start := time.Now()
//nolint:gosec // args are constructed by the caller, not user input.
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if err != nil {
p.logger.Warn(ctx, "portabledesktop command failed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
slog.Error(err),
slog.F("output", string(out)),
)
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
}
if elapsed > 5*time.Second {
p.logger.Warn(ctx, "portabledesktop command slow",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
} else {
p.logger.Debug(ctx, "portabledesktop command completed",
slog.F("args", args),
slog.F("elapsed_ms", elapsed.Milliseconds()),
)
}
return string(out), nil
}
// ensureBinary resolves or downloads the portabledesktop binary. It
// must be called while p.mu is held.
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
if p.binPath != "" {
return nil
}
// 1. Check PATH.
if path, err := exec.LookPath("portabledesktop"); err == nil {
p.logger.Info(ctx, "found portabledesktop in PATH",
slog.F("path", path),
)
p.binPath = path
return nil
}
// 2. Platform checks.
if runtime.GOOS != "linux" {
return xerrors.New("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
return xerrors.Errorf("unsupported architecture for portabledesktop: %s", runtime.GOARCH)
}
// 3. Check cache.
cacheDir := filepath.Join(p.dataDir, "portabledesktop", bin.SHA256)
cachedPath := filepath.Join(cacheDir, "portabledesktop")
if info, err := os.Stat(cachedPath); err == nil && !info.IsDir() {
// Verify it is executable.
if info.Mode()&0o100 != 0 {
p.logger.Info(ctx, "using cached portabledesktop binary",
slog.F("path", cachedPath),
)
p.binPath = cachedPath
return nil
}
}
// 4. Download with retry.
p.logger.Info(ctx, "downloading portabledesktop binary",
slog.F("url", bin.URL),
slog.F("version", portableDesktopVersion),
slog.F("arch", runtime.GOARCH),
)
var lastErr error
for attempt := range downloadRetries {
if err := downloadBinary(ctx, p.httpDo(), bin.URL, bin.SHA256, cachedPath); err != nil {
lastErr = err
p.logger.Warn(ctx, "download attempt failed",
slog.F("attempt", attempt+1),
slog.F("max_attempts", downloadRetries),
slog.Error(err),
)
if attempt < downloadRetries-1 {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(downloadRetryDelay):
}
}
continue
}
p.binPath = cachedPath
p.logger.Info(ctx, "downloaded portabledesktop binary",
slog.F("path", cachedPath),
)
return nil
}
return xerrors.Errorf("download portabledesktop after %d attempts: %w", downloadRetries, lastErr)
}
// downloadBinary fetches a binary from url, verifies its SHA-256
// digest matches expectedSHA256, and atomically writes it to destPath.
func downloadBinary(ctx context.Context, client *http.Client, url, expectedSHA256, destPath string) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o700); err != nil {
return xerrors.Errorf("create cache directory: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return xerrors.Errorf("create HTTP request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return xerrors.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return xerrors.Errorf("HTTP GET %s: status %d", url, resp.StatusCode)
}
// Write to a temp file in the same directory so the final rename
// is atomic on the same filesystem.
tmpFile, err := os.CreateTemp(filepath.Dir(destPath), "portabledesktop-download-*")
if err != nil {
return xerrors.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
// Clean up the temp file on any error path.
success := false
defer func() {
if !success {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
}
}()
// Stream the response body while computing SHA-256.
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return xerrors.Errorf("download body: %w", err)
}
if err := tmpFile.Close(); err != nil {
return xerrors.Errorf("close temp file: %w", err)
}
// Verify digest.
actualSHA256 := hex.EncodeToString(hasher.Sum(nil))
if actualSHA256 != expectedSHA256 {
return xerrors.Errorf(
"SHA-256 mismatch: expected %s, got %s",
expectedSHA256, actualSHA256,
)
}
if err := os.Chmod(tmpPath, 0o700); err != nil {
return xerrors.Errorf("chmod: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
return xerrors.Errorf("rename to final path: %w", err)
}
success = true
return nil
}
@@ -1,713 +0,0 @@
package agentdesktop
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/pty"
)
// recordedExecer implements agentexec.Execer by recording every
// invocation and delegating to a real shell command built from a
// caller-supplied mapping of subcommand → shell script body.
type recordedExecer struct {
mu sync.Mutex
commands [][]string
// scripts maps a subcommand keyword (e.g. "up", "screenshot")
// to a shell snippet whose stdout will be the command output.
scripts map[string]string
}
func (r *recordedExecer) record(cmd string, args ...string) {
r.mu.Lock()
defer r.mu.Unlock()
r.commands = append(r.commands, append([]string{cmd}, args...))
}
func (r *recordedExecer) allCommands() [][]string {
r.mu.Lock()
defer r.mu.Unlock()
out := make([][]string, len(r.commands))
copy(out, r.commands)
return out
}
// scriptFor finds the first matching script key present in args.
func (r *recordedExecer) scriptFor(args []string) string {
for _, a := range args {
if s, ok := r.scripts[a]; ok {
return s
}
}
// Fallback: succeed silently.
return "true"
}
func (r *recordedExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
r.record(cmd, args...)
script := r.scriptFor(args)
//nolint:gosec // Test helper — script content is controlled by the test.
return exec.CommandContext(ctx, "sh", "-c", script)
}
func (r *recordedExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
r.record(cmd, args...)
return pty.CommandContext(ctx, "sh", "-c", r.scriptFor(args))
}
// --- portableDesktop tests ---
func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
// The "up" script prints the JSON line then sleeps until
// the context is canceled (simulating a long-running process).
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
}
ctx := context.Background()
cfg, err := pd.Start(ctx)
require.NoError(t, err)
assert.Equal(t, 1920, cfg.Width)
assert.Equal(t, 1080, cfg.Height)
assert.Equal(t, 5901, cfg.VNCPort)
assert.Equal(t, -1, cfg.Display)
// Clean up the long-running process.
require.NoError(t, pd.Close())
}
func TestPortableDesktop_Start_Idempotent(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := context.Background()
cfg1, err := pd.Start(ctx)
require.NoError(t, err)
cfg2, err := pd.Start(ctx)
require.NoError(t, err)
assert.Equal(t, cfg1, cfg2, "second Start should return the same config")
// The execer should have been called exactly once for "up".
cmds := rec.allCommands()
upCalls := 0
for _, c := range cmds {
for _, a := range c {
if a == "up" {
upCalls++
}
}
}
assert.Equal(t, 1, upCalls, "expected exactly one 'up' invocation")
require.NoError(t, pd.Close())
}
func TestPortableDesktop_Screenshot(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
"screenshot": `echo '{"data":"abc123"}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := context.Background()
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
require.NoError(t, err)
assert.Equal(t, "abc123", result.Data)
}
func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
dataDir := t.TempDir()
rec := &recordedExecer{
scripts: map[string]string{
"screenshot": `echo '{"data":"x"}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: dataDir,
binPath: "portabledesktop",
}
ctx := context.Background()
_, err := pd.Screenshot(ctx, ScreenshotOptions{
TargetWidth: 800,
TargetHeight: 600,
})
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
// The last command should contain the target dimension flags.
last := cmds[len(cmds)-1]
joined := strings.Join(last, " ")
assert.Contains(t, joined, "--target-width 800")
assert.Contains(t, joined, "--target-height 600")
}
func TestPortableDesktop_MouseMethods(t *testing.T) {
t.Parallel()
// Each sub-test verifies a single mouse method dispatches the
// correct CLI arguments.
tests := []struct {
name string
invoke func(context.Context, *portableDesktop) error
wantArgs []string // substrings expected in a recorded command
}{
{
name: "Move",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Move(ctx, 42, 99)
},
wantArgs: []string{"mouse", "move", "42", "99"},
},
{
name: "Click",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Click(ctx, 10, 20, MouseButtonLeft)
},
// Click does move then click.
wantArgs: []string{"mouse", "click", "left"},
},
{
name: "DoubleClick",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.DoubleClick(ctx, 5, 6, MouseButtonRight)
},
wantArgs: []string{"mouse", "click", "right"},
},
{
name: "ButtonDown",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.ButtonDown(ctx, MouseButtonMiddle)
},
wantArgs: []string{"mouse", "down", "middle"},
},
{
name: "ButtonUp",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.ButtonUp(ctx, MouseButtonLeft)
},
wantArgs: []string{"mouse", "up", "left"},
},
{
name: "Scroll",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Scroll(ctx, 50, 60, 3, 4)
},
wantArgs: []string{"mouse", "scroll", "3", "4"},
},
{
name: "Drag",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Drag(ctx, 10, 20, 30, 40)
},
// Drag ends with mouse up left.
wantArgs: []string{"mouse", "up", "left"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"mouse": `echo ok`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds, "expected at least one command")
// Find at least one recorded command that contains
// all expected argument substrings.
found := false
for _, cmd := range cmds {
joined := strings.Join(cmd, " ")
match := true
for _, want := range tt.wantArgs {
if !strings.Contains(joined, want) {
match = false
break
}
}
if match {
found = true
break
}
}
assert.True(t, found,
"no recorded command matched %v; got %v", tt.wantArgs, cmds)
})
}
}
func TestPortableDesktop_KeyboardMethods(t *testing.T) {
t.Parallel()
tests := []struct {
name string
invoke func(context.Context, *portableDesktop) error
wantArgs []string
}{
{
name: "KeyPress",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyPress(ctx, "Return")
},
wantArgs: []string{"keyboard", "key", "Return"},
},
{
name: "KeyDown",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyDown(ctx, "shift")
},
wantArgs: []string{"keyboard", "down", "shift"},
},
{
name: "KeyUp",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.KeyUp(ctx, "shift")
},
wantArgs: []string{"keyboard", "up", "shift"},
},
{
name: "Type",
invoke: func(ctx context.Context, pd *portableDesktop) error {
return pd.Type(ctx, "hello world")
},
wantArgs: []string{"keyboard", "type", "hello world"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"keyboard": `echo ok`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
err := tt.invoke(context.Background(), pd)
require.NoError(t, err)
cmds := rec.allCommands()
require.NotEmpty(t, cmds)
last := cmds[len(cmds)-1]
joined := strings.Join(last, " ")
for _, want := range tt.wantArgs {
assert.Contains(t, joined, want)
}
})
}
}
func TestPortableDesktop_CursorPosition(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"cursor": `echo '{"x":100,"y":200}'`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
x, y, err := pd.CursorPosition(context.Background())
require.NoError(t, err)
assert.Equal(t, 100, x)
assert.Equal(t, 200, y)
}
func TestPortableDesktop_Close(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
rec := &recordedExecer{
scripts: map[string]string{
"up": `printf '{"vncPort":5901,"geometry":"1024x768"}\n' && sleep 120`,
},
}
pd := &portableDesktop{
logger: logger,
execer: rec,
dataDir: t.TempDir(),
binPath: "portabledesktop",
}
ctx := context.Background()
_, err := pd.Start(ctx)
require.NoError(t, err)
// Session should exist.
pd.mu.Lock()
require.NotNil(t, pd.session)
pd.mu.Unlock()
require.NoError(t, pd.Close())
// Session should be cleaned up.
pd.mu.Lock()
assert.Nil(t, pd.session)
assert.True(t, pd.closed)
pd.mu.Unlock()
// Subsequent Start must fail.
_, err = pd.Start(ctx)
require.Error(t, err)
assert.Contains(t, err.Error(), "desktop is closed")
}
// --- downloadBinary tests ---
func TestDownloadBinary_Success(t *testing.T) {
t.Parallel()
binaryContent := []byte("#!/bin/sh\necho portable\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
require.NoError(t, err)
// Verify the file exists and has correct content.
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
// Verify executable permissions.
info, err := os.Stat(destPath)
require.NoError(t, err)
assert.NotZero(t, info.Mode()&0o700, "binary should be executable")
}
func TestDownloadBinary_ChecksumMismatch(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("real binary content"))
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
err := downloadBinary(context.Background(), srv.Client(), srv.URL, wrongSHA, destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "SHA-256 mismatch")
// The destination file should not exist (temp file cleaned up).
_, statErr := os.Stat(destPath)
assert.True(t, os.IsNotExist(statErr), "dest file should not exist after checksum failure")
// No leftover temp files in the directory.
entries, err := os.ReadDir(destDir)
require.NoError(t, err)
assert.Empty(t, entries, "no leftover temp files should remain")
}
func TestDownloadBinary_HTTPError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
err := downloadBinary(context.Background(), srv.Client(), srv.URL, "irrelevant", destPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "status 404")
}
// --- ensureBinary tests ---
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
t.Parallel()
// When binPath is already set, ensureBinary should return
// immediately without doing any work.
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: t.TempDir(),
binPath: "/already/set",
}
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, "/already/set", pd.binPath)
}
func TestEnsureBinary_UsesCachedBinary(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
bin, ok := platformBinaries[runtime.GOARCH]
if !ok {
t.Skipf("no platformBinary entry for %s", runtime.GOARCH)
}
dataDir := t.TempDir()
cacheDir := filepath.Join(dataDir, "portabledesktop", bin.SHA256)
require.NoError(t, os.MkdirAll(cacheDir, 0o700))
cachedPath := filepath.Join(cacheDir, "portabledesktop")
require.NoError(t, os.WriteFile(cachedPath, []byte("#!/bin/sh\n"), 0o600))
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
}
// Clear PATH so LookPath won't find a real binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
assert.Equal(t, cachedPath, pd.binPath)
}
func TestEnsureBinary_Downloads(t *testing.T) {
// Cannot use t.Parallel because t.Setenv modifies the process
// environment and we override the package-level platformBinaries.
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
binaryContent := []byte("#!/bin/sh\necho downloaded\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Save and restore platformBinaries for this test.
origBinaries := platformBinaries
platformBinaries = map[string]struct {
URL string
SHA256 string
}{
runtime.GOARCH: {
URL: srv.URL + "/portabledesktop",
SHA256: expectedSHA,
},
}
t.Cleanup(func() { platformBinaries = origBinaries })
dataDir := t.TempDir()
logger := slogtest.Make(t, nil)
pd := &portableDesktop{
logger: logger,
execer: agentexec.DefaultExecer,
dataDir: dataDir,
httpClient: srv.Client(),
}
// Ensure PATH doesn't contain a real portabledesktop binary.
t.Setenv("PATH", "")
err := pd.ensureBinary(context.Background())
require.NoError(t, err)
expectedPath := filepath.Join(dataDir, "portabledesktop", expectedSHA, "portabledesktop")
assert.Equal(t, expectedPath, pd.binPath)
// Verify the downloaded file has correct content.
got, err := os.ReadFile(expectedPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
}
func TestEnsureBinary_RetriesOnFailure(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" {
t.Skip("portabledesktop is only supported on Linux")
}
binaryContent := []byte("#!/bin/sh\necho retried\n")
hash := sha256.Sum256(binaryContent)
expectedSHA := hex.EncodeToString(hash[:])
var mu sync.Mutex
attempt := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
mu.Lock()
current := attempt
attempt++
mu.Unlock()
// Fail the first 2 attempts, succeed on the third.
if current < 2 {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write(binaryContent)
}))
defer srv.Close()
// Test downloadBinary directly to avoid time.Sleep in
// ensureBinary's retry loop. We call it 3 times to simulate
// what ensureBinary would do.
destDir := t.TempDir()
destPath := filepath.Join(destDir, "portabledesktop")
var lastErr error
for i := range 3 {
lastErr = downloadBinary(context.Background(), srv.Client(), srv.URL, expectedSHA, destPath)
if lastErr == nil {
break
}
if i < 2 {
// In the real code, ensureBinary sleeps here.
// We skip the sleep in tests.
continue
}
}
require.NoError(t, lastErr, "download should succeed on the third attempt")
got, err := os.ReadFile(destPath)
require.NoError(t, err)
assert.Equal(t, binaryContent, got)
mu.Lock()
assert.Equal(t, 3, attempt, "server should have been hit 3 times")
mu.Unlock()
}
// Ensure that portableDesktop satisfies the Desktop interface at
// compile time. This uses the unexported type so it lives in the
// internal test package.
var _ Desktop = (*portableDesktop)(nil)
// Silence the linter about unused imports — agentexec.DefaultExecer
// is used in TestEnsureBinary_UsesCachedBinPath and others, and
// fmt.Sscanf is used indirectly via the implementation.
var (
_ = agentexec.DefaultExecer
_ = fmt.Sprintf
)
+38 -89
View File
@@ -447,10 +447,13 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
content := string(data)
for _, edit := range edits {
var err error
content, err = fuzzyReplace(content, edit)
if err != nil {
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
}
@@ -477,92 +480,51 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace it
// with `replace`. It uses a cascading match strategy inspired by
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace
// (indentation-tolerant).
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When edit.ReplaceAll is false (the default), the search string must
// match exactly one location. If multiple matches are found, an error
// is returned asking the caller to include more context or set
// replace_all.
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still
// applied at the byte offsets of the original content so that surrounding
// text (including indentation of untouched lines) is preserved.
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
search := edit.Search
replace := edit.Replace
// Pass 1 exact substring match.
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
if edit.ReplaceAll {
return strings.ReplaceAll(content, search, replace), nil
}
count := strings.Count(content, search)
if count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
// Exactly one match.
return strings.Replace(content, search, replace, 1), nil
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search
// into lines.
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element
// from SplitAfter. Drop it so it doesn't interfere with line
// matching.
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
trimRight := func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}
trimAll := func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace
// (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
if !edit.ReplaceAll {
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
return "", xerrors.Errorf("search string matches %d occurrences "+
"(expected exactly 1). Include more surrounding "+
"context to make the match unique, or set "+
"replace_all to true", count)
}
}
return spliceLines(contentLines, start, end, replace), nil
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return "", xerrors.New("search string not found in file. Verify the search " +
"string matches the file content exactly, including whitespace " +
"and indentation")
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
@@ -587,26 +549,6 @@ outer:
return 0, 0, false
}
// countLineMatches counts how many non-overlapping contiguous
// subsequences of contentLines match searchLines according to eq.
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
count := 0
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
return count
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
count++
i += len(searchLines) - 1 // skip past this match
}
return count
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
@@ -620,3 +562,10 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
+3 -68
View File
@@ -576,9 +576,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
// When the second edit creates ambiguity (two "bar"
// occurrences), it should fail.
name: "EditEditAmbiguous",
name: "EditEdit", // Edits affect previous edits.
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
@@ -595,33 +593,7 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 2 occurrences"},
// File should not be modified on error.
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
},
{
// With replace_all the cascading edit replaces
// both occurrences.
name: "EditEditReplaceAll",
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit-ra"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
},
{
name: "Multiline",
@@ -748,7 +720,7 @@ func TestEditFiles(t *testing.T) {
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchErrors",
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
@@ -761,46 +733,9 @@ func TestEditFiles(t *testing.T) {
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"search string not found in file"},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "AmbiguousExactMatch",
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ambig-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"matches 3 occurrences"},
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
},
{
name: "ReplaceAllExact",
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "ra-exact"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "qux",
ReplaceAll: true,
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
+2 -29
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"sort"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
@@ -70,12 +69,7 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
return
}
var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok {
chatID = id.String()
}
proc, err := api.manager.start(req, chatID)
proc, err := api.manager.start(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start process.",
@@ -111,28 +105,7 @@ func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var chatID string
if id, _, ok := agentgit.ExtractChatContext(r); ok {
chatID = id.String()
}
infos := api.manager.list(chatID)
// Sort by running state (running first), then by started_at
// descending so the most recent processes appear first.
sort.Slice(infos, func(i, j int) bool {
if infos[i].Running != infos[j].Running {
return infos[i].Running
}
return infos[i].StartedAt > infos[j].StartedAt
})
// Cap the response to avoid bloating LLM context.
const maxListProcesses = 10
if len(infos) > maxListProcesses {
infos = infos[:maxListProcesses]
}
infos := api.manager.list()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
Processes: infos,
})
+3 -201
View File
@@ -27,7 +27,7 @@ import (
)
// postStart sends a POST /start request and returns the recorder.
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) *httptest.ResponseRecorder {
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -38,13 +38,6 @@ func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcess
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
for _, h := range headers {
for k, vals := range h {
for _, v := range vals {
r.Header.Add(k, v)
}
}
}
handler.ServeHTTP(w, r)
return w
}
@@ -147,10 +140,10 @@ func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.Pro
// startAndGetID is a helper that starts a process and returns
// the process ID.
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest, headers ...http.Header) string {
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
t.Helper()
w := postStart(t, handler, req, headers...)
w := postStart(t, handler, req)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
@@ -340,180 +333,6 @@ func TestListProcesses(t *testing.T) {
require.Empty(t, resp.Processes)
})
t.Run("FilterByChatID", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
chatA := uuid.New().String()
chatB := uuid.New().String()
headersA := http.Header{workspacesdk.CoderChatIDHeader: {chatA}}
headersB := http.Header{workspacesdk.CoderChatIDHeader: {chatB}}
// Start processes with different chat IDs.
id1 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-a",
}, headersA)
waitForExit(t, handler, id1)
id2 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-b",
}, headersB)
waitForExit(t, handler, id2)
id3 := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo chat-a-2",
}, headersA)
waitForExit(t, handler, id3)
// List with chat A header should return 2 processes.
w := getListWithChatHeader(t, handler, chatA)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
ids := make(map[string]bool)
for _, p := range resp.Processes {
ids[p.ID] = true
}
require.True(t, ids[id1])
require.True(t, ids[id3])
// List with chat B header should return 1 process.
w2 := getListWithChatHeader(t, handler, chatB)
require.Equal(t, http.StatusOK, w2.Code)
var resp2 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w2.Body).Decode(&resp2)
require.NoError(t, err)
require.Len(t, resp2.Processes, 1)
require.Equal(t, id2, resp2.Processes[0].ID)
// List without chat header should return all 3.
w3 := getList(t, handler)
require.Equal(t, http.StatusOK, w3.Code)
var resp3 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w3.Body).Decode(&resp3)
require.NoError(t, err)
require.Len(t, resp3.Processes, 3)
})
t.Run("ChatIDFiltering", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
chatID := uuid.New().String()
headers := http.Header{workspacesdk.CoderChatIDHeader: {chatID}}
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo with-chat",
}, headers)
waitForExit(t, handler, id)
// Listing with the same chat header should return
// the process.
w := getListWithChatHeader(t, handler, chatID)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 1)
require.Equal(t, id, resp.Processes[0].ID)
// Listing with a different chat header should not
// return the process.
w2 := getListWithChatHeader(t, handler, uuid.New().String())
require.Equal(t, http.StatusOK, w2.Code)
var resp2 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w2.Body).Decode(&resp2)
require.NoError(t, err)
require.Empty(t, resp2.Processes)
// Listing without a chat header should return the
// process (no filtering).
w3 := getList(t, handler)
require.Equal(t, http.StatusOK, w3.Code)
var resp3 workspacesdk.ListProcessesResponse
err = json.NewDecoder(w3.Body).Decode(&resp3)
require.NoError(t, err)
require.Len(t, resp3.Processes, 1)
})
t.Run("SortAndLimit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start 12 short-lived processes so we exceed the
// limit of 10.
for i := 0; i < 12; i++ {
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("echo proc-%d", i),
})
waitForExit(t, handler, id)
}
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 10, "should be capped at 10")
// All returned processes are exited, so they should
// be sorted by StartedAt descending (newest first).
for i := 1; i < len(resp.Processes); i++ {
require.GreaterOrEqual(t, resp.Processes[i-1].StartedAt, resp.Processes[i].StartedAt,
"processes should be sorted by started_at descending")
}
})
t.Run("RunningProcessesSortedFirst", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start an exited process first.
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, exitedID)
// Start a running process after.
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
// Running process should come first regardless of
// start order.
require.Equal(t, runningID, resp.Processes[0].ID)
require.True(t, resp.Processes[0].Running)
require.Equal(t, exitedID, resp.Processes[1].ID)
require.False(t, resp.Processes[1].Running)
// Clean up.
postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
t.Run("MixedRunningAndExited", func(t *testing.T) {
t.Parallel()
@@ -562,23 +381,6 @@ func TestListProcesses(t *testing.T) {
})
}
// getListWithChatHeader sends a GET /list request with the
// Coder-Chat-Id header set and returns the recorder.
func getListWithChatHeader(t *testing.T, handler http.Handler, chatID string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
if chatID != "" {
r.Header.Set(workspacesdk.CoderChatIDHeader, chatID)
}
handler.ServeHTTP(w, r)
return w
}
func TestProcessOutput(t *testing.T) {
t.Parallel()
-26
View File
@@ -1,26 +0,0 @@
//go:build !windows
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Unix, Setpgid creates a new process group so
// that signals can be delivered to the entire group (the shell
// and all its children).
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}
// signalProcess sends a signal to the process group rooted at p.
// Using the negative PID sends the signal to every process in the
// group, ensuring child processes (e.g. from shell pipelines) are
// also signaled.
func signalProcess(p *os.Process, sig syscall.Signal) error {
return syscall.Kill(-p.Pid, sig)
}
-20
View File
@@ -1,20 +0,0 @@
package agentproc
import (
"os"
"syscall"
)
// procSysProcAttr returns the SysProcAttr to use when spawning
// processes. On Windows, process groups are not supported in the
// same way as Unix, so this returns an empty struct.
func procSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{}
}
// signalProcess sends a signal directly to the process. Windows
// does not support process group signaling, so we fall back to
// sending the signal to the process itself.
func signalProcess(p *os.Process, _ syscall.Signal) error {
return p.Kill()
}
+9 -36
View File
@@ -21,10 +21,6 @@ import (
var (
errProcessNotFound = xerrors.New("process not found")
errProcessNotRunning = xerrors.New("process is not running")
// exitedProcessReapAge is how long an exited process is
// kept before being automatically removed from the map.
exitedProcessReapAge = 5 * time.Minute
)
// process represents a running or completed process.
@@ -34,7 +30,6 @@ type process struct {
command string
workDir string
background bool
chatID string
cmd *exec.Cmd
cancel context.CancelFunc
buf *HeadTailBuffer
@@ -94,7 +89,7 @@ func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(curr
// processes use a long-lived context so the process survives
// the HTTP request lifecycle. The background flag only affects
// client-side polling behavior.
func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*process, error) {
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
@@ -113,7 +108,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
cmd.SysProcAttr = procSysProcAttr()
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
@@ -160,7 +154,6 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
command: req.Command,
workDir: req.WorkDir,
background: req.Background,
chatID: chatID,
cmd: cmd,
cancel: cancel,
buf: buf,
@@ -222,32 +215,14 @@ func (m *manager) get(id string) (*process, bool) {
return proc, ok
}
// list returns info about all tracked processes. Exited
// processes older than exitedProcessReapAge are removed.
// If chatID is non-empty, only processes belonging to that
// chat are returned.
func (m *manager) list(chatID string) []workspacesdk.ProcessInfo {
// list returns info about all tracked processes.
func (m *manager) list() []workspacesdk.ProcessInfo {
m.mu.Lock()
defer m.mu.Unlock()
now := m.clock.Now()
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
for id, proc := range m.procs {
info := proc.info()
// Reap processes that exited more than 5 minutes ago
// to prevent unbounded map growth.
if !info.Running && info.ExitedAt != nil {
exitedAt := time.Unix(*info.ExitedAt, 0)
if now.Sub(exitedAt) > exitedProcessReapAge {
delete(m.procs, id)
continue
}
}
// Filter by chatID if provided.
if chatID != "" && proc.chatID != chatID {
continue
}
infos = append(infos, info)
for _, proc := range m.procs {
infos = append(infos, proc.info())
}
return infos
}
@@ -273,15 +248,13 @@ func (m *manager) signal(id string, sig string) error {
switch sig {
case "kill":
// Use process group kill to ensure child processes
// (e.g. from shell pipelines) are also killed.
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
// Use process group signal to ensure child processes
// are also terminated.
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
-1
View File
@@ -30,7 +30,6 @@ func (a *agent) apiHandler() http.Handler {
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/git", a.gitAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
+8 -31
View File
@@ -2,7 +2,6 @@ package reaper
import (
"os"
"sync"
"github.com/hashicorp/go-reap"
@@ -43,42 +42,20 @@ func WithLogger(logger slog.Logger) Option {
}
}
// WithReaperStop sets a channel that, when closed, stops the reaper
// WithDone sets a channel that, when closed, stops the reaper
// goroutine. Callers that invoke ForkReap more than once in the
// same process (e.g. tests) should use this to prevent goroutine
// accumulation.
func WithReaperStop(ch chan struct{}) Option {
func WithDone(ch chan struct{}) Option {
return func(o *options) {
o.ReaperStop = ch
}
}
// WithReaperStopped sets a channel that is closed after the
// reaper goroutine has fully exited.
func WithReaperStopped(ch chan struct{}) Option {
return func(o *options) {
o.ReaperStopped = ch
}
}
// WithReapLock sets a mutex shared between the reaper and Wait4.
// The reaper holds the write lock while reaping, and ForkReap
// holds the read lock during Wait4, preventing the reaper from
// stealing the child's exit status. This is only needed for
// tests with instant-exit children where the race window is
// large.
func WithReapLock(mu *sync.RWMutex) Option {
return func(o *options) {
o.ReapLock = mu
o.Done = ch
}
}
type options struct {
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
ReaperStop chan struct{}
ReaperStopped chan struct{}
ReapLock *sync.RWMutex
ExecArgs []string
PIDs reap.PidCh
CatchSignals []os.Signal
Logger slog.Logger
Done chan struct{}
}
+34 -99
View File
@@ -7,7 +7,6 @@ import (
"os"
"os/exec"
"os/signal"
"sync"
"syscall"
"testing"
"time"
@@ -19,82 +18,35 @@ import (
"github.com/coder/coder/v2/testutil"
)
// subprocessEnvKey is set when a test re-execs itself as an
// isolated subprocess. Tests that call ForkReap or send signals
// to their own process check this to decide whether to run real
// test logic or launch the subprocess and wait for it.
const subprocessEnvKey = "CODER_REAPER_TEST_SUBPROCESS"
// withDone returns an option that stops the reaper goroutine when t
// completes, preventing goroutine accumulation across subtests.
func withDone(t *testing.T) reaper.Option {
t.Helper()
done := make(chan struct{})
t.Cleanup(func() { close(done) })
return reaper.WithDone(done)
}
// runSubprocess re-execs the current test binary in a new process
// running only the named test. This isolates ForkReap's
// syscall.ForkExec and any process-directed signals (e.g. SIGINT)
// from the parent test binary, making these tests safe to run in
// CI and alongside other tests.
// TestReap checks that's the reaper is successfully reaping
// exited processes and passing the PIDs through the shared
// channel.
//
// Returns true inside the subprocess (caller should proceed with
// the real test logic). Returns false in the parent after the
// subprocess exits successfully (caller should return).
func runSubprocess(t *testing.T) bool {
t.Helper()
if os.Getenv(subprocessEnvKey) == "1" {
return true
}
ctx := testutil.Context(t, testutil.WaitMedium)
//nolint:gosec // Test-controlled arguments.
cmd := exec.CommandContext(ctx, os.Args[0],
"-test.run=^"+t.Name()+"$",
"-test.v",
)
cmd.Env = append(os.Environ(), subprocessEnvKey+"=1")
out, err := cmd.CombinedOutput()
t.Logf("Subprocess output:\n%s", out)
require.NoError(t, err, "subprocess failed")
return false
}
// withDone returns options that stop the reaper goroutine when t
// completes and wait for it to fully exit, preventing
// overlapping reapers across sequential subtests.
func withDone(t *testing.T) []reaper.Option {
t.Helper()
stop := make(chan struct{})
stopped := make(chan struct{})
t.Cleanup(func() {
close(stop)
<-stopped
})
return []reaper.Option{
reaper.WithReaperStop(stop),
reaper.WithReaperStopped(stopped),
}
}
// TestReap checks that the reaper successfully reaps exited
// processes and passes their PIDs through the shared channel.
//nolint:paralleltest
func TestReap(t *testing.T) {
t.Parallel()
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
pids := make(reap.PidCh, 1)
var reapLock sync.RWMutex
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
// Provide some argument that immediately exits.
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
reaper.WithReapLock(&reapLock),
}, withDone(t)...)
reapLock.RLock()
exitCode, err := reaper.ForkReap(opts...)
reapLock.RUnlock()
withDone(t),
)
require.NoError(t, err)
require.Equal(t, 0, exitCode)
@@ -114,7 +66,7 @@ func TestReap(t *testing.T) {
expectedPIDs := []int{cmd.Process.Pid, cmd2.Process.Pid}
for range len(expectedPIDs) {
for i := 0; i < len(expectedPIDs); i++ {
select {
case <-time.After(testutil.WaitShort):
t.Fatalf("Timed out waiting for process")
@@ -124,15 +76,11 @@ func TestReap(t *testing.T) {
}
}
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
//nolint:paralleltest
func TestForkReapExitCodes(t *testing.T) {
t.Parallel()
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
tests := []struct {
name string
@@ -147,35 +95,26 @@ func TestForkReapExitCodes(t *testing.T) {
{"SIGTERM", "kill -15 $$", 128 + 15},
}
//nolint:paralleltest // Subtests must be sequential, each starts its own reaper.
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var reapLock sync.RWMutex
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
reaper.WithReapLock(&reapLock),
}, withDone(t)...)
reapLock.RLock()
exitCode, err := reaper.ForkReap(opts...)
reapLock.RUnlock()
withDone(t),
)
require.NoError(t, err)
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
})
}
}
// TestReapInterrupt verifies that ForkReap forwards caught signals
// to the child process. The test sends SIGINT to its own process
// and checks that the child receives it. Running in a subprocess
// ensures SIGINT cannot kill the parent test binary.
//nolint:paralleltest // Signal handling.
func TestReapInterrupt(t *testing.T) {
t.Parallel()
// Don't run the reaper test in CI. It does weird
// things like forkexecing which may have unintended
// consequences in CI.
if testutil.InCI() {
t.Skip("Detected CI, skipping reaper tests")
}
if !runSubprocess(t) {
return
}
errC := make(chan error, 1)
pids := make(reap.PidCh, 1)
@@ -187,28 +126,24 @@ func TestReapInterrupt(t *testing.T) {
defer signal.Stop(usrSig)
go func() {
opts := append([]reaper.Option{
exitCode, err := reaper.ForkReap(
reaper.WithPIDCallback(pids),
reaper.WithCatchSignals(os.Interrupt),
withDone(t),
// Signal propagation does not extend to children of children, so
// we create a little bash script to ensure sleep is interrupted.
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf(
"pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait",
os.Getpid(), os.Getpid(),
)),
}, withDone(t)...)
exitCode, err := reaper.ForkReap(opts...)
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
)
// The child exits with 128 + SIGTERM (15) = 143, but the trap catches
// SIGINT and sends SIGTERM to the sleep process, so exit code varies.
_ = exitCode
errC <- err
}()
require.Equal(t, syscall.SIGUSR1, <-usrSig)
require.Equal(t, <-usrSig, syscall.SIGUSR1)
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
require.NoError(t, err)
require.Equal(t, <-usrSig, syscall.SIGUSR2)
require.Equal(t, syscall.SIGUSR2, <-usrSig)
require.NoError(t, <-errC)
}
+14 -24
View File
@@ -19,36 +19,31 @@ func IsInitProcess() bool {
return os.Getpid() == 1
}
// startSignalForwarding registers signal handlers synchronously
// then forwards caught signals to the child in a background
// goroutine. Registering before the goroutine starts ensures no
// signal is lost between ForkExec and the handler being ready.
func startSignalForwarding(logger slog.Logger, pid int, sigs []os.Signal) {
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
if len(sigs) == 0 {
return
}
sc := make(chan os.Signal, 1)
signal.Notify(sc, sigs...)
defer signal.Stop(sc)
logger.Info(context.Background(), "reaper catching signals",
slog.F("signals", sigs),
slog.F("child_pid", pid),
)
go func() {
defer signal.Stop(sc)
for s := range sc {
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
for {
s := <-sc
sig, ok := s.(syscall.Signal)
if ok {
logger.Info(context.Background(), "reaper caught signal, killing child process",
slog.F("signal", sig.String()),
slog.F("child_pid", pid),
)
_ = syscall.Kill(pid, sig)
}
}()
}
}
// ForkReap spawns a goroutine that reaps children. In order to avoid
@@ -69,12 +64,7 @@ func ForkReap(opt ...Option) (int, error) {
o(opts)
}
go func() {
reap.ReapChildren(opts.PIDs, nil, opts.ReaperStop, opts.ReapLock)
if opts.ReaperStopped != nil {
close(opts.ReaperStopped)
}
}()
go reap.ReapChildren(opts.PIDs, nil, opts.Done, nil)
pwd, err := os.Getwd()
if err != nil {
@@ -100,7 +90,7 @@ func ForkReap(opt ...Option) (int, error) {
return 1, xerrors.Errorf("fork exec: %w", err)
}
startSignalForwarding(opts.Logger, pid, opts.CatchSignals)
go catchSignals(opts.Logger, pid, opts.CatchSignals)
var wstatus syscall.WaitStatus
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
-13
View File
@@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -41,18 +40,6 @@ func New(t testing.TB, args ...string) (*serpent.Invocation, config.Root) {
return NewWithCommand(t, cmd, args...)
}
// NewWithClock is like New, but injects the given clock for
// tests that are time-dependent.
func NewWithClock(t testing.TB, clk quartz.Clock, args ...string) (*serpent.Invocation, config.Root) {
var root cli.RootCmd
root.SetClock(clk)
cmd, err := root.Command(root.AGPL())
require.NoError(t, err)
return NewWithCommand(t, cmd, args...)
}
type logWriter struct {
prefix string
log slog.Logger
-15
View File
@@ -46,7 +46,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
autoUpdates string
copyParametersFrom string
useParameterDefaults bool
noWait bool
// Organization context is only required if more than 1 template
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
@@ -373,14 +372,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
if noWait {
_, _ = fmt.Fprintf(inv.Stdout,
"\nThe %s workspace has been created and is building in the background.\n",
cliui.Keyword(workspace.Name),
)
return nil
}
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
if err != nil {
return xerrors.Errorf("watch build: %w", err)
@@ -454,12 +445,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Description: "Automatically accept parameter defaults when no value is provided.",
Value: serpent.BoolOf(&useParameterDefaults),
},
serpent.Option{
Flag: "no-wait",
Env: "CODER_CREATE_NO_WAIT",
Description: "Return immediately after creating the workspace. The build will run in the background.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
)
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
-75
View File
@@ -603,81 +603,6 @@ func TestCreate(t *testing.T) {
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
}
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was actually created.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
})
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
}))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
ctx := testutil.Context(t, testutil.WaitLong)
inv, root := clitest.New(t, "create", "my-workspace",
"--template", template.Name,
"-y",
"--use-parameter-defaults",
"--no-wait",
)
clitest.SetupConfig(t, member, root)
doneChan := make(chan struct{})
pty := ptytest.New(t).Attach(inv)
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatchContext(ctx, "building in the background")
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify workspace was created and parameters were applied.
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
assert.Equal(t, ws.TemplateName, template.Name)
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
require.NoError(t, err)
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
})
}
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
+45 -74
View File
@@ -1732,18 +1732,19 @@ const (
func (r *RootCmd) scaletestAutostart() *serpent.Command {
var (
workspaceCount int64
workspaceJobTimeout time.Duration
autostartBuildTimeout time.Duration
autostartDelay time.Duration
template string
noCleanup bool
workspaceCount int64
workspaceJobTimeout time.Duration
autostartDelay time.Duration
autostartTimeout time.Duration
template string
noCleanup bool
parameterFlags workspaceParameterFlags
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
@@ -1771,7 +1772,7 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("parse output flags: %w", err)
return xerrors.Errorf("could not parse --output flags")
}
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
@@ -1802,41 +1803,15 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := autostart.NewMetrics(reg)
setupBarrier := new(sync.WaitGroup)
setupBarrier.Add(int(workspaceCount))
// The workspace-build-updates experiment must be enabled to use
// the centralized pubsub channel for coordinating workspace builds.
experiments, err := client.Experiments(ctx)
if err != nil {
return xerrors.Errorf("get experiments: %w", err)
}
if !experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) {
return xerrors.New("the workspace-build-updates experiment must be enabled to run the autostart scaletest")
}
workspaceNames := make([]string, 0, workspaceCount)
resultSink := make(chan autostart.RunResult, workspaceCount)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i := range workspaceCount {
id := strconv.Itoa(int(i))
workspaceNames = append(workspaceNames, loadtestutil.GenerateDeterministicWorkspaceName(id))
}
dispatcher := autostart.NewWorkspaceDispatcher(workspaceNames)
decoder, err := client.WatchAllWorkspaceBuilds(ctx)
if err != nil {
return xerrors.Errorf("watch all workspace builds: %w", err)
}
defer decoder.Close()
// Start the dispatcher. It will run in a goroutine and automatically
// close all workspace channels when the build updates channel closes.
dispatcher.Start(ctx, decoder.Chan())
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for workspaceName, buildUpdatesChannel := range dispatcher.Channels {
id := strings.TrimPrefix(workspaceName, loadtestutil.ScaleTestPrefix+"-")
config := autostart.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
@@ -1846,16 +1821,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
// Use deterministic workspace name so we can pre-create the channel.
Name: workspaceName,
},
},
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartBuildTimeout: autostartBuildTimeout,
AutostartDelay: autostartDelay,
SetupBarrier: setupBarrier,
BuildUpdates: buildUpdatesChannel,
ResultSink: resultSink,
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartDelay: autostartDelay,
AutostartTimeout: autostartTimeout,
Metrics: metrics,
SetupBarrier: setupBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
@@ -1877,11 +1849,18 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
th.AddRun(autostartTestName, id, runner)
}
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Running autostart load test...")
@@ -1892,40 +1871,31 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// Collect all metrics from the channel.
close(resultSink)
var runResults []autostart.RunResult
for r := range resultSink {
runResults = append(runResults, r)
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
_, _ = fmt.Fprintf(inv.Stderr, "\nAll %d autostart builds completed successfully (elapsed: %s)\n", res.TotalRuns, time.Duration(res.Elapsed).Round(time.Millisecond))
if len(runResults) > 0 {
results := autostart.NewRunResults(runResults)
for _, out := range outputs {
if err := out.write(results.ToHarnessResults(), inv.Stdout); err != nil {
return xerrors.Errorf("write output: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(context.Background())
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
_, _ = fmt.Fprintln(inv.Stderr, "Cleanup complete")
} else {
_, _ = fmt.Fprintln(inv.Stderr, "\nSkipping cleanup (--no-cleanup specified). Resources left running.")
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
@@ -1948,13 +1918,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "Timeout for workspace jobs (e.g. build, start).",
Value: serpent.DurationOf(&workspaceJobTimeout),
},
{
Flag: "autostart-build-timeout",
Env: "CODER_SCALETEST_AUTOSTART_BUILD_TIMEOUT",
Default: "15m",
Description: "Timeout for the autostart build to complete. Must be longer than workspace-job-timeout to account for queueing time in high-load scenarios.",
Value: serpent.DurationOf(&autostartBuildTimeout),
},
{
Flag: "autostart-delay",
Env: "CODER_SCALETEST_AUTOSTART_DELAY",
@@ -1962,6 +1925,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
Description: "How long after all the workspaces have been stopped to schedule them to be started again.",
Value: serpent.DurationOf(&autostartDelay),
},
{
Flag: "autostart-timeout",
Env: "CODER_SCALETEST_AUTOSTART_TIMEOUT",
Default: "5m",
Description: "Timeout for the autostart build to be initiated after the scheduled start time.",
Value: serpent.DurationOf(&autostartTimeout),
},
{
Flag: "template",
FlagShorthand: "t",
@@ -1980,9 +1950,10 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
tracingFlags.attach(&cmd.Options)
output.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
+6 -12
View File
@@ -19,18 +19,12 @@ func OverrideVSCodeConfigs(fs afero.Fs) error {
return err
}
mutate := func(m map[string]interface{}) {
// These defaults prevent VS Code from overriding
// GIT_ASKPASS and using its own GitHub authentication,
// which would circumvent cloning with Coder-configured
// providers. We only set them if they are not already
// present so that template authors can override them
// via module settings (e.g. the vscode-web module).
if _, ok := m["git.useIntegratedAskPass"]; !ok {
m["git.useIntegratedAskPass"] = false
}
if _, ok := m["github.gitAuthentication"]; !ok {
m["github.gitAuthentication"] = false
}
// This prevents VS Code from overriding GIT_ASKPASS, which
// we use to automatically authenticate Git providers.
m["git.useIntegratedAskPass"] = false
// This prevents VS Code from using it's own GitHub authentication
// which would circumvent cloning with Coder-configured providers.
m["github.gitAuthentication"] = false
}
for _, configPath := range []string{
-27
View File
@@ -61,31 +61,4 @@ func TestOverrideVSCodeConfigs(t *testing.T) {
require.Equal(t, "something", mapping["hotdogs"])
}
})
t.Run("NoOverwrite", func(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
mapping := map[string]interface{}{
"git.useIntegratedAskPass": true,
"github.gitAuthentication": true,
"other.setting": "preserved",
}
data, err := json.Marshal(mapping)
require.NoError(t, err)
for _, configPath := range configPaths {
err = afero.WriteFile(fs, configPath, data, 0o600)
require.NoError(t, err)
}
err = gitauth.OverrideVSCodeConfigs(fs)
require.NoError(t, err)
for _, configPath := range configPaths {
data, err := afero.ReadFile(fs, configPath)
require.NoError(t, err)
mapping := map[string]interface{}{}
err = json.Unmarshal(data, &mapping)
require.NoError(t, err)
require.Equal(t, true, mapping["git.useIntegratedAskPass"])
require.Equal(t, true, mapping["github.gitAuthentication"])
require.Equal(t, "preserved", mapping["other.setting"])
}
})
}
-15
View File
@@ -39,7 +39,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/pretty"
"github.com/coder/quartz"
"github.com/coder/serpent"
)
@@ -231,10 +230,6 @@ func (r *RootCmd) RunWithSubcommands(subcommands []*serpent.Command) {
}
func (r *RootCmd) Command(subcommands []*serpent.Command) (*serpent.Command, error) {
if r.clock == nil {
r.clock = quartz.NewReal()
}
fmtLong := `Coder %s A tool for provisioning self-hosted development environments with Terraform.
`
hiddenAgentAuth := &AgentAuth{}
@@ -553,16 +548,6 @@ type RootCmd struct {
useKeyring bool
keyringServiceName string
useKeyringWithGlobalConfig bool
// clock is used for time-dependent operations. Initialized to
// quartz.NewReal() in Command() if not set via SetClock.
clock quartz.Clock
}
// SetClock sets the clock used for time-dependent operations.
// Must be called before Command() to take effect.
func (r *RootCmd) SetClock(clk quartz.Clock) {
r.clock = clk
}
// ensureClientURL loads the client URL from the config file if it
+10 -11
View File
@@ -188,17 +188,16 @@ func (r *RootCmd) newCreateAdminUserCommand() *serpent.Command {
_, _ = fmt.Fprintln(inv.Stderr, "Creating user...")
newUser, err = tx.InsertUser(ctx, database.InsertUserParams{
ID: uuid.New(),
Email: newUserEmail,
Username: newUserUsername,
Name: "Admin User",
HashedPassword: []byte(hashedPassword),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
RBACRoles: []string{rbac.RoleOwner().String()},
LoginType: database.LoginTypePassword,
Status: "",
IsServiceAccount: false,
ID: uuid.New(),
Email: newUserEmail,
Username: newUserUsername,
Name: "Admin User",
HashedPassword: []byte(hashedPassword),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
RBACRoles: []string{rbac.RoleOwner().String()},
LoginType: database.LoginTypePassword,
Status: "",
})
if err != nil {
return xerrors.Errorf("insert user: %w", err)
-1
View File
@@ -324,7 +324,6 @@ func TestServer(t *testing.T) {
ignoreLines := []string{
"isn't externally reachable",
"open install.sh: file does not exist",
"open install.ps1: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
+15 -1
View File
@@ -21,8 +21,9 @@ type storedCredentials map[string]struct {
APIToken string `json:"api_token"`
}
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
func TestKeyring(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" && runtime.GOOS != "darwin" {
t.Skip("linux is not supported yet")
}
@@ -36,6 +37,8 @@ func TestKeyring(t *testing.T) {
)
t.Run("ReadNonExistent", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -47,6 +50,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("DeleteNonExistent", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -58,6 +63,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("WriteAndRead", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -84,6 +91,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("WriteAndDelete", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -106,6 +115,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("OverwriteToken", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -135,6 +146,8 @@ func TestKeyring(t *testing.T) {
})
t.Run("MultipleServers", func(t *testing.T) {
t.Parallel()
backend := sessionstore.NewKeyringWithService(testhelpers.KeyringServiceName(t))
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
@@ -186,6 +199,7 @@ func TestKeyring(t *testing.T) {
})
t.Run("StorageFormat", func(t *testing.T) {
t.Parallel()
// The storage format must remain consistent to ensure we don't break
// compatibility with other Coder related applications that may read
// or decode the same credential.
@@ -25,8 +25,9 @@ func readRawKeychainCredential(t *testing.T, serviceName string) []byte {
return winCred.CredentialBlob
}
//nolint:paralleltest, tparallel // OS keyring is flaky under concurrent access
func TestWindowsKeyring_WriteReadDelete(t *testing.T) {
t.Parallel()
const testURL = "http://127.0.0.1:1337"
srvURL, err := url.Parse(testURL)
require.NoError(t, err)
+16 -8
View File
@@ -180,11 +180,15 @@ func TestSSH(t *testing.T) {
// Delay until workspace is starting, otherwise the agent may be
// booted due to outdated build.
require.Eventually(t, func() bool {
var err error
var err error
for {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, err)
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
break
}
time.Sleep(testutil.IntervalFast)
}
// When the agent connects, the workspace was started, and we should
// have access to the shell.
@@ -759,11 +763,15 @@ func TestSSH(t *testing.T) {
// Delay until workspace is starting, otherwise the agent may be
// booted due to outdated build.
require.Eventually(t, func() bool {
var err error
var err error
for {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, err)
if workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
break
}
time.Sleep(testutil.IntervalFast)
}
// When the agent connects, the workspace was started, and we should
// have access to the shell.
+47 -3
View File
@@ -6,6 +6,8 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/require"
@@ -102,11 +104,20 @@ func TestSyncCommands_Golden(t *testing.T) {
require.NoError(t, err)
client.Close()
outBuf := testutil.NewWaitBuffer()
// Use a writer that signals when the "Waiting" message has been
// written, so the goroutine can complete the dependency at the
// right time without relying on time.Sleep.
outBuf := newSyncWriter("Waiting")
// Start a goroutine to complete the dependency once the start
// command has printed its waiting message.
done := make(chan error, 1)
go func() {
if err := outBuf.WaitFor(ctx, "Waiting"); err != nil {
done <- err
// Block until the command prints the waiting message.
select {
case <-outBuf.matched:
case <-ctx.Done():
done <- ctx.Err()
return
}
@@ -328,3 +339,36 @@ func TestSyncCommands_Golden(t *testing.T) {
clitest.TestGoldenFile(t, "TestSyncCommands_Golden/status_json_format", outBuf.Bytes(), nil)
})
}
// syncWriter is a thread-safe io.Writer that wraps a bytes.Buffer and
// closes a channel when the written content contains a signal string.
type syncWriter struct {
mu sync.Mutex
buf bytes.Buffer
signal string
matched chan struct{}
closeOnce sync.Once
}
func newSyncWriter(signal string) *syncWriter {
return &syncWriter{
signal: signal,
matched: make(chan struct{}),
}
}
func (w *syncWriter) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
n, err := w.buf.Write(p)
if w.signal != "" && strings.Contains(w.buf.String(), w.signal) {
w.closeOnce.Do(func() { close(w.matched) })
}
return n, err
}
func (w *syncWriter) Bytes() []byte {
w.mu.Lock()
defer w.mu.Unlock()
return w.buf.Bytes()
}
+5 -5
View File
@@ -90,7 +90,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
return err
}
tsr := toStatusRow(task, r.clock.Now())
tsr := toStatusRow(task)
out, err := formatter.Format(ctx, []taskStatusRow{tsr})
if err != nil {
return xerrors.Errorf("format task status: %w", err)
@@ -112,7 +112,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
}
// Only print if something changed
newStatusRow := toStatusRow(task, r.clock.Now())
newStatusRow := toStatusRow(task)
if !taskStatusRowEqual(lastStatusRow, newStatusRow) {
out, err := formatter.Format(ctx, []taskStatusRow{newStatusRow})
if err != nil {
@@ -166,10 +166,10 @@ func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
taskStateEqual(r1.CurrentState, r2.CurrentState)
}
func toStatusRow(task codersdk.Task, now time.Time) taskStatusRow {
func toStatusRow(task codersdk.Task) taskStatusRow {
tsr := taskStatusRow{
Task: task,
ChangedAgo: now.Sub(task.UpdatedAt).Truncate(time.Second).String() + " ago",
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
}
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
task.WorkspaceAgentHealth.Healthy &&
@@ -178,7 +178,7 @@ func toStatusRow(task codersdk.Task, now time.Time) taskStatusRow {
!task.WorkspaceAgentLifecycle.ShuttingDown()
if task.CurrentState != nil {
tsr.ChangedAgo = now.Sub(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
}
return tsr
}
+9 -12
View File
@@ -19,7 +19,6 @@ import (
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_TaskStatus(t *testing.T) {
@@ -29,12 +28,12 @@ func Test_TaskStatus(t *testing.T) {
args []string
expectOutput string
expectError string
hf func(context.Context, quartz.Clock) func(http.ResponseWriter, *http.Request)
hf func(context.Context, time.Time) func(http.ResponseWriter, *http.Request)
}{
{
args: []string{"doesnotexist"},
expectError: httpapi.ResourceNotFoundResponse.Message,
hf: func(ctx context.Context, _ quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/tasks/me/doesnotexist":
@@ -50,8 +49,7 @@ func Test_TaskStatus(t *testing.T) {
args: []string{"exists"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
0s ago active true working Thinking furiously...`,
hf: func(ctx context.Context, clk quartz.Clock) func(w http.ResponseWriter, r *http.Request) {
now := clk.Now()
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/tasks/me/exists":
@@ -86,8 +84,7 @@ func Test_TaskStatus(t *testing.T) {
4s ago active true
3s ago active true working Reticulating splines...
2s ago active true complete Splines reticulated successfully!`,
hf: func(ctx context.Context, clk quartz.Clock) func(http.ResponseWriter, *http.Request) {
now := clk.Now()
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
var calls atomic.Int64
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -218,7 +215,7 @@ func Test_TaskStatus(t *testing.T) {
"created_at": "2025-08-26T12:34:56Z",
"updated_at": "2025-08-26T12:34:56Z"
}`,
hf: func(ctx context.Context, _ quartz.Clock) func(http.ResponseWriter, *http.Request) {
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
@@ -255,8 +252,8 @@ func Test_TaskStatus(t *testing.T) {
var (
ctx = testutil.Context(t, testutil.WaitShort)
mClock = quartz.NewMock(t)
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, mClock)))
now = time.Now().UTC() // TODO: replace with quartz
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
client = codersdk.New(testutil.MustURL(t, srv.URL))
sb = strings.Builder{}
args = []string{"task", "status", "--watch-interval", testutil.IntervalFast.String()}
@@ -264,10 +261,10 @@ func Test_TaskStatus(t *testing.T) {
t.Cleanup(srv.Close)
args = append(args, tc.args...)
inv, cfgDir := clitest.NewWithClock(t, mClock, args...)
inv, root := clitest.New(t, args...)
inv.Stdout = &sb
inv.Stderr = &sb
clitest.SetupConfig(t, client, cfgDir)
clitest.SetupConfig(t, client, root)
err := inv.WithContext(ctx).Run()
if tc.expectError == "" {
assert.NoError(t, err)
-4
View File
@@ -20,10 +20,6 @@ OPTIONS:
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
Specify the source workspace name to copy parameters from.
--no-wait bool, $CODER_CREATE_NO_WAIT
Return immediately after creating the workspace. The build will run in
the background.
--parameter string-array, $CODER_RICH_PARAMETER
Rich parameter value in the format "name=value".
+5
View File
@@ -143,6 +143,11 @@ AI BRIDGE OPTIONS:
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
Whether to start an in-memory aibridged instance.
--aibridge-inject-coder-mcp-tools bool, $CODER_AIBRIDGE_INJECT_CODER_MCP_TOOLS (default: false)
Whether to inject Coder's MCP tools into intercepted AI Bridge
requests (requires the "oauth2" and "mcp-server-http" experiments to
be enabled).
--aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0)
Maximum number of concurrent AI Bridge requests per replica. Set to 0
to disable (unlimited).
+2 -4
View File
@@ -778,10 +778,8 @@ aibridge:
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
# Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a
# future release. Whether to inject Coder's MCP tools into intercepted AI Bridge
# requests (requires the "oauth2" and "mcp-server-http" experiments to be
# enabled).
# Whether to inject Coder's MCP tools into intercepted AI Bridge requests
# (requires the "oauth2" and "mcp-server-http" experiments to be enabled).
# (default: false, type: bool)
inject_coder_mcp_tools: false
# Length of time to retain data such as interceptions and all related records
+4 -4
View File
@@ -116,10 +116,10 @@ func TestWorkspaceActivityBump(t *testing.T) {
// is required. The Activity Bump behavior is also coupled with
// Last Used, so it would be obvious to the user if we
// are falsely recognizing activity.
require.Never(t, func() bool {
workspace, err = client.Workspace(ctx, workspace.ID)
return err == nil && !workspace.LatestBuild.Deadline.Time.Equal(firstDeadline)
}, testutil.IntervalMedium, testutil.IntervalFast, "deadline should not change")
time.Sleep(testutil.IntervalMedium)
workspace, err = client.Workspace(ctx, workspace.ID)
require.NoError(t, err)
require.Equal(t, workspace.LatestBuild.Deadline.Time, firstDeadline)
return
}
+3 -6
View File
@@ -134,12 +134,9 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
case database.WorkspaceAgentLifecycleStateReady,
database.WorkspaceAgentLifecycleStateStartTimeout,
database.WorkspaceAgentLifecycleStateStartError:
// Only emit metrics for the parent agent, this metric is not intended to measure devcontainer durations.
if !workspaceAgent.ParentID.Valid {
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
a.emitMetricsOnce.Do(func() {
a.emitBuildDurationMetric(ctx, workspaceAgent.ResourceID)
})
}
return req.Lifecycle, nil
-58
View File
@@ -582,64 +582,6 @@ func TestUpdateLifecycle(t *testing.T) {
require.Equal(t, uint64(1), got.GetSampleCount())
require.Equal(t, expectedDuration, got.GetSampleSum())
})
t.Run("SubAgentDoesNotEmitMetric", func(t *testing.T) {
t.Parallel()
parentID := uuid.New()
subAgent := database.WorkspaceAgent{
ID: uuid.New(),
ParentID: uuid.NullUUID{UUID: parentID, Valid: true},
LifecycleState: database.WorkspaceAgentLifecycleStateStarting,
StartedAt: sql.NullTime{Valid: true, Time: someTime},
ReadyAt: sql.NullTime{Valid: false},
}
lifecycle := &agentproto.Lifecycle{
State: agentproto.Lifecycle_READY,
ChangedAt: timestamppb.New(now),
}
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: subAgent.ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
StartedAt: subAgent.StartedAt,
ReadyAt: sql.NullTime{
Time: now,
Valid: true,
},
}).Return(nil)
// GetWorkspaceBuildMetricsByResourceID should NOT be called
// because sub-agents should be skipped before querying.
reg := prometheus.NewRegistry()
metrics := agentapi.NewLifecycleMetrics(reg)
api := &agentapi.LifecycleAPI{
AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) {
return subAgent, nil
},
WorkspaceID: workspaceID,
Database: dbM,
Log: testutil.Logger(t),
Metrics: metrics,
PublishWorkspaceUpdateFn: nil,
}
resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{
Lifecycle: lifecycle,
})
require.NoError(t, err)
require.Equal(t, lifecycle, resp)
// We don't expect the metric to be emitted for sub-agents, by default this will fail anyway but it doesn't hurt
// to document the test explicitly.
dbM.EXPECT().GetWorkspaceBuildMetricsByResourceID(gomock.Any(), gomock.Any()).Times(0)
// If we were emitting the metric we would have failed by now since it would include a call to the database that we're not expecting.
pm, err := reg.Gather()
require.NoError(t, err)
for _, m := range pm {
if m.GetName() == fullMetricName {
t.Fatal("metric should not be emitted for sub-agent")
}
}
})
}
func TestUpdateStartup(t *testing.T) {
-38
View File
@@ -1,38 +0,0 @@
// Package aiseats is the AGPL version the package.
// The actual implementation is in `enterprise/aiseats`.
package aiseats
import (
"context"
"github.com/google/uuid"
"github.com/coder/coder/v2/coderd/database"
)
type Reason struct {
EventType database.AiSeatUsageReason
Description string
}
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
func ReasonAIBridge(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
}
// ReasonTask constructs a reason for usage originating from tasks.
func ReasonTask(description string) Reason {
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
}
// SeatTracker records AI seat consumption state.
type SeatTracker interface {
// RecordUsage does not return an error to prevent blocking the user from using
// AI features. This method is used to record usage, not enforce it.
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
}
// Noop is an AGPL seat tracker that does nothing.
type Noop struct{}
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
+6 -66
View File
@@ -869,28 +869,6 @@ const docTemplate = `{
}
}
},
"/debug/profile": {
"post": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": [
"Debug"
],
"summary": "Collect debug profiles",
"operationId": "collect-debug-profiles",
"responses": {
"200": {
"description": "OK"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/debug/tailnet": {
"get": {
"security": [
@@ -1091,31 +1069,6 @@ const docTemplate = `{
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Workspaces"
],
"summary": "Watch all workspace builds",
"operationId": "watch-all-workspace-builds",
"responses": {
"101": {
"description": "Switching Protocols"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/experiments": {
"get": {
"security": [
@@ -12473,7 +12426,6 @@ const docTemplate = `{
"type": "boolean"
},
"inject_coder_mcp_tools": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "boolean"
},
"max_concurrency": {
@@ -14360,6 +14312,7 @@ const docTemplate = `{
"codersdk.CreateUserRequestWithOrgs": {
"type": "object",
"required": [
"email",
"username"
],
"properties": {
@@ -14389,10 +14342,6 @@ const docTemplate = `{
"password": {
"type": "string"
},
"service_account": {
"description": "Service accounts are admin-managed accounts that cannot login.",
"type": "boolean"
},
"user_status": {
"description": "UserStatus defaults to UserStatusDormant.",
"allOf": [
@@ -15208,8 +15157,7 @@ const docTemplate = `{
"web-push",
"oauth2",
"agents",
"mcp-server-http",
"workspace-build-updates"
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
@@ -15219,7 +15167,6 @@ const docTemplate = `{
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceBuildUpdates": "Enables publishing workspace build updates to the all builds pubsub channel.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -15230,8 +15177,7 @@ const docTemplate = `{
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality.",
"Enables publishing workspace build updates to the all builds pubsub channel."
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -15241,8 +15187,7 @@ const docTemplate = `{
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceBuildUpdates"
"ExperimentMCPServerHTTP"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -15366,15 +15311,12 @@ const docTemplate = `{
"type": "string"
},
"mcp_tool_allow_regex": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"mcp_tool_deny_regex": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"mcp_url": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"no_refresh": {
@@ -18478,8 +18420,7 @@ const docTemplate = `{
"idp_sync_settings_role",
"workspace_agent",
"workspace_app",
"task",
"ai_seat"
"task"
],
"x-enum-varnames": [
"ResourceTypeTemplate",
@@ -18507,8 +18448,7 @@ const docTemplate = `{
"ResourceTypeIdpSyncSettingsRole",
"ResourceTypeWorkspaceAgent",
"ResourceTypeWorkspaceApp",
"ResourceTypeTask",
"ResourceTypeAISeat"
"ResourceTypeTask"
]
},
"codersdk.Response": {
+6 -61
View File
@@ -752,26 +752,6 @@
}
}
},
"/debug/profile": {
"post": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": ["Debug"],
"summary": "Collect debug profiles",
"operationId": "collect-debug-profiles",
"responses": {
"200": {
"description": "OK"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/debug/tailnet": {
"get": {
"security": [
@@ -944,27 +924,6 @@
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Workspaces"],
"summary": "Watch all workspace builds",
"operationId": "watch-all-workspace-builds",
"responses": {
"101": {
"description": "Switching Protocols"
}
},
"x-apidocgen": {
"skip": true
}
}
},
"/experiments": {
"get": {
"security": [
@@ -11079,7 +11038,6 @@
"type": "boolean"
},
"inject_coder_mcp_tools": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "boolean"
},
"max_concurrency": {
@@ -12898,7 +12856,7 @@
},
"codersdk.CreateUserRequestWithOrgs": {
"type": "object",
"required": ["username"],
"required": ["email", "username"],
"properties": {
"email": {
"type": "string",
@@ -12926,10 +12884,6 @@
"password": {
"type": "string"
},
"service_account": {
"description": "Service accounts are admin-managed accounts that cannot login.",
"type": "boolean"
},
"user_status": {
"description": "UserStatus defaults to UserStatusDormant.",
"allOf": [
@@ -13726,8 +13680,7 @@
"web-push",
"oauth2",
"agents",
"mcp-server-http",
"workspace-build-updates"
"mcp-server-http"
],
"x-enum-comments": {
"ExperimentAgents": "Enables agent-powered chat functionality.",
@@ -13737,7 +13690,6 @@
"ExperimentNotifications": "Sends notifications via SMTP and webhooks following certain events.",
"ExperimentOAuth2": "Enables OAuth2 provider functionality.",
"ExperimentWebPush": "Enables web push notifications through the browser.",
"ExperimentWorkspaceBuildUpdates": "Enables publishing workspace build updates to the all builds pubsub channel.",
"ExperimentWorkspaceUsage": "Enables the new workspace usage tracking."
},
"x-enum-descriptions": [
@@ -13748,8 +13700,7 @@
"Enables web push notifications through the browser.",
"Enables OAuth2 provider functionality.",
"Enables agent-powered chat functionality.",
"Enables the MCP HTTP server functionality.",
"Enables publishing workspace build updates to the all builds pubsub channel."
"Enables the MCP HTTP server functionality."
],
"x-enum-varnames": [
"ExperimentExample",
@@ -13759,8 +13710,7 @@
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentAgents",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceBuildUpdates"
"ExperimentMCPServerHTTP"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -13884,15 +13834,12 @@
"type": "string"
},
"mcp_tool_allow_regex": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"mcp_tool_deny_regex": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"mcp_url": {
"description": "Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a future release.",
"type": "string"
},
"no_refresh": {
@@ -16871,8 +16818,7 @@
"idp_sync_settings_role",
"workspace_agent",
"workspace_app",
"task",
"ai_seat"
"task"
],
"x-enum-varnames": [
"ResourceTypeTemplate",
@@ -16900,8 +16846,7 @@
"ResourceTypeIdpSyncSettingsRole",
"ResourceTypeWorkspaceAgent",
"ResourceTypeWorkspaceApp",
"ResourceTypeTask",
"ResourceTypeAISeat"
"ResourceTypeTask"
]
},
"codersdk.Response": {
+1 -2
View File
@@ -32,8 +32,7 @@ type Auditable interface {
idpsync.OrganizationSyncSettings |
idpsync.GroupSyncSettings |
idpsync.RoleSyncSettings |
database.TaskTable |
database.AiSeatState
database.TaskTable
}
// Map is a map of changed fields in an audited resource. It maps field names to
-8
View File
@@ -132,8 +132,6 @@ func ResourceTarget[T Auditable](tgt T) string {
return "Organization Role Sync"
case database.TaskTable:
return typed.Name
case database.AiSeatState:
return "AI Seat"
default:
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
}
@@ -198,8 +196,6 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
return noID // Org field on audit log has org id
case database.TaskTable:
return typed.ID
case database.AiSeatState:
return typed.UserID
default:
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
}
@@ -255,8 +251,6 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
return database.ResourceTypeIdpSyncSettingsGroup
case database.TaskTable:
return database.ResourceTypeTask
case database.AiSeatState:
return database.ResourceTypeAiSeat
default:
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
}
@@ -315,8 +309,6 @@ func ResourceRequiresOrgID[T Auditable]() bool {
return true
case database.TaskTable:
return true
case database.AiSeatState:
return false
default:
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
}
+3 -1
View File
@@ -240,7 +240,9 @@ func (c *Compressor) serveRef(w http.ResponseWriter, r *http.Request, headers ht
}
for key, values := range headers {
w.Header()[key] = values
for _, value := range values {
w.Header().Add(key, value)
}
}
w.Header().Set("Content-Encoding", cref.key.encoding)
w.Header().Add("Vary", "Accept-Encoding")
@@ -155,41 +155,6 @@ type nopEncoder struct {
func (nopEncoder) Close() error { return nil }
func TestCompressorPresetHeaders(t *testing.T) {
t.Parallel()
logger := testutil.Logger(t)
tempDir := t.TempDir()
cacheDir := filepath.Join(tempDir, "cache")
err := os.MkdirAll(cacheDir, 0o700)
require.NoError(t, err)
srcDir := filepath.Join(tempDir, "src")
err = os.MkdirAll(srcDir, 0o700)
require.NoError(t, err)
err = os.WriteFile(filepath.Join(srcDir, "file.html"), []byte("textstring"), 0o600)
require.NoError(t, err)
compressor := NewCompressor(logger, 5, cacheDir, http.FS(os.DirFS(srcDir)))
for range 2 {
ctx := testutil.Context(t, testutil.WaitShort)
req := httptest.NewRequestWithContext(ctx, "GET", "/file.html", nil)
req.Header.Set("Accept-Encoding", "gzip")
respRec := httptest.NewRecorder()
respRec.Header().Set("X-Original-Content-Length", "10")
respRec.Header().Set("ETag", `"abc123"`)
compressor.ServeHTTP(respRec, req)
resp := respRec.Result()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, []string{"10"}, resp.Header.Values("X-Original-Content-Length"))
require.Equal(t, []string{`"abc123"`}, resp.Header.Values("ETag"))
require.NoError(t, resp.Body.Close())
}
}
// nolint: tparallel // we want to assert the state of the cache, so run synchronously
func TestCompressorHeadings(t *testing.T) {
t.Parallel()
-71
View File
@@ -1,71 +0,0 @@
package chatcost
import (
"github.com/shopspring/decimal"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
// Returns cost in micros -- millionths of a dollar, rounded up to the next
// whole microdollar.
// Returns nil when pricing is not configured or when all priced usage fields
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
func CalculateTotalCostMicros(
usage codersdk.ChatMessageUsage,
cost *codersdk.ModelCostConfig,
) *int64 {
if cost == nil {
return nil
}
// A cost config with no prices set means pricing is effectively
// unconfigured — return nil (unpriced) rather than zero.
if cost.InputPricePerMillionTokens == nil &&
cost.OutputPricePerMillionTokens == nil &&
cost.CacheReadPricePerMillionTokens == nil &&
cost.CacheWritePricePerMillionTokens == nil {
return nil
}
if usage.InputTokens == nil &&
usage.OutputTokens == nil &&
usage.ReasoningTokens == nil &&
usage.CacheCreationTokens == nil &&
usage.CacheReadTokens == nil {
return nil
}
// OutputTokens already includes reasoning tokens per provider
// semantics (e.g. OpenAI's completion_tokens encompasses
// reasoning_tokens). Adding ReasoningTokens here would
// double-count.
// Preserve nil when usage exists only in categories without configured
// pricing, so callers can distinguish "unpriced" from "priced at zero".
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
if !hasMatchingPrice {
return nil
}
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
total := inputMicros.
Add(outputMicros).
Add(cacheReadMicros).
Add(cacheWriteMicros)
rounded := total.Ceil().IntPart()
return &rounded
}
// calcCost returns the cost in fractional microdollars (millionths of a USD)
// for the given token count at the specified per-million-token price.
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
}
-163
View File
@@ -1,163 +0,0 @@
package chatcost_test
import (
"testing"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
func TestCalculateTotalCostMicros(t *testing.T) {
t.Parallel()
tests := []struct {
name string
usage codersdk.ChatMessageUsage
cost *codersdk.ModelCostConfig
want *int64
}{
{
name: "nil cost returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: nil,
want: nil,
},
{
name: "all priced usage fields nil returns nil",
usage: codersdk.ChatMessageUsage{
TotalTokens: ptr.Ref[int64](1234),
ContextLimit: ptr.Ref[int64](8192),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: nil,
},
{
name: "sub-micro total rounds up to 1",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
},
want: ptr.Ref[int64](1),
},
{
name: "simple input only",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "simple output only",
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "reasoning tokens included in output total",
usage: codersdk.ChatMessageUsage{
OutputTokens: ptr.Ref[int64](500),
ReasoningTokens: ptr.Ref[int64](200),
},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "cache read tokens",
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
cost: &codersdk.ModelCostConfig{
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "cache creation tokens",
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
cost: &codersdk.ModelCostConfig{
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
},
want: ptr.Ref[int64](18750),
},
{
name: "full mixed usage totals all components exactly",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](101),
OutputTokens: ptr.Ref[int64](201),
ReasoningTokens: ptr.Ref[int64](52),
CacheReadTokens: ptr.Ref[int64](1005),
CacheCreationTokens: ptr.Ref[int64](33),
TotalTokens: ptr.Ref[int64](1391),
ContextLimit: ptr.Ref[int64](4096),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
},
want: ptr.Ref[int64](2005),
},
{
name: "partial pricing only input contributes",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](1234),
OutputTokens: ptr.Ref[int64](999),
ReasoningTokens: ptr.Ref[int64](111),
CacheReadTokens: ptr.Ref[int64](500),
CacheCreationTokens: ptr.Ref[int64](250),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
},
want: ptr.Ref[int64](3085),
},
{
name: "zero tokens with pricing returns zero pointer",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](0),
},
{
name: "usage only in unpriced categories returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: nil,
},
{
name: "non nil usage with empty cost config returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
cost: &codersdk.ModelCostConfig{},
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
if tt.want == nil {
require.Nil(t, got)
} else {
require.NotNil(t, got)
require.Equal(t, *tt.want, *got)
}
})
}
}
+304 -635
View File
File diff suppressed because it is too large Load Diff
+61 -1011
View File
File diff suppressed because it is too large Load Diff
+85 -164
View File
@@ -8,7 +8,6 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"charm.land/fantasy"
@@ -63,16 +62,9 @@ type RunOptions struct {
// of the provider, which lives in chatd, not chatloop.
ProviderOptions fantasy.ProviderOptions
// ProviderTools are provider-native tools (like web search
// and computer use) whose definitions are passed directly
// to the provider API. When a ProviderTool has a non-nil
// Runner, tool calls are executed locally; otherwise the
// provider handles execution (e.g. web search).
ProviderTools []ProviderTool
PersistStep func(context.Context, PersistedStep) error
PublishMessagePart func(
role codersdk.ChatMessageRole,
role fantasy.MessageRole,
part codersdk.ChatMessagePart,
)
Compaction *CompactionOptions
@@ -89,16 +81,6 @@ type RunOptions struct {
OnInterruptedPersistError func(error)
}
// ProviderTool pairs a provider-native tool definition with an
// optional local executor. When Runner is nil the tool is fully
// provider-executed (e.g. web search). When Runner is non-nil
// the definition is sent to the API but execution is handled
// locally (e.g. computer use).
type ProviderTool struct {
Definition fantasy.Tool
Runner fantasy.AgentTool
}
// stepResult holds the accumulated output of a single streaming
// step. Since we own the stream consumer, all content is tracked
// directly here — no shadow draft state needed.
@@ -169,23 +151,11 @@ func (r stepResult) toResponseMessages() []fantasy.Message {
if !ok {
continue
}
part := fantasy.ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderExecuted: result.ProviderExecuted,
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
}
// Provider-executed tool results (e.g. web_search)
// must stay in the assistant message so the result
// block appears inline after the corresponding
// server_tool_use block. This matches the persistence
// layer in chatd.go which keeps them in
// assistantBlocks.
if result.ProviderExecuted {
assistantParts = append(assistantParts, part)
} else {
toolParts = append(toolParts, part)
}
toolParts = append(toolParts, fantasy.ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
})
default:
continue
}
@@ -227,14 +197,14 @@ func Run(ctx context.Context, opts RunOptions) error {
opts.MaxSteps = 1
}
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
if opts.PublishMessagePart == nil {
return
}
opts.PublishMessagePart(role, part)
}
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools, opts.ProviderTools)
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools)
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
messages := opts.Messages
@@ -326,29 +296,17 @@ func Run(ctx context.Context, opts RunOptions) error {
return ctx.Err()
}
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
publishMessagePart(
codersdk.ChatMessageRoleTool,
fantasy.MessageRoleTool,
chatprompt.PartFromContent(tr),
)
})
for _, tr := range toolResults {
result.content = append(result.content, tr)
}
// Check for interruption after tool execution.
// Tools that were canceled mid-flight produce error
// results via ctx cancellation. Persist the full
// step (assistant blocks + tool results) through
// the interrupt-safe path so nothing is lost.
if ctx.Err() != nil {
if errors.Is(context.Cause(ctx), ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return ctx.Err()
}
}
// Extract context limit from provider metadata.
contextLimit := extractContextLimit(result.providerMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
@@ -357,21 +315,16 @@ func Run(ctx context.Context, opts RunOptions) error {
Valid: true,
}
}
// Persist the step. If persistence fails because
// the chat was interrupted between the previous
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
// Persist the step — errors propagate directly.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return xerrors.Errorf("persist step: %w", err)
}
lastUsage = result.usage
lastProviderMetadata = result.providerMetadata
@@ -466,7 +419,7 @@ func Run(ctx context.Context, opts RunOptions) error {
func processStepStream(
ctx context.Context,
stream fantasy.StreamResponse,
publishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart),
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
) (stepResult, error) {
var result stepResult
@@ -475,6 +428,27 @@ func processStepStream(
activeReasoningContent := make(map[string]reasoningState)
// Track tool names by ID for input delta publishing.
toolNames := make(map[string]string)
// Track reasoning text/titles for title extraction.
reasoningTitles := make(map[string]string)
reasoningText := make(map[string]string)
setReasoningTitleFromText := func(id string, text string) {
if id == "" || strings.TrimSpace(text) == "" {
return
}
if reasoningTitles[id] != "" {
return
}
reasoningText[id] += text
if !strings.ContainsAny(reasoningText[id], "\r\n") {
return
}
title := chatprompt.ReasoningTitleFromFirstLine(reasoningText[id])
if title == "" {
return
}
reasoningTitles[id] = title
}
for part := range stream {
switch part.Type {
@@ -485,7 +459,10 @@ func processStepStream(
if _, exists := activeTextContent[part.ID]; exists {
activeTextContent[part.ID] += part.Delta
}
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(part.Delta))
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: part.Delta,
})
case fantasy.StreamPartTypeTextEnd:
if text, exists := activeTextContent[part.ID]; exists {
@@ -508,7 +485,13 @@ func processStepStream(
active.options = part.ProviderMetadata
activeReasoningContent[part.ID] = active
}
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageReasoning(part.Delta))
setReasoningTitleFromText(part.ID, part.Delta)
title := reasoningTitles[part.ID]
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: part.Delta,
Title: title,
})
case fantasy.StreamPartTypeReasoningEnd:
if active, exists := activeReasoningContent[part.ID]; exists {
@@ -521,6 +504,21 @@ func processStepStream(
}
result.content = append(result.content, content)
delete(activeReasoningContent, part.ID)
// Derive reasoning title at end of reasoning
// block if we haven't yet.
if reasoningTitles[part.ID] == "" {
reasoningTitles[part.ID] = chatprompt.ReasoningTitleFromFirstLine(
reasoningText[part.ID],
)
}
title := reasoningTitles[part.ID]
if title != "" {
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Title: title,
})
}
}
case fantasy.StreamPartTypeToolInputStart:
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
@@ -534,19 +532,17 @@ func processStepStream(
}
case fantasy.StreamPartTypeToolInputDelta:
var providerExecuted bool
if toolCall, exists := activeToolCalls[part.ID]; exists {
toolCall.Input += part.Delta
providerExecuted = toolCall.ProviderExecuted
}
toolName := toolNames[part.ID]
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: part.ID,
ToolName: toolName,
ArgsDelta: part.Delta,
ProviderExecuted: providerExecuted,
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: part.ID,
ToolName: toolName,
ArgsDelta: part.Delta,
})
case fantasy.StreamPartTypeToolInputEnd:
// No callback needed; the full tool call arrives in
// StreamPartTypeToolCall.
@@ -568,7 +564,7 @@ func processStepStream(
delete(activeToolCalls, part.ID)
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(tc),
)
@@ -582,28 +578,10 @@ func processStepStream(
}
result.content = append(result.content, sourceContent)
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
fantasy.MessageRoleAssistant,
chatprompt.PartFromContent(sourceContent),
)
case fantasy.StreamPartTypeToolResult:
// Provider-executed tool results (e.g. web search)
// are emitted by the provider and added directly
// to the step content for multi-turn round-tripping.
// This mirrors fantasy's agent.go accumulation logic.
if part.ProviderExecuted {
tr := fantasy.ToolResultContent{
ToolCallID: part.ID,
ToolName: part.ToolCallName,
ProviderExecuted: part.ProviderExecuted,
ProviderMetadata: part.ProviderMetadata,
}
result.content = append(result.content, tr)
publishMessagePart(
codersdk.ChatMessageRoleTool,
chatprompt.PartFromContent(tr),
)
}
case fantasy.StreamPartTypeFinish:
result.usage = part.Usage
result.finishReason = part.FinishReason
@@ -631,26 +609,17 @@ func processStepStream(
}
}
hasLocalToolCalls := false
for _, tc := range result.toolCalls {
if !tc.ProviderExecuted {
hasLocalToolCalls = true
break
}
}
result.shouldContinue = hasLocalToolCalls &&
result.shouldContinue = len(result.toolCalls) > 0 &&
result.finishReason == fantasy.FinishReasonToolCalls
return result, nil
}
// executeTools runs all tool calls concurrently after the stream
// completes. Results are published via onResult in the original
// tool-call order after all tools finish, preserving deterministic
// event ordering for SSE subscribers.
// executeTools runs each tool call sequentially after the stream
// completes. Results are published via onResult as each tool
// finishes.
func executeTools(
ctx context.Context,
allTools []fantasy.AgentTool,
providerTools []ProviderTool,
toolCalls []fantasy.ToolCallContent,
onResult func(fantasy.ToolResultContent),
) []fantasy.ToolResultContent {
@@ -658,58 +627,16 @@ func executeTools(
return nil
}
// Filter out provider-executed tool calls. These were
// handled server-side by the LLM provider (e.g., web
// search) and their results are already in the stream
// content.
localToolCalls := make([]fantasy.ToolCallContent, 0, len(toolCalls))
for _, tc := range toolCalls {
if !tc.ProviderExecuted {
localToolCalls = append(localToolCalls, tc)
}
}
if len(localToolCalls) == 0 {
return nil
}
toolMap := make(map[string]fantasy.AgentTool, len(allTools))
for _, t := range allTools {
toolMap[t.Info().Name] = t
}
// Include runners from provider tools so locally-executed
// provider tools (e.g. computer use) can be dispatched.
for _, pt := range providerTools {
if pt.Runner != nil {
toolMap[pt.Runner.Info().Name] = pt.Runner
}
}
results := make([]fantasy.ToolResultContent, len(localToolCalls))
var wg sync.WaitGroup
wg.Add(len(localToolCalls))
for i, tc := range localToolCalls {
go func(i int, tc fantasy.ToolCallContent) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
results[i] = fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.Errorf("tool panicked: %v", r),
},
}
}
}()
results[i] = executeSingleTool(ctx, toolMap, tc)
}(i, tc)
}
wg.Wait()
// Publish results in the original tool-call order so SSE
// subscribers see a deterministic event sequence.
if onResult != nil {
for _, tr := range results {
results := make([]fantasy.ToolResultContent, 0, len(toolCalls))
for _, tc := range toolCalls {
tr := executeSingleTool(ctx, toolMap, tc)
results = append(results, tr)
if onResult != nil {
onResult(tr)
}
}
@@ -859,9 +786,8 @@ func persistInterruptedStep(
continue
}
content = append(content, fantasy.ToolResultContent{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
ProviderExecuted: tc.ProviderExecuted,
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New(interruptedToolResultErrorMessage),
},
@@ -881,17 +807,15 @@ func persistInterruptedStep(
// buildToolDefinitions converts AgentTool definitions into the
// fantasy.Tool slice expected by fantasy.Call. When activeTools
// is non-empty, only function tools whose name appears in the
// list are included. Provider tool definitions are always
// appended unconditionally.
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []ProviderTool) []fantasy.Tool {
prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools))
// is non-empty, only tools whose name appears in the list are
// included. This mirrors fantasy's agent.prepareTools filtering.
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fantasy.Tool {
prepared := make([]fantasy.Tool, 0, len(tools))
for _, tool := range tools {
info := tool.Info()
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
continue
}
inputSchema := map[string]any{
"type": "object",
"properties": info.Parameters,
@@ -905,9 +829,6 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi
ProviderOptions: tool.ProviderOptions(),
})
}
for _, pt := range providerTools {
prepared = append(prepared, pt.Definition)
}
return prepared
}
-252
View File
@@ -499,82 +499,6 @@ func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) {
assert.ErrorIs(t, err, context.Canceled, "shutdown should propagate as context.Canceled")
}
func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *testing.T) {
t.Parallel()
sr := stepResult{
content: []fantasy.Content{
// Provider-executed tool call (e.g. web_search).
fantasy.ToolCallContent{
ToolCallID: "provider-tc-1",
ToolName: "web_search",
Input: `{"query":"coder"}`,
ProviderExecuted: true,
},
// Provider-executed tool result — must stay in
// assistant message.
fantasy.ToolResultContent{
ToolCallID: "provider-tc-1",
ToolName: "web_search",
ProviderExecuted: true,
ProviderMetadata: fantasy.ProviderMetadata{"anthropic": nil},
},
// Local tool call (e.g. read_file).
fantasy.ToolCallContent{
ToolCallID: "local-tc-1",
ToolName: "read_file",
Input: `{"path":"main.go"}`,
ProviderExecuted: false,
},
// Local tool result — should go into tool message.
fantasy.ToolResultContent{
ToolCallID: "local-tc-1",
ToolName: "read_file",
Result: fantasy.ToolResultOutputContentText{Text: "some result"},
ProviderExecuted: false,
},
},
}
msgs := sr.toResponseMessages()
require.Len(t, msgs, 2, "expected assistant + tool messages")
// First message: assistant role.
assistantMsg := msgs[0]
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
require.Len(t, assistantMsg.Content, 3,
"assistant message should have provider ToolCallPart, provider ToolResultPart, and local ToolCallPart")
// Part 0: provider tool call.
providerTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[0])
require.True(t, ok, "part 0 should be ToolCallPart")
assert.Equal(t, "provider-tc-1", providerTC.ToolCallID)
assert.True(t, providerTC.ProviderExecuted)
// Part 1: provider tool result (inline in assistant turn).
providerTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](assistantMsg.Content[1])
require.True(t, ok, "part 1 should be ToolResultPart")
assert.Equal(t, "provider-tc-1", providerTR.ToolCallID)
assert.True(t, providerTR.ProviderExecuted)
// Part 2: local tool call.
localTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[2])
require.True(t, ok, "part 2 should be ToolCallPart")
assert.Equal(t, "local-tc-1", localTC.ToolCallID)
assert.False(t, localTC.ProviderExecuted)
// Second message: tool role.
toolMsg := msgs[1]
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1,
"tool message should have only the local ToolResultPart")
localTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "tool part should be ToolResultPart")
assert.Equal(t, "local-tc-1", localTR.ToolCallID)
assert.False(t, localTR.ProviderExecuted)
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
@@ -588,179 +512,3 @@ func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
return ok && cacheOptions.CacheControl.Type == "ephemeral"
}
// TestRun_InterruptedDuringToolExecutionPersistsStep verifies that when
// tools are executing and the chat is interrupted, the accumulated step
// content (assistant blocks + tool results) is persisted via the
// interrupt-safe path rather than being lost.
func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
t.Parallel()
toolStarted := make(chan struct{})
// Model returns a completed tool call in the stream.
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"},
{Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "let me think"},
{Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-1"},
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "slow_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"key":"value"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "slow_tool",
ToolCallInput: `{"key":"value"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
},
}
// Tool that blocks until context is canceled, simulating
// a long-running operation interrupted by the user.
slowTool := fantasy.NewAgentTool(
"slow_tool",
"blocks until canceled",
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
close(toolStarted)
<-ctx.Done()
return fantasy.ToolResponse{}, ctx.Err()
},
)
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-toolStarted
cancel(ErrInterrupted)
}()
var persistedContent []fantasy.Content
persistedCtxErr := xerrors.New("unset")
err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "run the slow tool"),
},
Tools: []fantasy.AgentTool{slowTool},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
// persistInterruptedStep uses context.WithoutCancel, so the
// persist callback should see a non-canceled context.
require.NoError(t, persistedCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundReasoning bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "calling tool") {
foundText = true
}
continue
}
if reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](block); ok {
if strings.Contains(reasoning.Text, "let me think") {
foundReasoning = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "tc-1" && toolCall.ToolName == "slow_tool" {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "tc-1" {
foundToolResult = true
}
}
}
require.True(t, foundText, "persisted content should include text from the stream")
require.True(t, foundReasoning, "persisted content should include reasoning from the stream")
require.True(t, foundToolCall, "persisted content should include the tool call")
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
}
// TestRun_PersistStepInterruptedFallback verifies that when the normal
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
// race), the step is retried via the interrupt-safe path.
func TestRun_PersistStepInterruptedFallback(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
var (
mu sync.Mutex
persistCalls int
savedContent []fantasy.Content
)
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, step PersistedStep) error {
mu.Lock()
defer mu.Unlock()
persistCalls++
if persistCalls == 1 {
// First call: simulate an interrupt race by
// returning ErrInterrupted without persisting.
return ErrInterrupted
}
// Second call (from persistInterruptedStep fallback):
// accept the content.
savedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
mu.Lock()
defer mu.Unlock()
require.Equal(t, 2, persistCalls, "PersistStep should be called twice: once normally (failing), once via fallback")
require.NotEmpty(t, savedContent)
var foundText bool
for _, block := range savedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "hello world") {
foundText = true
}
}
}
require.True(t, foundText, "fallback should persist the text content")
}
+29 -27
View File
@@ -17,26 +17,13 @@ const (
minCompactionThresholdPercent = int32(0)
maxCompactionThresholdPercent = int32(100)
defaultCompactionSummaryPrompt = "You are performing a context compaction. " +
"Summarize the conversation so a new assistant can seamlessly " +
"continue the work in progress.\n\n" +
"Include:\n" +
"- The user's overall goal and current task\n" +
"- Key decisions made and their rationale\n" +
"- Concrete technical details: file paths, function names, " +
"commands, APIs, and configurations\n" +
"- Errors encountered and how they were resolved\n" +
"- Current state of the work: what is DONE, what is IN PROGRESS, " +
"and what REMAINS to be done\n" +
"- The specific action the assistant was performing or about to " +
"perform when this summary was triggered\n\n" +
"Be dense and factual. Every sentence should convey essential " +
"context for continuation. Do not include pleasantries or " +
"conversational filler."
defaultCompactionSystemSummaryPrefix = "The following is a summary of " +
"the earlier conversation. The assistant was actively working when " +
"the context was compacted. Continue the work described below:"
defaultCompactionTimeout = 90 * time.Second
defaultCompactionSummaryPrompt = "Summarize the current chat so a " +
"new assistant can continue seamlessly. Include the user's goals, " +
"decisions made, concrete technical details (files, commands, APIs), " +
"errors encountered and fixes, and open questions. Be dense and factual. " +
"Omit pleasantries and next-step suggestions."
defaultCompactionSystemSummaryPrefix = "Summary of earlier chat context:"
defaultCompactionTimeout = 90 * time.Second
)
type CompactionOptions struct {
@@ -55,7 +42,7 @@ type CompactionOptions struct {
// PublishMessagePart publishes streaming parts to connected
// clients so they see "Summarizing..." / "Summarized" UI
// transitions during compaction.
PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart)
PublishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart)
OnError func(error)
}
@@ -110,8 +97,12 @@ func tryCompact(
// connected clients see activity during summary generation.
if config.PublishMessagePart != nil && config.ToolCallID != "" {
config.PublishMessagePart(
codersdk.ChatMessageRoleAssistant,
codersdk.ChatMessageToolCall(config.ToolCallID, config.ToolName, nil),
fantasy.MessageRoleAssistant,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
},
)
}
@@ -159,8 +150,13 @@ func tryCompact(
"context_limit_tokens": contextLimit,
})
config.PublishMessagePart(
codersdk.ChatMessageRoleTool,
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false),
fantasy.MessageRoleTool,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
Result: resultJSON,
},
)
}
@@ -177,8 +173,14 @@ func publishCompactionError(config CompactionOptions, msg string) {
"error": msg,
})
config.PublishMessagePart(
codersdk.ChatMessageRoleTool,
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true),
fantasy.MessageRoleTool,
codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: config.ToolCallID,
ToolName: config.ToolName,
Result: errJSON,
IsError: true,
},
)
}
+2 -2
View File
@@ -149,7 +149,7 @@ func TestRun_Compaction(t *testing.T) {
SummaryPrompt: "summarize now",
ToolCallID: "test-tool-call-id",
ToolName: "chat_summarized",
PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
PublishMessagePart: func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
switch part.Type {
case codersdk.ChatMessagePartTypeToolCall:
callOrder = append(callOrder, "publish_tool_call")
@@ -218,7 +218,7 @@ func TestRun_Compaction(t *testing.T) {
ThresholdPercent: 70,
ToolCallID: "test-tool-call-id",
ToolName: "chat_summarized",
PublishMessagePart: func(_ codersdk.ChatMessageRole, _ codersdk.ChatMessagePart) {
PublishMessagePart: func(_ fantasy.MessageRole, _ codersdk.ChatMessagePart) {
publishCalled = true
},
Persist: func(_ context.Context, _ CompactionResult) error {
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+20 -38
View File
@@ -553,33 +553,30 @@ func normalizedEnumValue(value string, allowed ...string) *string {
return nil
}
// MergeMissingModelCostConfig fills unset pricing metadata from defaults.
func MergeMissingModelCostConfig(
dst **codersdk.ModelCostConfig,
defaults *codersdk.ModelCostConfig,
// MergeMissingCallConfig fills unset call config values from defaults.
func MergeMissingCallConfig(
dst *codersdk.ChatModelCallConfig,
defaults codersdk.ChatModelCallConfig,
) {
if defaults == nil {
return
if dst.MaxOutputTokens == nil {
dst.MaxOutputTokens = defaults.MaxOutputTokens
}
if *dst == nil {
copied := *defaults
*dst = &copied
return
if dst.Temperature == nil {
dst.Temperature = defaults.Temperature
}
current := *dst
if current.InputPricePerMillionTokens == nil {
current.InputPricePerMillionTokens = defaults.InputPricePerMillionTokens
if dst.TopP == nil {
dst.TopP = defaults.TopP
}
if current.OutputPricePerMillionTokens == nil {
current.OutputPricePerMillionTokens = defaults.OutputPricePerMillionTokens
if dst.TopK == nil {
dst.TopK = defaults.TopK
}
if current.CacheReadPricePerMillionTokens == nil {
current.CacheReadPricePerMillionTokens = defaults.CacheReadPricePerMillionTokens
if dst.PresencePenalty == nil {
dst.PresencePenalty = defaults.PresencePenalty
}
if current.CacheWritePricePerMillionTokens == nil {
current.CacheWritePricePerMillionTokens = defaults.CacheWritePricePerMillionTokens
if dst.FrequencyPenalty == nil {
dst.FrequencyPenalty = defaults.FrequencyPenalty
}
MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions)
}
// MergeMissingProviderOptions fills unset provider option fields from defaults.
@@ -888,14 +885,11 @@ func MergeMissingProviderOptions(
}
// ModelFromConfig resolves a provider/model pair and constructs a fantasy
// language model client using the provided provider credentials. The
// userAgent is sent as the User-Agent header on every outgoing LLM
// API request.
// language model client using the provided provider credentials.
func ModelFromConfig(
providerHint string,
modelName string,
providerKeys ProviderAPIKeys,
userAgent string,
) (fantasy.LanguageModel, error) {
provider, modelID, err := ResolveModelWithProviderHint(modelName, providerHint)
if err != nil {
@@ -913,7 +907,6 @@ func ModelFromConfig(
case fantasyanthropic.Name:
options := []fantasyanthropic.Option{
fantasyanthropic.WithAPIKey(apiKey),
fantasyanthropic.WithUserAgent(userAgent),
}
if baseURL != "" {
options = append(options, fantasyanthropic.WithBaseURL(baseURL))
@@ -927,17 +920,12 @@ func ModelFromConfig(
fantasyazure.WithAPIKey(apiKey),
fantasyazure.WithBaseURL(baseURL),
fantasyazure.WithUseResponsesAPI(),
fantasyazure.WithUserAgent(userAgent),
)
case fantasybedrock.Name:
providerClient, err = fantasybedrock.New(
fantasybedrock.WithAPIKey(apiKey),
fantasybedrock.WithUserAgent(userAgent),
)
providerClient, err = fantasybedrock.New(fantasybedrock.WithAPIKey(apiKey))
case fantasygoogle.Name:
options := []fantasygoogle.Option{
fantasygoogle.WithGeminiAPIKey(apiKey),
fantasygoogle.WithUserAgent(userAgent),
}
if baseURL != "" {
options = append(options, fantasygoogle.WithBaseURL(baseURL))
@@ -947,7 +935,6 @@ func ModelFromConfig(
options := []fantasyopenai.Option{
fantasyopenai.WithAPIKey(apiKey),
fantasyopenai.WithUseResponsesAPI(),
fantasyopenai.WithUserAgent(userAgent),
}
if baseURL != "" {
options = append(options, fantasyopenai.WithBaseURL(baseURL))
@@ -956,21 +943,16 @@ func ModelFromConfig(
case fantasyopenaicompat.Name:
options := []fantasyopenaicompat.Option{
fantasyopenaicompat.WithAPIKey(apiKey),
fantasyopenaicompat.WithUserAgent(userAgent),
}
if baseURL != "" {
options = append(options, fantasyopenaicompat.WithBaseURL(baseURL))
}
providerClient, err = fantasyopenaicompat.New(options...)
case fantasyopenrouter.Name:
providerClient, err = fantasyopenrouter.New(
fantasyopenrouter.WithAPIKey(apiKey),
fantasyopenrouter.WithUserAgent(userAgent),
)
providerClient, err = fantasyopenrouter.New(fantasyopenrouter.WithAPIKey(apiKey))
case fantasyvercel.Name:
options := []fantasyvercel.Option{
fantasyvercel.WithAPIKey(apiKey),
fantasyvercel.WithUserAgent(userAgent),
}
if baseURL != "" {
options = append(options, fantasyvercel.WithBaseURL(baseURL))
@@ -137,6 +137,43 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
t.Parallel()
dst := codersdk.ChatModelCallConfig{
Temperature: float64Ptr(0.2),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("alice"),
},
},
}
defaults := codersdk.ChatModelCallConfig{
MaxOutputTokens: int64Ptr(512),
Temperature: float64Ptr(0.9),
TopP: float64Ptr(0.8),
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("bob"),
ReasoningEffort: stringPtr("medium"),
},
},
}
chatprovider.MergeMissingCallConfig(&dst, defaults)
require.NotNil(t, dst.MaxOutputTokens)
require.EqualValues(t, 512, *dst.MaxOutputTokens)
require.NotNil(t, dst.Temperature)
require.Equal(t, 0.2, *dst.Temperature)
require.NotNil(t, dst.TopP)
require.Equal(t, 0.8, *dst.TopP)
require.NotNil(t, dst.ProviderOptions)
require.NotNil(t, dst.ProviderOptions.OpenAI)
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
}
func stringPtr(value string) *string {
return &value
}
@@ -148,3 +185,7 @@ func boolPtr(value bool) *bool {
func int64Ptr(value int64) *int64 {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
-19
View File
@@ -1,19 +0,0 @@
package chatprovider
import (
"fmt"
"runtime"
"github.com/coder/coder/v2/buildinfo"
)
// UserAgent returns the User-Agent string sent on all outgoing LLM
// API requests made by Coder's built-in chat (chatd). The format
// mirrors conventions used by other coding agents so that LLM
// providers can identify traffic originating from Coder.
//
// Example: coder-agents/v2.21.0 (linux/amd64)
func UserAgent() string {
return fmt.Sprintf("coder-agents/%s (%s/%s)",
buildinfo.Version(), runtime.GOOS, runtime.GOARCH)
}
@@ -1,78 +0,0 @@
package chatprovider_test
import (
"context"
"runtime"
"strings"
"sync"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattest"
)
func TestUserAgent(t *testing.T) {
t.Parallel()
ua := chatprovider.UserAgent()
// Must start with "coder-agents/" so LLM providers can
// identify traffic from Coder.
require.True(t, strings.HasPrefix(ua, "coder-agents/"),
"User-Agent should start with 'coder-agents/', got %q", ua)
// Must contain the build version.
assert.Contains(t, ua, buildinfo.Version())
// Must contain OS/arch.
assert.Contains(t, ua, runtime.GOOS+"/"+runtime.GOARCH)
}
func TestModelFromConfig_UserAgent(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var capturedUA string
serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
mu.Lock()
capturedUA = req.Header.Get("User-Agent")
mu.Unlock()
return chattest.OpenAINonStreamingResponse("hello")
})
expectedUA := chatprovider.UserAgent()
keys := chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{"openai": "test-key"},
BaseURLByProvider: map[string]string{"openai": serverURL},
}
model, err := chatprovider.ModelFromConfig("openai", "gpt-4", keys, expectedUA)
require.NoError(t, err)
// Make a real call so Fantasy sends an HTTP request to the
// fake server, which captures the User-Agent header.
_, err = model.Generate(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: "hello"},
},
},
},
})
require.NoError(t, err)
mu.Lock()
got := capturedUA
mu.Unlock()
require.NotEmpty(t, got, "User-Agent header was not sent")
require.Equal(t, expectedUA, got,
"User-Agent header should match chatprovider.UserAgent()")
}
+3 -6
View File
@@ -96,7 +96,6 @@ type AnthropicDeltaBlock struct {
// anthropicServer is a test server that mocks the Anthropic API.
type anthropicServer struct {
mu sync.Mutex
t testing.TB
server *httptest.Server
handler AnthropicHandler
request *AnthropicRequest
@@ -110,7 +109,6 @@ func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
t.Helper()
s := &anthropicServer{
t: t,
handler: handler,
}
@@ -145,7 +143,7 @@ func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request)
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
if resp.Error != nil {
writeErrorResponse(s.t, w, resp.Error)
writeErrorResponse(w, resp.Error)
return
}
@@ -225,6 +223,7 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
}
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
_ = s // receiver unused but kept for consistency
response := map[string]interface{}{
"id": resp.ID,
"type": resp.Type,
@@ -242,9 +241,7 @@ func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp
w.Header().Set("Content-Type", "application/json")
w.Header().Set("anthropic-version", "2023-06-01")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.t.Errorf("writeNonStreamingResponse: failed to encode response: %v", err)
}
_ = json.NewEncoder(w).Encode(response)
}
// AnthropicStreamingResponse creates a streaming response from chunks.
+2 -5
View File
@@ -3,7 +3,6 @@ package chattest
import (
"encoding/json"
"net/http"
"testing"
)
// ErrorResponse describes an HTTP error that a test server should return
@@ -16,7 +15,7 @@ type ErrorResponse struct {
// writeErrorResponse writes a JSON error response matching the common
// provider error format used by both Anthropic and OpenAI.
func writeErrorResponse(t testing.TB, w http.ResponseWriter, errResp *ErrorResponse) {
func writeErrorResponse(w http.ResponseWriter, errResp *ErrorResponse) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(errResp.StatusCode)
body := map[string]interface{}{
@@ -25,9 +24,7 @@ func writeErrorResponse(t testing.TB, w http.ResponseWriter, errResp *ErrorRespo
"message": errResp.Message,
},
}
if err := json.NewEncoder(w).Encode(body); err != nil {
t.Errorf("writeErrorResponse: failed to encode error response: %v", err)
}
_ = json.NewEncoder(w).Encode(body)
}
// AnthropicErrorResponse returns an AnthropicResponse that causes the
+13 -29
View File
@@ -113,7 +113,6 @@ type OpenAICompletion struct {
// openAIServer is a test server that mocks the OpenAI API.
type openAIServer struct {
mu sync.Mutex
t testing.TB
server *httptest.Server
handler OpenAIHandler
request *OpenAIRequest
@@ -127,7 +126,6 @@ func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
t.Helper()
s := &openAIServer{
t: t,
handler: handler,
}
@@ -178,7 +176,7 @@ func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
if resp.Error != nil {
writeErrorResponse(s.t, w, resp.Error)
writeErrorResponse(w, resp.Error)
return
}
@@ -207,7 +205,7 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
if resp.Error != nil {
writeErrorResponse(s.t, w, resp.Error)
writeErrorResponse(w, resp.Error)
return
}
@@ -228,7 +226,7 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
return
case hasStreaming:
writeResponsesAPIStreaming(s.t, w, req.Request, resp.StreamingChunks)
writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks)
default:
s.writeResponsesAPINonStreaming(w, resp.Response)
}
@@ -320,7 +318,7 @@ func writeSSEEvent(w http.ResponseWriter, v interface{}) error {
return err
}
func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
@@ -347,28 +345,19 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
// the fantasy client closes open text
// blocks and persists the step content.
for outputIndex, itemID := range itemIDs {
if err := writeSSEEvent(w, responses.ResponseTextDoneEvent{
_ = writeSSEEvent(w, responses.ResponseTextDoneEvent{
ItemID: itemID,
OutputIndex: int64(outputIndex),
}); err != nil {
t.Logf("writeResponsesAPIStreaming: failed to write ResponseTextDoneEvent: %v", err)
return
}
if err := writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
})
_ = writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
OutputIndex: int64(outputIndex),
Item: responses.ResponseOutputItemUnion{
ID: itemID,
Type: "message",
},
}); err != nil {
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemDoneEvent: %v", err)
return
}
}
if err := writeSSEEvent(w, responses.ResponseCompletedEvent{}); err != nil {
t.Logf("writeResponsesAPIStreaming: failed to write ResponseCompletedEvent: %v", err)
return
})
}
_ = writeSSEEvent(w, responses.ResponseCompletedEvent{})
flusher.Flush()
return
}
@@ -393,7 +382,6 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
Type: "message",
},
}); err != nil {
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemAddedEvent: %v", err)
return
}
flusher.Flush()
@@ -411,12 +399,10 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
chunkBytes, err := json.Marshal(chunkData)
if err != nil {
t.Logf("writeResponsesAPIStreaming: failed to marshal chunk data: %v", err)
return
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
t.Logf("writeResponsesAPIStreaming: failed to write chunk data: %v", err)
return
}
flusher.Flush()
@@ -425,13 +411,13 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
}
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.t.Errorf("writeChatCompletionsNonStreaming: failed to encode response: %v", err)
}
_ = json.NewEncoder(w).Encode(resp)
}
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
_ = s // receiver unused but kept for consistency
// Convert all choices to output format
outputs := make([]map[string]interface{}, len(resp.Choices))
for i, choice := range resp.Choices {
@@ -457,9 +443,7 @@ func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp
"usage": resp.Usage,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
s.t.Errorf("writeResponsesAPINonStreaming: failed to encode response: %v", err)
}
_ = json.NewEncoder(w).Encode(response)
}
// OpenAIStreamingResponse creates a streaming response from chunks.
-220
View File
@@ -1,220 +0,0 @@
package chattool
import (
"context"
"fmt"
"math"
"time"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
const (
// ComputerUseModelProvider is the provider for the computer
// use model.
ComputerUseModelProvider = "anthropic"
// ComputerUseModelName is the model used for computer use
// subagents.
ComputerUseModelName = "claude-opus-4-6"
)
// computerUseTool implements fantasy.AgentTool and
// chatloop.ToolDefiner for Anthropic computer use.
type computerUseTool struct {
displayWidth int
displayHeight int
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error)
providerOptions fantasy.ProviderOptions
clock quartz.Clock
}
// NewComputerUseTool creates a computer use AgentTool that
// delegates to the agent's desktop endpoints.
func NewComputerUseTool(
displayWidth, displayHeight int,
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error),
clock quartz.Clock,
) fantasy.AgentTool {
return &computerUseTool{
displayWidth: displayWidth,
displayHeight: displayHeight,
getWorkspaceConn: getWorkspaceConn,
clock: clock,
}
}
func (*computerUseTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: "computer",
Description: "Control the desktop: take screenshots, move the mouse, click, type, and scroll.",
Parameters: map[string]any{},
Required: []string{},
}
}
// ComputerUseProviderTool creates the provider-defined tool
// definition for Anthropic computer use. This is passed via
// ProviderTools so the API receives the correct wire format.
func ComputerUseProviderTool(displayWidth, displayHeight int) fantasy.Tool {
return fantasyanthropic.NewComputerUseTool(
fantasyanthropic.ComputerUseToolOptions{
DisplayWidthPx: int64(displayWidth),
DisplayHeightPx: int64(displayHeight),
ToolVersion: fantasyanthropic.ComputerUse20251124,
},
)
}
func (t *computerUseTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
func (t *computerUseTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.providerOptions = opts
}
func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
input, err := fantasyanthropic.ParseComputerUseInput(call.Input)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("invalid computer use input: %v", err),
), nil
}
conn, err := t.getWorkspaceConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("failed to connect to workspace: %v", err),
), nil
}
// Compute scaled screenshot size for Anthropic constraints.
scaledW, scaledH := computeScaledScreenshotSize(
t.displayWidth, t.displayHeight,
)
// For wait actions, sleep then return a screenshot.
if input.Action == fantasyanthropic.ActionWait {
d := input.Duration
if d <= 0 {
d = 1000
}
timer := t.clock.NewTimer(time.Duration(d)*time.Millisecond, "computeruse", "wait")
defer timer.Stop()
select {
case <-ctx.Done():
case <-timer.C:
}
screenshotAction := workspacesdk.DesktopAction{
Action: "screenshot",
ScaledWidth: &scaledW,
ScaledHeight: &scaledH,
}
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
if sErr != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("screenshot failed: %v", sErr),
), nil
}
return fantasy.NewImageResponse(
[]byte(screenResp.ScreenshotData), "image/png",
), nil
}
// For screenshot action, use ExecuteDesktopAction.
if input.Action == fantasyanthropic.ActionScreenshot {
screenshotAction := workspacesdk.DesktopAction{
Action: "screenshot",
ScaledWidth: &scaledW,
ScaledHeight: &scaledH,
}
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
if sErr != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("screenshot failed: %v", sErr),
), nil
}
return fantasy.NewImageResponse(
[]byte(screenResp.ScreenshotData), "image/png",
), nil
}
// Build the action request.
action := workspacesdk.DesktopAction{
Action: string(input.Action),
ScaledWidth: &scaledW,
ScaledHeight: &scaledH,
}
if input.Coordinate != ([2]int64{}) {
coord := [2]int{int(input.Coordinate[0]), int(input.Coordinate[1])}
action.Coordinate = &coord
}
if input.StartCoordinate != ([2]int64{}) {
coord := [2]int{int(input.StartCoordinate[0]), int(input.StartCoordinate[1])}
action.StartCoordinate = &coord
}
if input.Text != "" {
action.Text = &input.Text
}
if input.Duration > 0 {
d := int(input.Duration)
action.Duration = &d
}
if input.ScrollAmount > 0 {
s := int(input.ScrollAmount)
action.ScrollAmount = &s
}
if input.ScrollDirection != "" {
action.ScrollDirection = &input.ScrollDirection
}
// Execute the action.
_, err = conn.ExecuteDesktopAction(ctx, action)
if err != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("action %q failed: %v", input.Action, err),
), nil
}
// Take a screenshot after every action (Anthropic pattern).
screenshotAction := workspacesdk.DesktopAction{
Action: "screenshot",
ScaledWidth: &scaledW,
ScaledHeight: &scaledH,
}
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
if sErr != nil {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("screenshot failed: %v", sErr),
), nil
}
return fantasy.NewImageResponse(
[]byte(screenResp.ScreenshotData), "image/png",
), nil
}
// computeScaledScreenshotSize computes the target screenshot
// dimensions to fit within Anthropic's constraints.
func computeScaledScreenshotSize(width, height int) (scaledWidth int, scaledHeight int) {
const maxLongEdge = 1568
const maxTotalPixels = 1_150_000
longEdge := max(width, height)
totalPixels := width * height
longEdgeScale := float64(maxLongEdge) / float64(longEdge)
totalPixelsScale := math.Sqrt(
float64(maxTotalPixels) / float64(totalPixels),
)
scale := min(1.0, longEdgeScale, totalPixelsScale)
if scale >= 1.0 {
return width, height
}
return max(1, int(float64(width)*scale)),
max(1, int(float64(height)*scale))
}
@@ -1,81 +0,0 @@
package chattool
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestComputeScaledScreenshotSize(t *testing.T) {
t.Parallel()
tests := []struct {
name string
width, height int
wantW, wantH int
}{
{
name: "1920x1080_scales_down",
width: 1920,
height: 1080,
wantW: 1429,
wantH: 804,
},
{
name: "1280x800_no_scaling",
width: 1280,
height: 800,
wantW: 1280,
wantH: 800,
},
{
name: "3840x2160_large_display",
width: 3840,
height: 2160,
wantW: 1429,
wantH: 804,
},
{
name: "1568x1000_pixel_cap_applies",
width: 1568,
height: 1000,
wantW: 1342,
wantH: 856,
},
{
name: "100x100_small_display",
width: 100,
height: 100,
wantW: 100,
wantH: 100,
},
{
name: "4000x3000_stays_within_limits",
width: 4000,
// Both constraints apply. The function should keep
// the result within maxLongEdge=1568 and
// totalPixels<=1,150,000.
height: 3000,
wantW: 1238,
wantH: 928,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotW, gotH := computeScaledScreenshotSize(tt.width, tt.height)
assert.Equal(t, tt.wantW, gotW)
assert.Equal(t, tt.wantH, gotH)
// Invariant: results must respect Anthropic constraints.
const maxLongEdge = 1568
const maxTotalPixels = 1_150_000
longEdge := max(gotW, gotH)
assert.LessOrEqual(t, longEdge, maxLongEdge,
"long edge %d exceeds max %d", longEdge, maxLongEdge)
assert.LessOrEqual(t, gotW*gotH, maxTotalPixels,
"total pixels %d exceeds max %d", gotW*gotH, maxTotalPixels)
})
}
}
-186
View File
@@ -1,186 +0,0 @@
package chattool_test
import (
"context"
"testing"
"charm.land/fantasy"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
"github.com/coder/quartz"
)
func TestComputerUseTool_Info(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
info := tool.Info()
assert.Equal(t, "computer", info.Name)
assert.NotEmpty(t, info.Description)
}
func TestComputerUseProviderTool(t *testing.T) {
t.Parallel()
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
pdt, ok := def.(fantasy.ProviderDefinedTool)
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
assert.Contains(t, pdt.ID, "computer")
assert.Equal(t, "computer", pdt.Name)
// Verify display dimensions are passed through.
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
}
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "base64png",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-1",
Name: "computer",
Input: `{"action":"screenshot"}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, "image/png", resp.MediaType)
assert.Equal(t, []byte("base64png"), resp.Data)
assert.False(t, resp.IsError)
}
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
// Expect the action call first.
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "left_click performed",
}, nil)
// Then expect a screenshot (auto-screenshot after action).
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "after-click",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-2",
Name: "computer",
Input: `{"action":"left_click","coordinate":[100,200]}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, []byte("after-click"), resp.Data)
}
func TestComputerUseTool_Run_Wait(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockConn := agentconnmock.NewMockAgentConn(ctrl)
// Expect a screenshot after the wait completes.
mockConn.EXPECT().ExecuteDesktopAction(
gomock.Any(),
gomock.Any(),
).Return(workspacesdk.DesktopActionResponse{
Output: "screenshot",
ScreenshotData: "after-wait",
ScreenshotWidth: 1024,
ScreenshotHeight: 768,
}, nil)
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return mockConn, nil
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-3",
Name: "computer",
Input: `{"action":"wait","duration":10}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.Equal(t, "image", resp.Type)
assert.Equal(t, "image/png", resp.MediaType)
assert.Equal(t, []byte("after-wait"), resp.Data)
assert.False(t, resp.IsError)
}
func TestComputerUseTool_Run_ConnError(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("workspace not available")
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-4",
Name: "computer",
Input: `{"action":"screenshot"}`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "workspace not available")
}
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
t.Parallel()
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
return nil, xerrors.New("should not be called")
}, quartz.NewReal())
call := fantasy.ToolCall{
ID: "test-5",
Name: "computer",
Input: `{invalid json`,
}
resp, err := tool.Run(context.Background(), call)
require.NoError(t, err)
assert.True(t, resp.IsError)
assert.Contains(t, resp.Content, "invalid computer use input")
}
+8 -14
View File
@@ -2,6 +2,7 @@ package chattool
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
@@ -12,7 +13,6 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/namesgenerator"
"github.com/coder/coder/v2/codersdk"
@@ -68,7 +68,6 @@ type CreateWorkspaceOptions struct {
CreateFn CreateWorkspaceFn
AgentConnFn AgentConnFunc
WorkspaceMu *sync.Mutex
Logger slog.Logger
}
type createWorkspaceArgs struct {
@@ -194,19 +193,13 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
// Persist workspace + agent association on the chat.
if options.DB != nil && options.ChatID != uuid.Nil {
if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
ID: options.ChatID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
}); err != nil {
options.Logger.Error(ctx, "failed to persist chat workspace association",
slog.F("chat_id", options.ChatID),
slog.F("workspace_id", workspace.ID),
slog.Error(err),
)
}
})
}
// Wait for the agent to come online and startup scripts to finish.
@@ -248,14 +241,15 @@ func checkExistingWorkspace(
return nil, false, nil
}
// Check if workspace still exists.
ws, err := db.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// Workspace was deleted — allow creation.
return nil, false, nil
}
return nil, false, xerrors.Errorf("load workspace: %w", err)
}
// Workspace was soft-deleted — allow creation.
if ws.Deleted {
return nil, false, nil
}
// Check the latest build status.
build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
@@ -108,35 +108,3 @@ func TestWaitForAgentReady(t *testing.T) {
require.Empty(t, result)
})
}
func TestCheckExistingWorkspace_DeletedWorkspace(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
chatID := uuid.New()
workspaceID := uuid.New()
// Mock GetChatByID returns a chat linked to a workspace.
db.EXPECT().
GetChatByID(gomock.Any(), chatID).
Return(database.Chat{
ID: chatID,
WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true},
}, nil)
// Mock GetWorkspaceByID returns a soft-deleted workspace.
db.EXPECT().
GetWorkspaceByID(gomock.Any(), workspaceID).
Return(database.Workspace{
ID: workspaceID,
Deleted: true,
}, nil)
result, done, err := checkExistingWorkspace(
context.Background(), db, chatID, nil,
)
require.NoError(t, err)
require.False(t, done, "should allow creation for deleted workspace")
require.Nil(t, result)
}
+13 -24
View File
@@ -3,14 +3,12 @@ package chattool
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"strings"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@@ -67,6 +65,7 @@ type ExecuteResult struct {
type ExecuteOptions struct {
GetWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error)
DefaultTimeout time.Duration
ChatID string
}
// ProcessToolOptions configures a process management tool
@@ -78,10 +77,10 @@ type ProcessToolOptions struct {
// ExecuteArgs are the parameters accepted by the execute tool.
type ExecuteArgs struct {
Command string `json:"command" description:"The shell command to execute."`
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
Command string `json:"command"`
Timeout *string `json:"timeout,omitempty"`
WorkDir *string `json:"workdir,omitempty"`
RunInBackground *bool `json:"run_in_background,omitempty"`
}
// Execute returns an AgentTool that runs a shell command in the
@@ -89,7 +88,7 @@ type ExecuteArgs struct {
func Execute(options ExecuteOptions) fantasy.AgentTool {
return fantasy.NewAgentTool(
"execute",
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
"Execute a shell command in the workspace.",
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if options.GetWorkspaceConn == nil {
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
@@ -98,7 +97,7 @@ func Execute(options ExecuteOptions) fantasy.AgentTool {
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return executeTool(ctx, conn, args, options.DefaultTimeout), nil
return executeTool(ctx, conn, args, options.DefaultTimeout, options.ChatID), nil
},
)
}
@@ -108,6 +107,7 @@ func executeTool(
conn workspacesdk.AgentConn,
args ExecuteArgs,
optTimeout time.Duration,
chatID string,
) fantasy.ToolResponse {
if args.Command == "" {
return fantasy.NewTextErrorResponse("command is required")
@@ -116,22 +116,15 @@ func executeTool(
// Build the environment map for the process request.
env := make(map[string]string, len(nonInteractiveEnvVars)+1)
env["CODER_CHAT_AGENT"] = "true"
if chatID != "" {
env["CODER_CHAT_ID"] = chatID
}
for k, v := range nonInteractiveEnvVars {
env[k] = v
}
background := args.RunInBackground != nil && *args.RunInBackground
// Detect shell-style backgrounding (trailing &) and promote to
// background mode. Models sometimes use "cmd &" instead of the
// run_in_background parameter, which causes the shell to fork
// and exit immediately, leaving an untracked orphan process.
trimmed := strings.TrimSpace(args.Command)
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
background = true
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
}
var workDir string
if args.WorkDir != nil {
workDir = *args.WorkDir
@@ -257,18 +250,14 @@ func pollProcess(
context.Background(),
5*time.Second,
)
outputResp, outputErr := conn.ProcessOutput(bgCtx, processID)
outputResp, _ := conn.ProcessOutput(bgCtx, processID)
bgCancel()
output := truncateOutput(outputResp.Output)
timeoutErr := xerrors.Errorf("command timed out after %s", timeout)
if outputErr != nil {
timeoutErr = errors.Join(timeoutErr, xerrors.Errorf("failed to get output: %w", outputErr))
}
return ExecuteResult{
Success: false,
Output: output,
ExitCode: -1,
Error: timeoutErr.Error(),
Error: fmt.Sprintf("command timed out after %s", timeout),
Truncated: outputResp.Truncated,
}
case <-ticker.C:
+6 -5
View File
@@ -2,6 +2,7 @@ package chattool
import (
"context"
"database/sql"
"sync"
"charm.land/fantasy"
@@ -70,15 +71,15 @@ func StartWorkspace(options StartWorkspaceOptions) fantasy.AgentTool {
ws, err := options.DB.GetWorkspaceByID(ctx, chat.WorkspaceID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return fantasy.NewTextErrorResponse(
"workspace was deleted; use create_workspace to make a new one",
), nil
}
return fantasy.NewTextErrorResponse(
xerrors.Errorf("load workspace: %w", err).Error(),
), nil
}
if ws.Deleted {
return fantasy.NewTextErrorResponse(
"workspace was deleted; use create_workspace to make a new one",
), nil
}
build, err := options.DB.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID)
if err != nil {
@@ -174,51 +174,6 @@ func TestStartWorkspace(t *testing.T) {
require.True(t, ok)
require.True(t, started)
})
t.Run("DeletedWorkspace", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
modelCfg := seedModelConfig(ctx, t, db, user.ID)
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
// Create a workspace that has been soft-deleted.
wsResp := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
Deleted: true,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionDelete,
}).Do()
ws := wsResp.Workspace
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
LastModelConfigID: modelCfg.ID,
Title: "test-deleted-workspace",
})
require.NoError(t, err)
tool := chattool.StartWorkspace(chattool.StartWorkspaceOptions{
DB: db,
ChatID: chat.ID,
StartFn: func(_ context.Context, _ uuid.UUID, _ uuid.UUID, _ codersdk.CreateWorkspaceBuildRequest) (codersdk.WorkspaceBuild, error) {
t.Fatal("StartFn should not be called for deleted workspace")
return codersdk.WorkspaceBuild{}, nil
},
WorkspaceMu: &sync.Mutex{},
})
resp, err := tool.Run(ctx, fantasy.ToolCall{ID: "call-1", Name: "start_workspace", Input: "{}"})
require.NoError(t, err)
require.Contains(t, resp.Content, "workspace was deleted")
})
}
// seedModelConfig inserts a provider and model config for testing.
-282
View File
@@ -1,282 +0,0 @@
package chatd_test
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestAnthropicWebSearchRoundTrip is an integration test that verifies
// provider-executed tool results (web_search) survive the full
// persist → reconstruct → re-send cycle. It sends a query that
// triggers Anthropic's web_search server tool, waits for completion,
// then sends a follow-up message. If the PE tool result was lost or
// corrupted during persistence, Anthropic rejects the second request:
//
// web_search tool use with id srvtoolu_... was found without a
// corresponding web_search_tool_result block
//
// The test requires ANTHROPIC_API_KEY to be set.
func TestAnthropicWebSearchRoundTrip(t *testing.T) {
t.Parallel()
apiKey := os.Getenv("ANTHROPIC_API_KEY")
if apiKey == "" {
t.Skip("ANTHROPIC_API_KEY not set; skipping Anthropic integration test")
}
baseURL := os.Getenv("ANTHROPIC_BASE_URL")
ctx := testutil.Context(t, testutil.WaitSuperLong)
// Stand up a full coderd with the agents experiment.
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
})
_ = coderdtest.CreateFirstUser(t, client)
// Configure an Anthropic provider with the real API key.
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
// Create a model config that enables web_search.
contextLimit := int64(200000)
isDefault := true
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
Model: "claude-sonnet-4-20250514",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
ModelConfig: &codersdk.ChatModelCallConfig{
ProviderOptions: &codersdk.ChatModelProviderOptions{
Anthropic: &codersdk.ChatModelAnthropicProviderOptions{
WebSearchEnabled: ptr.Ref(true),
},
},
},
})
require.NoError(t, err)
// --- Step 1: Send a message that triggers web_search ---
t.Log("Creating chat with web search query...")
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "What is the current weather in San Francisco right now? Use web search to find out.",
},
},
})
require.NoError(t, err)
t.Logf("Chat created: %s (status=%s)", chat.ID, chat.Status)
// Stream events until the chat reaches a terminal status.
events, closer, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer.Close()
waitForChatDone(ctx, t, events, "step 1")
// Verify the chat completed and messages were persisted.
chatData, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 1: %s, messages: %d",
chatData.Status, len(chatMsgs.Messages))
logMessages(t, chatMsgs.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status,
"chat should be in waiting status after step 1")
// Find the first assistant message and verify it has the
// content parts the UI needs to render web search results:
// tool-call(PE), source, tool-result(PE), and text.
assistantMsg := findAssistantWithText(t, chatMsgs.Messages)
require.NotNil(t, assistantMsg,
"expected an assistant message with text content after step 1")
partTypes := partTypeSet(assistantMsg.Content)
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeToolCall,
"assistant message should contain a PE tool-call part")
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeSource,
"assistant message should contain source parts for UI citations")
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeToolResult,
"assistant message should contain a PE tool-result part")
require.Contains(t, partTypes, codersdk.ChatMessagePartTypeText,
"assistant message should contain a text part")
// Verify the PE tool-call is marked as provider-executed.
for _, part := range assistantMsg.Content {
if part.Type == codersdk.ChatMessagePartTypeToolCall {
require.True(t, part.ProviderExecuted,
"web_search tool-call should be provider-executed")
break
}
}
// --- Step 2: Send a follow-up message ---
// This is the critical test: if PE tool results were lost during
// persistence, the reconstructed conversation will be rejected
// by Anthropic because server_tool_use has no matching
// web_search_tool_result.
t.Log("Sending follow-up message...")
_, err = client.CreateChatMessage(ctx, chat.ID,
codersdk.CreateChatMessageRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "Thanks! What about New York?",
},
},
})
require.NoError(t, err)
// Stream the follow-up response.
events2, closer2, err := client.StreamChat(ctx, chat.ID, nil)
require.NoError(t, err)
defer closer2.Close()
waitForChatDone(ctx, t, events2, "step 2")
// Verify the follow-up completed and produced content.
chatData2, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
t.Logf("Chat status after step 2: %s, messages: %d",
chatData2.Status, len(chatMsgs2.Messages))
logMessages(t, chatMsgs2.Messages)
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status,
"chat should be in waiting status after step 2")
require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages),
"follow-up should have added more messages")
// The last assistant message should have text.
lastAssistant := findLastAssistantWithText(t, chatMsgs2.Messages)
require.NotNil(t, lastAssistant,
"expected an assistant message with text in the follow-up")
t.Log("Anthropic web_search round-trip test passed.")
}
// waitForChatDone drains the event stream until the chat reaches
// a terminal status (waiting, completed, or error).
func waitForChatDone(
ctx context.Context,
t *testing.T,
events <-chan codersdk.ChatStreamEvent,
label string,
) {
t.Helper()
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for "+label+" completion")
case event, ok := <-events:
if !ok {
return
}
switch event.Type {
case codersdk.ChatStreamEventTypeError:
if event.Error != nil {
t.Logf("[%s] stream error: %s", label, event.Error.Message)
}
case codersdk.ChatStreamEventTypeStatus:
if event.Status != nil {
t.Logf("[%s] status → %s", label, event.Status.Status)
switch event.Status.Status {
case codersdk.ChatStatusWaiting,
codersdk.ChatStatusCompleted:
return
case codersdk.ChatStatusError:
require.FailNow(t, label+" ended with error status")
}
}
case codersdk.ChatStreamEventTypeMessage:
if event.Message != nil {
t.Logf("[%s] persisted message: role=%s parts=%d",
label, event.Message.Role, len(event.Message.Content))
}
case codersdk.ChatStreamEventTypeMessagePart:
// Streaming delta — just note it.
if event.MessagePart != nil {
t.Logf("[%s] part: type=%s",
label, event.MessagePart.Part.Type)
}
}
}
}
}
// findAssistantWithText returns the first assistant message that
// contains a non-empty text part.
func findAssistantWithText(t *testing.T, msgs []codersdk.ChatMessage) *codersdk.ChatMessage {
t.Helper()
for i := range msgs {
if msgs[i].Role != "assistant" {
continue
}
for _, part := range msgs[i].Content {
if part.Type == codersdk.ChatMessagePartTypeText && part.Text != "" {
return &msgs[i]
}
}
}
return nil
}
// findLastAssistantWithText returns the last assistant message that
// contains a non-empty text part.
func findLastAssistantWithText(t *testing.T, msgs []codersdk.ChatMessage) *codersdk.ChatMessage {
t.Helper()
for i := len(msgs) - 1; i >= 0; i-- {
if msgs[i].Role != "assistant" {
continue
}
for _, part := range msgs[i].Content {
if part.Type == codersdk.ChatMessagePartTypeText && part.Text != "" {
return &msgs[i]
}
}
}
return nil
}
// logMessages prints a summary of all messages for debugging.
func logMessages(t *testing.T, msgs []codersdk.ChatMessage) {
t.Helper()
for i, msg := range msgs {
types := make([]string, 0, len(msg.Content))
for _, part := range msg.Content {
s := string(part.Type)
if part.ProviderExecuted {
s += "(PE)"
}
types = append(types, s)
}
t.Logf(" msg[%d] role=%s parts=%v", i, msg.Role, types)
}
}
// partTypeSet returns the set of part types present in a message.
func partTypeSet(parts []codersdk.ChatMessagePart) map[codersdk.ChatMessagePartType]struct{} {
set := make(map[codersdk.ChatMessagePartType]struct{}, len(parts))
for _, p := range parts {
set[p.Type] = struct{}{}
}
return set
}
+19 -23
View File
@@ -21,7 +21,6 @@ import (
"github.com/coder/coder/v2/coderd/chatd/chatretry"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
)
const titleGenerationPrompt = "You are a title generator. Your ONLY job is to output a short title (2-8 words) " +
@@ -77,7 +76,7 @@ func (p *Server) maybeGenerateChatTitle(
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
c.provider, c.model, keys,
)
if err == nil {
candidates = append(candidates, m)
@@ -111,7 +110,7 @@ func (p *Server) maybeGenerateChatTitle(
return
}
chat.Title = title
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
return
}
@@ -159,12 +158,14 @@ func titleInput(
}
switch message.Role {
case database.ChatMessageRoleAssistant, database.ChatMessageRoleTool:
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
return "", false
case database.ChatMessageRoleUser:
case string(fantasy.MessageRoleUser):
userCount++
if firstUserText == "" {
parsed, err := chatprompt.ParseContent(message)
parsed, err := chatprompt.ParseContent(
string(fantasy.MessageRoleUser), message.Content,
)
if err != nil {
return "", false
}
@@ -225,21 +226,22 @@ func fallbackChatTitle(message string) string {
return truncateRunes(title, maxRunes)
}
// contentBlocksToText concatenates the text parts of SDK chat
// message parts into a single space-separated string.
func contentBlocksToText(parts []codersdk.ChatMessagePart) string {
texts := make([]string, 0, len(parts))
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeText {
// contentBlocksToText concatenates the text parts of content blocks
// into a single space-separated string.
func contentBlocksToText(content []fantasy.Content) string {
parts := make([]string, 0, len(content))
for _, block := range content {
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
if !ok {
continue
}
text := strings.TrimSpace(part.Text)
text := strings.TrimSpace(textBlock.Text)
if text == "" {
continue
}
texts = append(texts, text)
parts = append(parts, text)
}
return strings.Join(texts, " ")
return strings.Join(parts, " ")
}
func truncateRunes(value string, maxLen int) string {
@@ -279,7 +281,7 @@ func generatePushSummary(
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
c.provider, c.model, keys,
)
if err == nil {
candidates = append(candidates, m)
@@ -341,13 +343,7 @@ func generateShortText(
return "", xerrors.Errorf("generate short text: %w", err)
}
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
for _, block := range response.Content {
if p := chatprompt.PartFromContent(block); p.Type != "" {
responseParts = append(responseParts, p)
}
}
text := strings.TrimSpace(contentBlocksToText(responseParts))
text := strings.TrimSpace(contentBlocksToText(response.Content))
text = strings.Trim(text, "\"'`")
return text, nil
}
+60 -226
View File
@@ -2,7 +2,6 @@ package chatd
import (
"context"
"database/sql"
"encoding/json"
"sort"
"strings"
@@ -13,44 +12,21 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/database"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
)
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
const (
subagentAwaitPollInterval = 200 * time.Millisecond
subagentAwaitFallbackPoll = 5 * time.Second
defaultSubagentWaitTimeout = 5 * time.Minute
)
// computerUseSubagentSystemPrompt is the system prompt prepended to
// every computer use subagent chat. It instructs the model on how to
// interact with the desktop environment via the computer tool.
const computerUseSubagentSystemPrompt = `You are a computer use agent with access to a desktop environment. You can see the screen, move the mouse, click, type, scroll, and drag.
Your primary tool is the "computer" tool which lets you interact with the desktop. After every action you take, you will receive a screenshot showing the current state of the screen. Use these screenshots to verify your actions and plan next steps.
Guidelines:
- Always start by taking a screenshot to see the current state of the desktop.
- Be precise with coordinates when clicking or typing.
- Wait for UI elements to load before interacting with them.
- If an action doesn't produce the expected result, try alternative approaches.
- Report what you accomplished when done.`
type spawnAgentArgs struct {
Prompt string `json:"prompt"`
Title string `json:"title,omitempty"`
}
type spawnComputerUseAgentArgs struct {
Prompt string `json:"prompt"`
Title string `json:"title,omitempty"`
}
type waitAgentArgs struct {
ChatID string `json:"chat_id"`
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
@@ -66,26 +42,8 @@ type closeAgentArgs struct {
ChatID string `json:"chat_id"`
}
// isAnthropicConfigured reports whether an Anthropic API key is
// available, either from static provider keys or from the database.
func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
if p.providerAPIKeys.APIKey("anthropic") != "" {
return true
}
dbProviders, err := p.db.GetEnabledChatProviders(ctx)
if err != nil {
return false
}
for _, prov := range dbProviders {
if chatprovider.NormalizeProvider(prov.Provider) == "anthropic" && strings.TrimSpace(prov.APIKey) != "" {
return true
}
}
return false
}
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
tools := []fantasy.AgentTool{
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
return []fantasy.AgentTool{
fantasy.NewAgentTool(
"spawn_agent",
"Spawn a delegated child agent to work on a clearly scoped, "+
@@ -251,89 +209,6 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
},
),
}
// Only include the computer use tool when an Anthropic
// provider is configured, since it requires an Anthropic
// model.
if p.isAnthropicConfigured(ctx) {
tools = append(tools, fantasy.NewAgentTool(
"spawn_computer_use_agent",
"Spawn a dedicated computer use agent that can see the desktop "+
"(take screenshots) and interact with it (mouse, keyboard, "+
"scroll). The agent runs on a model optimized for computer "+
"use and has the same workspace tools as a standard subagent "+
"plus the native Anthropic computer tool. Use this for tasks "+
"that require visual interaction with a desktop GUI (e.g. "+
"browser automation, GUI testing, visual inspection). After "+
"spawning, use wait_agent to collect the result.",
func(ctx context.Context, args spawnComputerUseAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
if currentChat == nil {
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
}
parent := currentChat()
if parent.ParentChatID.Valid {
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
}
parent, err := p.db.GetChatByID(ctx, parent.ID)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
prompt := strings.TrimSpace(args.Prompt)
if prompt == "" {
return fantasy.NewTextErrorResponse("prompt is required"), nil
}
title := strings.TrimSpace(args.Title)
if title == "" {
title = subagentFallbackChatTitle(prompt)
}
rootChatID := parent.ID
if parent.RootChatID.Valid {
rootChatID = parent.RootChatID.UUID
}
if parent.LastModelConfigID == uuid.Nil {
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
}
// Create the child chat with Mode set to
// computer_use. This signals runChat to use the
// predefined computer use model and include the
// computer tool.
childChat, err := p.CreateChat(ctx, CreateOptions{
OwnerID: parent.OwnerID,
WorkspaceID: parent.WorkspaceID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: rootChatID,
Valid: true,
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
})
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return toolJSONResponse(map[string]any{
"chat_id": childChat.ID.String(),
"title": childChat.Title,
"status": string(childChat.Status),
}), nil
},
))
}
return tools
}
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
@@ -385,7 +260,7 @@ func (p *Server) createChildSubagentChat(
},
ModelConfigID: parent.LastModelConfigID,
Title: title,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
})
if err != nil {
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
@@ -414,16 +289,9 @@ func (p *Server) sendSubagentMessage(
return database.Chat{}, ErrSubagentNotDescendant
}
// Look up the target chat to get the owner for CreatedBy.
targetChat, err := p.db.GetChatByID(ctx, targetChatID)
if err != nil {
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
}
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
ChatID: targetChatID,
CreatedBy: targetChat.OwnerID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText(message)},
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
BusyBehavior: busyBehavior,
})
if err != nil {
@@ -447,90 +315,41 @@ func (p *Server) awaitSubagentCompletion(
return database.Chat{}, "", ErrSubagentNotDescendant
}
// Check immediately before entering the poll loop.
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
if checkErr != nil {
return database.Chat{}, "", checkErr
}
if done {
return handleSubagentDone(targetChat, report)
}
if timeout <= 0 {
timeout = defaultSubagentWaitTimeout
}
timer := time.NewTimer(timeout)
defer timer.Stop()
// When pubsub is available, subscribe for fast status
// notifications and use a less aggressive fallback poll.
// Without pubsub (single-instance / in-memory) fall back
// to the original 200ms polling.
pollInterval := subagentAwaitPollInterval
var notifyCh <-chan struct{}
if p.pubsub != nil {
pollInterval = subagentAwaitFallbackPoll
ch := make(chan struct{}, 1)
notifyCh = ch
cancel, subErr := p.pubsub.SubscribeWithErr(
coderdpubsub.ChatStreamNotifyChannel(targetChatID),
func(_ context.Context, _ []byte, _ error) {
// Non-blocking send so we never stall the
// pubsub dispatch goroutine.
select {
case ch <- struct{}{}:
default:
}
},
)
if subErr == nil {
defer cancel()
} else {
// Subscription failed; fall back to fast polling.
pollInterval = subagentAwaitPollInterval
notifyCh = nil
}
}
ticker := time.NewTicker(pollInterval)
ticker := time.NewTicker(subagentAwaitPollInterval)
defer ticker.Stop()
for {
targetChat, report, done, checkErr := p.checkSubagentCompletion(ctx, targetChatID)
if checkErr != nil {
return database.Chat{}, "", checkErr
}
if done {
if targetChat.Status == database.ChatStatusError {
reason := strings.TrimSpace(report)
if reason == "" {
reason = "agent reached error status"
}
return database.Chat{}, "", xerrors.New(reason)
}
return targetChat, report, nil
}
select {
case <-notifyCh:
case <-ticker.C:
case <-timer.C:
return database.Chat{}, "", xerrors.New("timed out waiting for delegated subagent completion")
case <-ctx.Done():
return database.Chat{}, "", ctx.Err()
}
targetChat, report, done, checkErr = p.checkSubagentCompletion(ctx, targetChatID)
if checkErr != nil {
return database.Chat{}, "", checkErr
}
if done {
return handleSubagentDone(targetChat, report)
}
}
}
// handleSubagentDone translates a completed subagent check into the
// appropriate return value, surfacing error-status chats as errors.
func handleSubagentDone(
chat database.Chat,
report string,
) (database.Chat, string, error) {
if chat.Status == database.ChatStatusError {
reason := strings.TrimSpace(report)
if reason == "" {
reason = "agent reached error status"
}
return database.Chat{}, "", xerrors.New(reason)
}
return chat, report, nil
}
func (p *Server) closeSubagent(
ctx context.Context,
parentChatID uuid.UUID,
@@ -603,12 +422,12 @@ func latestSubagentAssistantMessage(
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != database.ChatMessageRoleAssistant ||
if message.Role != string(fantasy.MessageRoleAssistant) ||
message.Visibility == database.ChatMessageVisibilityModel {
continue
}
content, parseErr := chatprompt.ParseContent(message)
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
if parseErr != nil {
continue
}
@@ -622,9 +441,6 @@ func latestSubagentAssistantMessage(
return "", nil
}
// isSubagentDescendant reports whether targetChatID is a descendant
// of ancestorChatID by walking up the parent chain from the target.
// This is O(depth) DB queries instead of O(nodes) BFS.
func isSubagentDescendant(
ctx context.Context,
store database.Store,
@@ -635,29 +451,47 @@ func isSubagentDescendant(
return false, nil
}
currentID := targetChatID
visited := map[uuid.UUID]struct{}{} // cycle protection
for {
if _, seen := visited[currentID]; seen {
return false, nil
}
visited[currentID] = struct{}{}
chat, err := store.GetChatByID(ctx, currentID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return false, nil // chain broken; not a confirmed descendant
}
return false, xerrors.Errorf("get chat %s: %w", currentID, err)
}
if !chat.ParentChatID.Valid {
return false, nil // reached root without finding ancestor
}
if chat.ParentChatID.UUID == ancestorChatID {
descendants, err := listSubagentDescendants(ctx, store, ancestorChatID)
if err != nil {
return false, err
}
for _, descendant := range descendants {
if descendant.ID == targetChatID {
return true, nil
}
currentID = chat.ParentChatID.UUID
}
return false, nil
}
func listSubagentDescendants(
ctx context.Context,
store database.Store,
chatID uuid.UUID,
) ([]database.Chat, error) {
queue := []uuid.UUID{chatID}
visited := map[uuid.UUID]struct{}{chatID: {}}
out := make([]database.Chat, 0)
for len(queue) > 0 {
parentChatID := queue[0]
queue = queue[1:]
children, err := store.ListChildChatsByParentID(ctx, parentChatID)
if err != nil {
return nil, xerrors.Errorf("list child chats for %s: %w", parentChatID, err)
}
for _, child := range children {
if _, ok := visited[child.ID]; ok {
continue
}
visited[child.ID] = struct{}{}
out = append(out, child)
queue = append(queue, child.ID)
}
}
return out, nil
}
func subagentFallbackChatTitle(message string) string {
-300
View File
@@ -1,300 +0,0 @@
package chatd
import (
"context"
"database/sql"
"encoding/json"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/chatd/chattool"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestComputerUseSubagentSystemPrompt(t *testing.T) {
t.Parallel()
// Verify the system prompt constant is non-empty and contains
// key instructions for the computer use agent.
assert.NotEmpty(t, computerUseSubagentSystemPrompt)
assert.Contains(t, computerUseSubagentSystemPrompt, "computer")
assert.Contains(t, computerUseSubagentSystemPrompt, "screenshot")
}
func TestSubagentFallbackChatTitle(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{
name: "EmptyPrompt",
input: "",
want: "New Chat",
},
{
name: "ShortPrompt",
input: "Open Firefox",
want: "Open Firefox",
},
{
name: "LongPrompt",
input: "Please open the Firefox browser and navigate to the settings page",
want: "Please open the Firefox browser and...",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := subagentFallbackChatTitle(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
// newInternalTestServer creates a Server for internal tests with
// custom provider API keys. The server is automatically closed
// when the test finishes.
func newInternalTestServer(
t *testing.T,
db database.Store,
ps pubsub.Pubsub,
keys chatprovider.ProviderAPIKeys,
) *Server {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := New(Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
// Use a very long interval so the background loop
// does not interfere with test assertions.
PendingChatAcquireInterval: testutil.WaitLong,
ProviderAPIKeys: keys,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
return server
}
// seedInternalChatDeps inserts an OpenAI provider and model config
// into the database and returns the created user and model. This
// deliberately does NOT create an Anthropic provider.
func seedInternalChatDeps(
ctx context.Context,
t *testing.T,
db database.Store,
) (database.User, database.ChatModelConfig) {
t.Helper()
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseUrl: "",
ApiKeyKeyID: sql.NullString{},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o-mini",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
return user, model
}
// findToolByName returns the tool with the given name from the
// slice, or nil if no match is found.
func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
for _, tool := range tools {
if tool.Info().Name == name {
return tool
}
}
return nil
}
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
// No Anthropic key in ProviderAPIKeys.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-no-anthropic",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Re-fetch so LastModelConfigID is populated from the DB.
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when Anthropic is not configured")
}
func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
// Provide an Anthropic key so the provider check passes.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// Create a root parent chat.
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "root-parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Create a child chat under the parent.
child, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
Title: "child-subagent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do something")},
})
require.NoError(t, err)
// Re-fetch the child so ParentChatID is populated.
childChat, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, childChat.ParentChatID.Valid,
"child chat must have a parent")
// Get tools as if the child chat is the current chat.
tools := server.subagentTools(ctx, func() database.Chat { return childChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
require.NotNil(t, tool, "spawn_computer_use_agent tool must be present")
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-2",
Name: "spawn_computer_use_agent",
Input: `{"prompt":"open browser"}`,
})
require.NoError(t, err)
assert.True(t, resp.IsError, "expected an error response")
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
}
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
// Provide an Anthropic key so the tool can proceed.
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
Anthropic: "test-anthropic-key",
})
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedInternalChatDeps(ctx, t, db)
// The parent uses an OpenAI model.
require.Equal(t, "openai", model.Provider,
"seed helper must create an OpenAI model")
parent, err := server.CreateChat(ctx, CreateOptions{
OwnerID: user.ID,
Title: "parent-openai",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
parentChat, err := db.GetChatByID(ctx, parent.ID)
require.NoError(t, err)
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
tool := findToolByName(tools, "spawn_computer_use_agent")
require.NotNil(t, tool)
resp, err := tool.Run(ctx, fantasy.ToolCall{
ID: "call-3",
Name: "spawn_computer_use_agent",
Input: `{"prompt":"take a screenshot"}`,
})
require.NoError(t, err)
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
// Parse the response to get the child chat ID.
var result map[string]any
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
childIDStr, ok := result["chat_id"].(string)
require.True(t, ok, "response must contain chat_id")
childID, err := uuid.Parse(childIDStr)
require.NoError(t, err)
childChat, err := db.GetChatByID(ctx, childID)
require.NoError(t, err)
// The child must have Mode=computer_use which causes
// runChat to override the model to the predefined computer
// use model instead of using the parent's model config.
require.True(t, childChat.Mode.Valid)
assert.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
// The predefined computer use model is Anthropic, which
// differs from the parent's OpenAI model. This confirms
// that the child will not inherit the parent's model at
// runtime.
assert.NotEqual(t, model.Provider, chattool.ComputerUseModelProvider,
"computer use model provider must differ from parent model provider")
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
assert.NotEmpty(t, chattool.ComputerUseModelName)
}
-218
View File
@@ -1,218 +0,0 @@
package chatd_test
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func TestSpawnComputerUseAgent_CreatesChildWithChatMode(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Create a parent chat.
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
// Simulate what spawn_computer_use_agent does: set ChatMode
// to computer_use and provide a system prompt.
prompt := "Use the desktop to open Firefox"
child, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: parent.OwnerID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
ModelConfigID: model.ID,
Title: "computer-use",
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: "Computer use instructions\n\n" + prompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
})
require.NoError(t, err)
// Verify parent-child relationship.
require.True(t, child.ParentChatID.Valid)
require.Equal(t, parent.ID, child.ParentChatID.UUID)
// Verify the chat type is set correctly.
require.True(t, child.Mode.Valid)
assert.Equal(t, database.ChatModeComputerUse, child.Mode.ChatMode)
// Confirm via a fresh DB read as well.
got, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, got.Mode.Valid)
assert.Equal(t, database.ChatModeComputerUse, got.Mode.ChatMode)
}
func TestSpawnComputerUseAgent_SystemPromptFormat(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
prompt := "Navigate to settings page"
systemPrompt := "Computer use instructions\n\n" + prompt
child, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: parent.OwnerID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
ModelConfigID: model.ID,
Title: "computer-use-format",
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: systemPrompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
})
require.NoError(t, err)
messages, err := db.GetChatMessagesForPromptByChatID(ctx, child.ID)
require.NoError(t, err)
// The system message raw content is a JSON-encoded string.
// It should contain the system prompt with the user prompt.
var rawSystemContent string
for _, msg := range messages {
if msg.Role != "system" {
continue
}
if msg.Content.Valid {
rawSystemContent = string(msg.Content.RawMessage)
break
}
}
assert.Contains(t, rawSystemContent, prompt,
"system prompt raw content should contain the user prompt")
}
func TestSpawnComputerUseAgent_ChildIsListedUnderParent(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
prompt := "Check the UI layout"
child, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: parent.OwnerID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
ModelConfigID: model.ID,
Title: "computer-use-child",
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: "Computer use instructions\n\n" + prompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
})
require.NoError(t, err)
// Verify the child is linked to the parent.
fetchedChild, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
require.True(t, fetchedChild.ParentChatID.Valid)
assert.Equal(t, parent.ID, fetchedChild.ParentChatID.UUID)
}
func TestSpawnComputerUseAgent_RootChatIDPropagation(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Create a root parent chat (no parent of its own).
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "root-parent",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
prompt := "Take a screenshot"
child, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: parent.OwnerID,
ParentChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
RootChatID: uuid.NullUUID{
UUID: parent.ID,
Valid: true,
},
ModelConfigID: model.ID,
Title: "computer-use-root-test",
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
SystemPrompt: "Computer use instructions\n\n" + prompt,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
})
require.NoError(t, err)
// When the parent has no RootChatID, the child's RootChatID
// should point to the parent.
require.True(t, child.RootChatID.Valid)
assert.Equal(t, parent.ID, child.RootChatID.UUID)
// Verify chat was retrieved correctly from the DB.
got, err := db.GetChatByID(ctx, child.ID)
require.NoError(t, err)
assert.True(t, got.RootChatID.Valid)
assert.Equal(t, parent.ID, got.RootChatID.UUID)
}
-128
View File
@@ -1,128 +0,0 @@
package chatd
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
)
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
// active usage-limit period containing now.
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
utcNow := now.UTC()
switch period {
case codersdk.ChatUsageLimitPeriodDay:
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 0, 1)
case codersdk.ChatUsageLimitPeriodWeek:
// Walk backward to Monday of the current ISO week.
// ISO 8601 weeks always start on Monday, so this never
// crosses an ISO-week boundary.
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for start.Weekday() != time.Monday {
start = start.AddDate(0, 0, -1)
}
end = start.AddDate(0, 0, 7)
case codersdk.ChatUsageLimitPeriodMonth:
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
end = start.AddDate(0, 1, 0)
default:
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
}
return start, end
}
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
//
// Note: There is a potential race condition where two concurrent messages
// from the same user can both pass the limit check if processed in
// parallel, allowing brief overage. This is acceptable because:
// - Cost is only known after the LLM API returns.
// - Overage is bounded by message cost × concurrency.
// - Fail-open is the deliberate design choice for this feature.
//
// Architecture note: today this path enforces one period globally
// (day/week/month) from config.
// To support simultaneous periods, add nullable
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
// means no limit for that period.
// Then scan spend once over the widest active window with conditional SUMs
// for each period and compare each spend/limit pair Go-side, blocking on
// whichever period is tightest.
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
// deployment config reads and cross-user chat spend aggregation.
authCtx := dbauthz.AsChatd(ctx)
config, err := db.GetChatUsageLimitConfig(authCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
return nil, err
}
if !config.Enabled {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
period, ok := mapDBPeriodToSDK(config.Period)
if !ok {
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
}
// Resolve effective limit in a single query:
// individual override > group limit > global default.
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
if err != nil {
return nil, err
}
// -1 means limits are disabled (shouldn't happen since we checked above,
// but handle gracefully).
if effectiveLimit < 0 {
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
}
start, end := ComputeUsagePeriodBounds(now, period)
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
UserID: userID,
StartTime: start,
EndTime: end,
})
if err != nil {
return nil, err
}
return &codersdk.ChatUsageLimitStatus{
IsLimited: true,
Period: period,
SpendLimitMicros: &effectiveLimit,
CurrentSpend: spendTotal,
PeriodStart: start,
PeriodEnd: end,
}, nil
}
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
switch dbPeriod {
case string(codersdk.ChatUsageLimitPeriodDay):
return codersdk.ChatUsageLimitPeriodDay, true
case string(codersdk.ChatUsageLimitPeriodWeek):
return codersdk.ChatUsageLimitPeriodWeek, true
case string(codersdk.ChatUsageLimitPeriodMonth):
return codersdk.ChatUsageLimitPeriodMonth, true
default:
return "", false
}
}
-132
View File
@@ -1,132 +0,0 @@
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
import (
"testing"
"time"
"github.com/coder/coder/v2/codersdk"
)
func TestComputeUsagePeriodBounds(t *testing.T) {
t.Parallel()
newYork, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatalf("load America/New_York: %v", err)
}
tests := []struct {
name string
now time.Time
period codersdk.ChatUsageLimitPeriod
wantStart time.Time
wantEnd time.Time
}{
{
name: "day/mid_day",
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/midnight_exactly",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "day/end_of_day",
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/wednesday",
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/monday",
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/sunday",
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
},
{
name: "week/year_boundary",
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodWeek,
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
},
{
name: "month/mid_month",
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/first_day",
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/last_day",
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/february",
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "month/leap_year_february",
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
period: codersdk.ChatUsageLimitPeriodMonth,
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
},
{
name: "day/non_utc_timezone",
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
period: codersdk.ChatUsageLimitPeriodDay,
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
if !start.Equal(tc.wantStart) {
t.Errorf("start: got %v, want %v", start, tc.wantStart)
}
if !end.Equal(tc.wantEnd) {
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
}
})
}
}
+238 -1086
View File
File diff suppressed because it is too large Load Diff
+139 -1430
View File
File diff suppressed because it is too large Load Diff
+2 -48
View File
@@ -44,7 +44,6 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
"github.com/coder/coder/v2/coderd/aiseats"
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/audit"
@@ -628,11 +627,8 @@ func New(options *Options) *API {
options.Database,
options.Pubsub,
),
dbRolluper: options.DatabaseRolluper,
ProfileCollector: defaultProfileCollector{},
AISeatTracker: aiseats.Noop{},
dbRolluper: options.DatabaseRolluper,
}
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
ctx,
options.Logger.Named("workspaceapps"),
@@ -1142,13 +1138,6 @@ func New(options *Options) *API {
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
r.Get("/watch", api.watchChats)
r.Route("/cost", func(r chi.Router) {
r.Get("/users", api.chatCostUsers)
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractUserParam(options.Database))
r.Get("/summary", api.chatCostSummary)
})
})
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
@@ -1178,31 +1167,17 @@ func New(options *Options) *API {
r.Delete("/", api.deleteChatModelConfig)
})
})
r.Route("/usage-limits", func(r chi.Router) {
r.Get("/", api.getChatUsageLimitConfig)
r.Put("/", api.updateChatUsageLimitConfig)
r.Get("/status", api.getMyChatUsageLimitStatus)
r.Route("/overrides/{user}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitOverride)
r.Delete("/", api.deleteChatUsageLimitOverride)
})
r.Route("/group-overrides/{group}", func(r chi.Router) {
r.Put("/", api.upsertChatUsageLimitGroupOverride)
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
})
})
r.Route("/{chat}", func(r chi.Router) {
r.Use(httpmw.ExtractChatParam(options.Database))
r.Get("/", api.getChat)
r.Get("/git/watch", api.watchChatGit)
r.Get("/desktop", api.watchChatDesktop)
r.Post("/archive", api.archiveChat)
r.Post("/unarchive", api.unarchiveChat)
r.Get("/messages", api.getChatMessages)
r.Post("/messages", api.postChatMessages)
r.Patch("/messages/{message}", api.patchChatMessage)
r.Get("/stream", api.streamChat)
r.Post("/interrupt", api.interruptChat)
r.Get("/diff-status", api.getChatDiffStatus)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
@@ -1219,13 +1194,6 @@ func New(options *Options) *API {
// MCP HTTP transport endpoint with mandatory authentication
r.Mount("/http", api.mcpHTTPHandler())
})
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceBuildUpdates),
)
r.Get("/", api.watchAllWorkspaceBuilds)
})
})
r.Route("/api/v2", func(r chi.Router) {
@@ -1763,8 +1731,6 @@ func New(options *Options) *API {
}
r.Method("GET", "/expvar", expvar.Handler()) // contains DERP metrics as well as cmdline and memstats
r.Post("/profile", api.debugCollectProfile)
r.Route("/pprof", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
// Some of the pprof handlers strip the `/debug/pprof`
@@ -2049,20 +2015,9 @@ type API struct {
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// AISeatTracker records AI seat usage.
AISeatTracker aiseats.SeatTracker
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
// ProfileCollector abstracts the runtime/pprof and runtime/trace
// calls used by the /debug/profile endpoint. Tests override this
// with a stub to avoid process-global side-effects.
ProfileCollector ProfileCollector
// ProfileCollecting is used as a concurrency guard so that only one
// profile collection (via /debug/profile) can run at a time. The CPU
// profiler is process-global, so concurrent collections would fail.
ProfileCollecting atomic.Bool
}
// Close waits for all WebSocket connections to drain before returning.
@@ -2263,7 +2218,6 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
provisionerdserver.Options{
OIDCConfig: api.OIDCConfig,
ExternalAuthConfigs: api.ExternalAuthConfigs,
AISeatTracker: api.AISeatTracker,
Clock: api.Clock,
HeartbeatFn: options.heartbeatFn,
},
+16 -23
View File
@@ -6,27 +6,20 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+263 -108
View File
@@ -12,6 +12,7 @@ import (
"strings"
"time"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/hashicorp/hcl/v2"
"github.com/sqlc-dev/pqtype"
@@ -21,7 +22,6 @@ import (
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/render"
@@ -1059,20 +1059,15 @@ func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
if !m.ModelConfigID.Valid {
modelConfigID = nil
}
createdBy := &m.CreatedBy.UUID
if !m.CreatedBy.Valid {
createdBy = nil
}
msg := codersdk.ChatMessage{
ID: m.ID,
ChatID: m.ChatID,
CreatedBy: createdBy,
ModelConfigID: modelConfigID,
CreatedAt: m.CreatedAt,
Role: codersdk.ChatMessageRole(m.Role),
Role: m.Role,
}
if m.Content.Valid {
parts, err := chatMessageParts(m)
parts, err := chatMessageParts(m.Role, m.Content)
if err == nil {
msg.Content = parts
}
@@ -1114,15 +1109,9 @@ func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
// ChatQueuedMessage converts a queued message to its SDK representation.
func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
// Queued messages are always written by current code via
// MarshalParts, so they are always current content version.
parts, err := chatMessageParts(database.ChatMessage{
Role: database.ChatMessageRoleUser,
Content: pqtype.NullRawMessage{
RawMessage: message.Content,
Valid: len(message.Content) > 0,
},
ContentVersion: chatprompt.CurrentContentVersion,
parts, err := chatMessageParts(string(fantasy.MessageRoleUser), pqtype.NullRawMessage{
RawMessage: message.Content,
Valid: len(message.Content) > 0,
})
if err != nil {
parts = nil
@@ -1146,16 +1135,265 @@ func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQu
return out
}
func chatMessageParts(m database.ChatMessage) ([]codersdk.ChatMessagePart, error) {
parts, err := chatprompt.ParseContent(m)
if err != nil {
return nil, err
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
switch role {
case string(fantasy.MessageRoleSystem):
content, err := parseSystemContent(raw)
if err != nil {
return nil, err
}
if strings.TrimSpace(content) == "" {
return nil, nil
}
return []codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeText,
Text: content,
}}, nil
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
content, err := parseContentBlocks(role, raw)
if err != nil {
return nil, err
}
var rawBlocks []json.RawMessage
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
parts := make([]codersdk.ChatMessagePart, 0, len(content))
for i, block := range content {
part := contentBlockToPart(block)
if part.Type == "" {
continue
}
if i < len(rawBlocks) {
switch part.Type {
case codersdk.ChatMessagePartTypeReasoning:
part.Title = reasoningStoredTitle(rawBlocks[i])
case codersdk.ChatMessagePartTypeFile:
if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil {
part.FileID = uuid.NullUUID{UUID: fid, Valid: true}
}
// When a file_id is present, omit inline data
// from the response. Clients fetch content via
// the GET /chats/files/{id} endpoint instead.
if part.FileID.Valid {
part.Data = nil
}
}
}
parts = append(parts, part)
}
return parts, nil
case string(fantasy.MessageRoleTool):
results, err := parseToolResults(raw)
if err != nil {
return nil, err
}
parts := make([]codersdk.ChatMessagePart, 0, len(results))
for _, result := range results {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: result.ToolCallID,
ToolName: result.ToolName,
Result: result.Result,
IsError: result.IsError,
})
}
return parts, nil
default:
return nil, nil
}
// Strip internal-only fields before API responses.
for i := range parts {
parts[i].StripInternal()
}
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return "", nil
}
return parts, nil
var content string
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
return "", xerrors.Errorf("parse system content: %w", err)
}
return content, nil
}
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
if role == string(fantasy.MessageRoleUser) {
var text string
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
return []fantasy.Content{
fantasy.TextContent{Text: text},
}, nil
}
}
var blocks []json.RawMessage
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
return nil, xerrors.Errorf("parse content blocks: %w", err)
}
content := make([]fantasy.Content, 0, len(blocks))
for _, block := range blocks {
decoded, err := fantasy.UnmarshalContent(block)
if err != nil {
return nil, xerrors.Errorf("parse content block: %w", err)
}
content = append(content, decoded)
}
return content, nil
}
// toolResultRow is used only for extracting top-level fields from
// persisted tool result JSON. The result payload is kept as raw JSON.
type toolResultRow struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
}
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return nil, nil
}
var results []toolResultRow
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
return nil, xerrors.Errorf("parse tool results: %w", err)
}
return results, nil
}
func reasoningStoredTitle(raw json.RawMessage) string {
var envelope struct {
Type string `json:"type"`
Data struct {
Title string `json:"title"`
} `json:"data"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
return ""
}
if !strings.EqualFold(envelope.Type, string(fantasy.ContentTypeReasoning)) {
return ""
}
return strings.TrimSpace(envelope.Data.Title)
}
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
switch value := block.(type) {
case fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case *fantasy.TextContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeText,
Text: value.Text,
}
case fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case *fantasy.ReasoningContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeReasoning,
Text: value.Text,
}
case fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case *fantasy.ToolCallContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: value.ToolCallID,
ToolName: value.ToolName,
Args: []byte(value.Input),
}
case fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case *fantasy.SourceContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeSource,
SourceID: value.ID,
URL: value.URL,
Title: value.Title,
}
case fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case *fantasy.FileContent:
return codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeFile,
MediaType: value.MediaType,
Data: value.Data,
}
case fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
case *fantasy.ToolResultContent:
return chatprompt.ToolResultToPart(
value.ToolCallID,
value.ToolName,
toolResultOutputToRawJSON(value.Result),
toolResultOutputIsError(value.Result),
)
default:
return codersdk.ChatMessagePart{}
}
}
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
switch v := output.(type) {
case fantasy.ToolResultOutputContentError:
if v.Error != nil {
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
return data
}
return json.RawMessage(`{"error":""}`)
case fantasy.ToolResultOutputContentText:
raw := json.RawMessage(v.Text)
if json.Valid(raw) {
return raw
}
data, _ := json.Marshal(map[string]any{"output": v.Text})
return data
case fantasy.ToolResultOutputContentMedia:
data, _ := json.Marshal(map[string]any{
"data": v.Data,
"mime_type": v.MediaType,
"text": v.Text,
})
return data
default:
return json.RawMessage(`{}`)
}
}
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
_, ok := output.(fantasy.ToolResultOutputContentError)
return ok
}
func nullInt64Ptr(v sql.NullInt64) *int64 {
@@ -1165,86 +1403,3 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
value := v.Int64
return &value
}
// ChatDiffStatus converts a database.ChatDiffStatus to a
// codersdk.ChatDiffStatus. When status is nil an empty value
// containing only the chatID is returned.
func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
result := codersdk.ChatDiffStatus{
ChatID: chatID,
}
if status == nil {
return result
}
result.ChatID = status.ChatID
if status.Url.Valid {
u := strings.TrimSpace(status.Url.String)
if u != "" {
result.URL = &u
}
}
if result.URL == nil {
// Try to build a branch URL from the stored origin.
// Since this function does not have access to the API
// instance, we construct a GitHub provider directly as
// a best-effort fallback.
// TODO: This uses the default github.com API base URL,
// so branch URLs for GitHub Enterprise instances will
// be incorrect. To fix this, this function would need
// access to the external auth configs.
gp := gitprovider.New("github", "", nil)
if gp != nil {
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
}
}
}
}
if status.PullRequestState.Valid {
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
if pullRequestState != "" {
result.PullRequestState = &pullRequestState
}
}
result.PullRequestTitle = status.PullRequestTitle
result.PullRequestDraft = status.PullRequestDraft
result.ChangesRequested = status.ChangesRequested
result.Additions = status.Additions
result.Deletions = status.Deletions
result.ChangedFiles = status.ChangedFiles
if status.AuthorLogin.Valid {
result.AuthorLogin = &status.AuthorLogin.String
}
if status.AuthorAvatarUrl.Valid {
result.AuthorAvatarURL = &status.AuthorAvatarUrl.String
}
if status.BaseBranch.Valid {
result.BaseBranch = &status.BaseBranch.String
}
if status.HeadBranch.Valid {
result.HeadBranch = &status.HeadBranch.String
}
if status.PrNumber.Valid {
result.PRNumber = &status.PrNumber.Int32
}
if status.Commits.Valid {
result.Commits = &status.Commits.Int32
}
if status.Approved.Valid {
result.Approved = &status.Approved.Bool
}
if status.ReviewerCount.Valid {
result.ReviewerCount = &status.ReviewerCount.Int32
}
if status.RefreshedAt.Valid {
refreshedAt := status.RefreshedAt.Time
result.RefreshedAt = &refreshedAt
}
staleAt := status.StaleAt
result.StaleAt = &staleAt
return result
}
+95 -54
View File
@@ -9,6 +9,7 @@ import (
"time"
"charm.land/fantasy"
fantasyopenai "charm.land/fantasy/providers/openai"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
@@ -437,67 +438,87 @@ func TestAIBridgeInterception(t *testing.T) {
}
}
func TestChatMessage_PreservesProviderExecutedOnToolResults(t *testing.T) {
func TestChatMessage_ReasoningPartWithoutPersistedTitleIsEmpty(t *testing.T) {
t.Parallel()
toolCallID := uuid.New().String()
toolName := "web_search"
// Build assistant content blocks with ProviderExecuted set.
toolCall := fantasy.ToolCallContent{
ToolCallID: toolCallID,
ToolName: toolName,
Input: `{"query":"test"}`,
ProviderExecuted: true,
}
toolResult := fantasy.ToolResultContent{
ToolCallID: toolCallID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentText{Text: `{"results":[]}`},
ProviderExecuted: true,
}
tcJSON, err := json.Marshal(toolCall)
require.NoError(t, err)
trJSON, err := json.Marshal(toolResult)
assistantContent, err := json.Marshal([]fantasy.Content{
fantasy.ReasoningContent{
Text: "Plan migration",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{"Plan migration"},
},
},
},
})
require.NoError(t, err)
rawContent := json.RawMessage("[" + string(tcJSON) + "," + string(trJSON) + "]")
dbMsg := database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
Role: database.ChatMessageRoleAssistant,
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: rawContent,
RawMessage: assistantContent,
Valid: true,
},
})
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Plan migration", message.Content[0].Text)
require.Empty(t, message.Content[0].Title)
}
func TestChatMessage_ReasoningPartPrefersPersistedTitle(t *testing.T) {
t.Parallel()
reasoningContent, err := json.Marshal(fantasy.ReasoningContent{
Text: "Verify schema updates, then apply changes in order.",
ProviderMetadata: fantasy.ProviderMetadata{
fantasyopenai.Name: &fantasyopenai.ResponsesReasoningMetadata{
ItemID: "reasoning-1",
Summary: []string{
"**Metadata-derived title**\n\nLonger explanation.",
},
},
},
})
require.NoError(t, err)
var envelope map[string]any
require.NoError(t, json.Unmarshal(reasoningContent, &envelope))
dataValue, ok := envelope["data"].(map[string]any)
require.True(t, ok)
dataValue["title"] = "Persisted stream title"
encodedReasoning, err := json.Marshal(envelope)
require.NoError(t, err)
assistantContent, err := json.Marshal([]json.RawMessage{encodedReasoning})
require.NoError(t, err)
message := db2sdk.ChatMessage(database.ChatMessage{
ID: 1,
ChatID: uuid.New(),
CreatedAt: time.Now(),
}
Role: string(fantasy.MessageRoleAssistant),
Content: pqtype.NullRawMessage{
RawMessage: assistantContent,
Valid: true,
},
})
result := db2sdk.ChatMessage(dbMsg)
require.Len(t, result.Content, 2)
// First part: tool call.
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, result.Content[0].Type)
require.Equal(t, toolCallID, result.Content[0].ToolCallID)
require.Equal(t, toolName, result.Content[0].ToolName)
require.True(t, result.Content[0].ProviderExecuted, "tool call should preserve ProviderExecuted")
// Second part: tool result.
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, result.Content[1].Type)
require.Equal(t, toolCallID, result.Content[1].ToolCallID)
require.Equal(t, toolName, result.Content[1].ToolName)
require.True(t, result.Content[1].ProviderExecuted, "tool result should preserve ProviderExecuted")
require.Len(t, message.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, message.Content[0].Type)
require.Equal(t, "Persisted stream title", message.Content[0].Title)
}
func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
t.Parallel()
// Queued messages are always written via MarshalParts (SDK format).
rawContent, err := json.Marshal([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("queued text"),
rawContent, err := json.Marshal([]fantasy.Content{
fantasy.TextContent{Text: "queued text"},
})
require.NoError(t, err)
@@ -513,15 +534,35 @@ func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
require.Equal(t, "queued text", queued.Content[0].Text)
}
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
func TestChatQueuedMessage_FallsBackToTextForLegacyContent(t *testing.T) {
t.Parallel()
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: json.RawMessage(`{"unexpected":"shape"}`),
CreatedAt: time.Now(),
t.Run("legacy_string", func(t *testing.T) {
t.Parallel()
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: json.RawMessage(`"legacy queued text"`),
CreatedAt: time.Now(),
})
require.Len(t, queued.Content, 1)
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
require.Equal(t, "legacy queued text", queued.Content[0].Text)
})
require.Empty(t, queued.Content)
t.Run("malformed_payload", func(t *testing.T) {
t.Parallel()
raw := json.RawMessage(`{"unexpected":"shape"}`)
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
ID: 1,
ChatID: uuid.New(),
Content: raw,
CreatedAt: time.Now(),
})
require.Empty(t, queued.Content)
})
}
+185 -149
View File
@@ -707,7 +707,6 @@ var (
rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
rbac.ResourceWorkspace.Type: {policy.ActionRead},
rbac.ResourceDeploymentConfig.Type: {policy.ActionRead},
rbac.ResourceUser.Type: {policy.ActionReadPersonal},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
@@ -1513,13 +1512,13 @@ func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.Prov
return nil
}
func (q *querier) AcquireChats(ctx context.Context, arg database.AcquireChatsParams) ([]database.Chat, error) {
// AcquireChats is a system-level operation used by the chat processor.
func (q *querier) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
// AcquireChat is a system-level operation used by the chat processor.
// Authorization is done at the system level, not per-user.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
return database.Chat{}, err
}
return q.db.AcquireChats(ctx, arg)
return q.db.AcquireChat(ctx, arg)
}
func (q *querier) AcquireLock(ctx context.Context, id int64) error {
@@ -1726,13 +1725,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
}
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return 0, err
}
return q.db.CountEnabledModelsWithoutPricing(ctx)
}
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
@@ -1836,6 +1828,18 @@ func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.De
return q.db.DeleteChatMessagesAfterID(ctx, arg)
}
func (q *querier) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
// Authorize delete on the parent chat.
chat, err := q.db.GetChatByID(ctx, chatID)
if err != nil {
return err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, chat); err != nil {
return err
}
return q.db.DeleteChatMessagesByChatID(ctx, chatID)
}
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -1861,20 +1865,6 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
return q.db.DeleteChatQueuedMessage(ctx, arg)
}
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -1909,6 +1899,10 @@ func (q *querier) DeleteExternalAuthLink(ctx context.Context, arg database.Delet
}, q.db.DeleteExternalAuthLink)(ctx, arg)
}
func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error {
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID)
}
func (q *querier) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
return deleteQ(q.log, q.auth, q.db.GetGroupByID, q.db.DeleteGroupByID)(ctx, id)
}
@@ -2348,13 +2342,6 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
}
func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil {
return 0, err
}
return q.db.GetActiveAISeatCount(ctx)
}
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
return nil, err
@@ -2410,6 +2397,13 @@ func (q *querier) GetAnnouncementBanners(ctx context.Context) (string, error) {
return q.db.GetAnnouncementBanners(ctx)
}
func (q *querier) GetAppSecurityKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetAppSecurityKey(ctx)
}
func (q *querier) GetApplicationName(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetApplicationName(ctx)
@@ -2454,34 +2448,6 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
}
func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return nil, err
}
return q.db.GetChatCostPerChat(ctx, arg)
}
func (q *querier) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return nil, err
}
return q.db.GetChatCostPerModel(ctx, arg)
}
func (q *querier) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.GetChatCostPerUser(ctx, arg)
}
func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return database.GetChatCostSummaryRow{}, err
}
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2560,14 +2526,6 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatIDDescPaginated(ctx, arg)
}
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
@@ -2584,6 +2542,13 @@ func (q *querier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (dat
return q.db.GetChatModelConfigByID(ctx, id)
}
func (q *querier) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatModelConfig{}, err
}
return q.db.GetChatModelConfigByProviderAndModel(ctx, arg)
}
func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
@@ -2632,27 +2597,6 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
return q.db.GetChatSystemPrompt(ctx)
}
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.GetChatUsageLimitConfig(ctx)
}
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitGroupOverrideRow{}, err
}
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
}
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.GetChatUsageLimitUserOverrideRow{}, err
}
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
}
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
}
@@ -2672,6 +2616,13 @@ func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetC
return q.db.GetAuthorizedConnectionLogsOffset(ctx, arg, prep)
}
func (q *querier) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetCoordinatorResumeTokenSigningKey(ctx)
}
func (q *querier) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceCryptoKey); err != nil {
return database.CryptoKey{}, err
@@ -2725,6 +2676,14 @@ func (q *querier) GetDefaultProxyConfig(ctx context.Context) (database.GetDefaul
return q.db.GetDefaultProxyConfig(ctx)
}
// Only used by metrics cache.
func (q *querier) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetDeploymentDAUs(ctx, tzOffset)
}
func (q *querier) GetDeploymentID(ctx context.Context) (string, error) {
// No authz checks
return q.db.GetDeploymentID(ctx)
@@ -2807,6 +2766,22 @@ func (q *querier) GetFileByID(ctx context.Context, id uuid.UUID) (database.File,
return file, nil
}
func (q *querier) GetFileIDByTemplateVersionID(ctx context.Context, templateVersionID uuid.UUID) (uuid.UUID, error) {
fileID, err := q.db.GetFileIDByTemplateVersionID(ctx, templateVersionID)
if err != nil {
return uuid.Nil, err
}
// This is a kind of weird check, because users will almost never have this
// permission. Since this query is not currently used to provide data in a
// user facing way, it's expected that this query is run as some system
// subject in order to be authorized.
err = q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceFile.WithID(fileID))
if err != nil {
return uuid.Nil, err
}
return fileID, nil
}
func (q *querier) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
@@ -3015,6 +2990,13 @@ func (q *querier) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) (d
return q.db.GetOAuth2ProviderAppByID(ctx, id)
}
func (q *querier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2App); err != nil {
return database.OAuth2ProviderApp{}, err
}
return q.db.GetOAuth2ProviderAppByRegistrationToken(ctx, registrationAccessToken)
}
func (q *querier) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) {
return fetch(q.log, q.auth, q.db.GetOAuth2ProviderAppCodeByID)(ctx, id)
}
@@ -3083,6 +3065,13 @@ func (q *querier) GetOAuth2ProviderAppsByUserID(ctx context.Context, userID uuid
return q.db.GetOAuth2ProviderAppsByUserID(ctx, userID)
}
func (q *querier) GetOAuthSigningKey(ctx context.Context) (string, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return "", err
}
return q.db.GetOAuthSigningKey(ctx)
}
func (q *querier) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
return fetch(q.log, q.auth, q.db.GetOrganizationByID)(ctx, id)
}
@@ -3322,6 +3311,23 @@ func (q *querier) GetProvisionerJobTimingsByJobID(ctx context.Context, jobID uui
return q.db.GetProvisionerJobTimingsByJobID(ctx, jobID)
}
func (q *querier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
provisionerJobs, err := q.db.GetProvisionerJobsByIDs(ctx, ids)
if err != nil {
return nil, err
}
orgIDs := make(map[uuid.UUID]struct{})
for _, job := range provisionerJobs {
orgIDs[job.OrganizationID] = struct{}{}
}
for orgID := range orgIDs {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceProvisionerJobs.InOrg(orgID)); err != nil {
return nil, err
}
}
return provisionerJobs, nil
}
func (q *querier) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
// TODO: Remove this once we have a proper rbac check for provisioner jobs.
// Details in https://github.com/coder/coder/issues/16160
@@ -3528,6 +3534,14 @@ func (q *querier) GetTemplateByOrganizationAndName(ctx context.Context, arg data
return fetch(q.log, q.auth, q.db.GetTemplateByOrganizationAndName)(ctx, arg)
}
// Only used by metrics cache.
func (q *querier) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetTemplateDAUs(ctx, arg)
}
func (q *querier) GetTemplateInsights(ctx context.Context, arg database.GetTemplateInsightsParams) (database.GetTemplateInsightsRow, error) {
if err := q.authorizeTemplateInsights(ctx, arg.TemplateIDs); err != nil {
return database.GetTemplateInsightsRow{}, err
@@ -3624,6 +3638,17 @@ func (q *querier) GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg
return tv, nil
}
func (q *querier) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
// If we can successfully call `GetTemplateVersionByID`, then
// we know the actor has sufficient permissions to know if the
// template has an AI task.
if _, err := q.GetTemplateVersionByID(ctx, id); err != nil {
return false, err
}
return q.db.GetTemplateVersionHasAITask(ctx, id)
}
func (q *querier) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
// An actor can read template version parameters if they can read the related template.
tv, err := q.db.GetTemplateVersionByID(ctx, templateVersionID)
@@ -3807,13 +3832,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
return q.db.GetUserChatCustomPrompt(ctx, userID)
}
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
return 0, err
}
return q.db.GetUserChatSpendInPeriod(ctx, arg)
}
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return 0, err
@@ -3821,13 +3839,6 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
return q.db.GetUserCount(ctx, includeSystem)
}
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.GetUserGroupSpendLimit(ctx, userID)
}
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
@@ -4275,6 +4286,15 @@ func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuil
return q.db.GetWorkspaceBuildParameters(ctx, workspaceBuildID)
}
func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}
return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
}
func (q *querier) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, buildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
// Fetching the provisioner state requires Update permission on the template.
return fetchWithAction(q.log, q.auth, policy.ActionUpdate, q.db.GetWorkspaceBuildProvisionerStateByID)(ctx, buildID)
@@ -4922,6 +4942,16 @@ func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertU
return q.db.InsertUserGroupsByID(ctx, arg)
}
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
// This will add the user to all named groups. This counts as updating a group.
// NOTE: instead of checking if the user has permission to update each group, we instead
// check if the user has permission to update *a* group in the org.
fetch := func(_ context.Context, arg database.InsertUserGroupsByNameParams) (rbac.Objecter, error) {
return rbac.ResourceGroup.InOrg(arg.OrganizationID), nil
}
return update(q.log, q.auth, fetch, q.db.InsertUserGroupsByName)(ctx, arg)
}
// TODO: Should this be in system.go?
func (q *querier) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceUserObject(arg.UserID)); err != nil {
@@ -5186,18 +5216,12 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
}
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitGroupOverrides(ctx)
func (q *querier) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChatsByRootID)(ctx, rootChatID)
}
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return nil, err
}
return q.db.ListChatUsageLimitOverrides(ctx)
func (q *querier) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListChildChatsByParentID)(ctx, parentChatID)
}
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
@@ -5311,6 +5335,14 @@ func (q *querier) RegisterWorkspaceProxy(ctx context.Context, arg database.Regis
return updateWithReturn(q.log, q.auth, fetch, q.db.RegisterWorkspaceProxy)(ctx, arg)
}
func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
// This is a system function to clear user groups in group sync.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.RemoveUserFromAllGroups(ctx, userID)
}
func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
// This is a system function to clear user groups in group sync.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
@@ -5319,13 +5351,6 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
return 0, err
}
return q.db.ResolveUserChatSpendLimit(ctx, userID)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -5652,6 +5677,13 @@ func (q *querier) UpdateOAuth2ProviderAppByID(ctx context.Context, arg database.
return q.db.UpdateOAuth2ProviderAppByID(ctx, arg)
}
func (q *querier) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg database.UpdateOAuth2ProviderAppSecretByIDParams) (database.OAuth2ProviderAppSecret, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceOauth2AppSecret); err != nil {
return database.OAuth2ProviderAppSecret{}, err
}
return q.db.UpdateOAuth2ProviderAppSecretByID(ctx, arg)
}
func (q *querier) UpdateOrganization(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {
fetch := func(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {
return q.db.GetOrganizationByID(ctx, arg.ID)
@@ -6104,6 +6136,13 @@ func (q *querier) UpdateUserLink(ctx context.Context, arg database.UpdateUserLin
return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpdateUserLink)(ctx, arg)
}
func (q *querier) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return database.UserLink{}, err
}
return q.db.UpdateUserLinkedID(ctx, arg)
}
func (q *querier) UpdateUserLoginType(ctx context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return database.User{}, err
@@ -6509,13 +6548,6 @@ func (q *querier) UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg datab
return q.db.UpdateWorkspacesTTLByTemplateID(ctx, arg)
}
func (q *querier) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return false, err
}
return q.db.UpsertAISeatState(ctx, arg)
}
func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -6523,6 +6555,13 @@ func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) e
return q.db.UpsertAnnouncementBanners(ctx, value)
}
func (q *querier) UpsertAppSecurityKey(ctx context.Context, data string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertAppSecurityKey(ctx, data)
}
func (q *querier) UpsertApplicationName(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -6568,27 +6607,6 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
return q.db.UpsertChatSystemPrompt(ctx, value)
}
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatUsageLimitConfig{}, err
}
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
}
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.UpsertChatUsageLimitUserOverrideRow{}, err
}
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
}
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
return database.ConnectionLog{}, err
@@ -6596,6 +6614,13 @@ func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertCo
return q.db.UpsertConnectionLog(ctx, arg)
}
func (q *querier) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
}
func (q *querier) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
@@ -6645,6 +6670,13 @@ func (q *querier) UpsertOAuth2GithubDefaultEligible(ctx context.Context, eligibl
return q.db.UpsertOAuth2GithubDefaultEligible(ctx, eligible)
}
func (q *querier) UpsertOAuthSigningKey(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.UpsertOAuthSigningKey(ctx, value)
}
func (q *querier) UpsertPrebuildsSettings(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -6834,6 +6866,10 @@ func (q *querier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context,
return q.GetWorkspacesAndAgentsByOwnerID(ctx, ownerID)
}
func (q *querier) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, _ rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) {
return q.GetWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs)
}
// GetAuthorizedUsers is not required for dbauthz since GetUsers is already
// authenticated.
func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
+163 -241
View File
@@ -373,15 +373,14 @@ func (s *MethodTestSuite) TestConnectionLogs() {
}
func (s *MethodTestSuite) TestChats() {
s.Run("AcquireChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.AcquireChatsParams{
s.Run("AcquireChat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.AcquireChatParams{
StartedAt: dbtime.Now(),
WorkerID: uuid.New(),
NumChats: 1,
}
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().AcquireChats(gomock.Any(), arg).Return([]database.Chat{chat}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.Chat{chat})
dbm.EXPECT().AcquireChat(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(chat)
}))
s.Run("DeleteAllChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
@@ -401,6 +400,12 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
}))
s.Run("DeleteChatMessagesByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().DeleteChatMessagesByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionDelete).Returns()
}))
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.DeleteChatMessagesAfterIDParams{
@@ -438,85 +443,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerChatParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
rows := []database.GetChatCostPerChatRow{{
RootChatID: uuid.New(),
ChatTitle: "chat-cost",
TotalCostMicros: 123,
MessageCount: 4,
TotalInputTokens: 55,
TotalOutputTokens: 89,
}}
dbm.EXPECT().GetChatCostPerChat(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerModelParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
rows := []database.GetChatCostPerModelRow{{
ModelConfigID: uuid.New(),
DisplayName: "GPT 4.1",
Provider: "openai",
Model: "gpt-4.1",
TotalCostMicros: 456,
MessageCount: 7,
TotalInputTokens: 144,
TotalOutputTokens: 233,
}}
dbm.EXPECT().GetChatCostPerModel(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostPerUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerUserParams{
PageOffset: 0,
PageLimit: 25,
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
Username: "cost-user",
}
rows := []database.GetChatCostPerUserRow{{
UserID: uuid.New(),
Username: "cost-user",
Name: "Cost User",
AvatarURL: "https://example.com/avatar.png",
TotalCostMicros: 789,
MessageCount: 11,
ChatCount: 3,
TotalInputTokens: 377,
TotalOutputTokens: 610,
TotalCount: 1,
}}
dbm.EXPECT().GetChatCostPerUser(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostSummaryParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
row := database.GetChatCostSummaryRow{
TotalCostMicros: 987,
PricedMessageCount: 12,
UnpricedMessageCount: 2,
TotalInputTokens: 400,
TotalOutputTokens: 800,
}
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
}))
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
@@ -562,18 +488,10 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := database.GetChatMessagesByChatIDDescPaginatedParams{ChatID: chat.ID, BeforeID: 0, LimitVal: 50}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
arg := database.GetLastChatMessageByRoleParams{ChatID: chat.ID, Role: database.ChatMessageRoleAssistant}
arg := database.GetLastChatMessageByRoleParams{ChatID: chat.ID, Role: "assistant"}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetLastChatMessageByRole(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msg)
@@ -595,6 +513,15 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetDefaultChatModelConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigByProviderAndModel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
config := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
args := database.GetChatModelConfigByProviderAndModelParams{
Provider: config.Provider,
Model: config.Model,
}
dbm.EXPECT().GetChatModelConfigByProviderAndModel(gomock.Any(), args).Return(config, nil).AnyTimes()
check.Args(args).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
@@ -648,6 +575,20 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
}))
s.Run("ListChatsByRootID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
rootChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}})
dbm.EXPECT().ListChatsByRootID(gomock.Any(), rootChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(rootChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("ListChildChatsByParentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
parentChatID := uuid.New()
chatA := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
chatB := testutil.Fake(s.T(), faker, database.Chat{ParentChatID: uuid.NullUUID{UUID: parentChatID, Valid: true}})
dbm.EXPECT().ListChildChatsByParentID(gomock.Any(), parentChatID).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
check.Args(parentChatID).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
}))
s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
threshold := dbtime.Now()
chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})}
@@ -845,142 +786,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetUserChatSpendInPeriodParams{
UserID: uuid.New(),
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
spend := int64(123)
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
}))
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(456)
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
limit := int64(789)
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
}))
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: true,
DefaultLimitMicros: 1_000_000,
Period: "monthly",
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
}))
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
override := database.GetChatUsageLimitGroupOverrideRow{
GroupID: groupID,
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
override := database.GetChatUsageLimitUserOverrideRow{
UserID: userID,
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
}
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
}))
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
GroupID: uuid.New(),
GroupName: "group-name",
GroupDisplayName: "Group Name",
GroupAvatarUrl: "https://example.com/group.png",
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
MemberCount: 5,
}}
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
overrides := []database.ListChatUsageLimitOverridesRow{{
UserID: uuid.New(),
Username: "usage-limit-user",
Name: "Usage Limit User",
AvatarURL: "https://example.com/avatar.png",
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
}}
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
}))
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
now := dbtime.Now()
arg := database.UpsertChatUsageLimitConfigParams{
Enabled: true,
DefaultLimitMicros: 6_000_000,
Period: "monthly",
}
config := database.ChatUsageLimitConfig{
ID: 1,
Singleton: true,
Enabled: arg.Enabled,
DefaultLimitMicros: arg.DefaultLimitMicros,
Period: arg.Period,
CreatedAt: now,
UpdatedAt: now,
}
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
}))
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitGroupOverrideParams{
SpendLimitMicros: 7_000_000,
GroupID: uuid.New(),
}
override := database.UpsertChatUsageLimitGroupOverrideRow{
GroupID: arg.GroupID,
Name: "group",
DisplayName: "Group",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.UpsertChatUsageLimitUserOverrideParams{
SpendLimitMicros: 8_000_000,
UserID: uuid.New(),
}
override := database.UpsertChatUsageLimitUserOverrideRow{
UserID: arg.UserID,
Username: "user",
Name: "User",
AvatarURL: "",
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
}
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
}))
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
groupID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
userID := uuid.New()
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
}
func (s *MethodTestSuite) TestFile() {
@@ -1000,6 +805,12 @@ func (s *MethodTestSuite) TestFile() {
dbm.EXPECT().GetFileTemplates(gomock.Any(), f.ID).Return([]database.GetFileTemplatesRow{}, nil).AnyTimes()
check.Args(f.ID).Asserts(f, policy.ActionRead).Returns(f)
}))
s.Run("GetFileIDByTemplateVersionID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
tvID := uuid.New()
fileID := uuid.New()
dbm.EXPECT().GetFileIDByTemplateVersionID(gomock.Any(), tvID).Return(fileID, nil).AnyTimes()
check.Args(tvID).Asserts(rbac.ResourceFile.WithID(fileID), policy.ActionRead).Returns(fileID)
}))
s.Run("InsertFile", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
ret := testutil.Fake(s.T(), faker, database.File{CreatedBy: u.ID})
@@ -1103,6 +914,16 @@ func (s *MethodTestSuite) TestGroup() {
check.Args(arg).Asserts(g, policy.ActionUpdate).Returns()
}))
s.Run("InsertUserGroupsByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
o := testutil.Fake(s.T(), faker, database.Organization{})
u1 := testutil.Fake(s.T(), faker, database.User{})
g1 := testutil.Fake(s.T(), faker, database.Group{OrganizationID: o.ID})
g2 := testutil.Fake(s.T(), faker, database.Group{OrganizationID: o.ID})
arg := database.InsertUserGroupsByNameParams{OrganizationID: o.ID, UserID: u1.ID, GroupNames: slice.New(g1.Name, g2.Name)}
dbm.EXPECT().InsertUserGroupsByName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
}))
s.Run("InsertUserGroupsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
o := testutil.Fake(s.T(), faker, database.Organization{})
u1 := testutil.Fake(s.T(), faker, database.User{})
@@ -1115,6 +936,12 @@ func (s *MethodTestSuite) TestGroup() {
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(returns)
}))
s.Run("RemoveUserFromAllGroups", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u1 := testutil.Fake(s.T(), faker, database.User{})
dbm.EXPECT().RemoveUserFromAllGroups(gomock.Any(), u1.ID).Return(nil).AnyTimes()
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
}))
s.Run("RemoveUserFromGroups", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
o := testutil.Fake(s.T(), faker, database.Organization{})
u1 := testutil.Fake(s.T(), faker, database.User{})
@@ -1271,6 +1098,18 @@ func (s *MethodTestSuite) TestProvisionerJob() {
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
}))
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
org := testutil.Fake(s.T(), faker, database.Organization{})
org2 := testutil.Fake(s.T(), faker, database.Organization{})
a := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org.ID})
b := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org2.ID})
ids := []uuid.UUID{a.ID, b.ID}
dbm.EXPECT().GetProvisionerJobsByIDs(gomock.Any(), ids).Return([]database.ProvisionerJob{a, b}, nil).AnyTimes()
check.Args(ids).Asserts(
rbac.ResourceProvisionerJobs.InOrg(org.ID), policy.ActionRead,
rbac.ResourceProvisionerJobs.InOrg(org2.ID), policy.ActionRead,
).OutOfOrder().Returns(slice.New(a, b))
}))
s.Run("GetProvisionerLogsAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
j := testutil.Fake(s.T(), faker, database.ProvisionerJob{Type: database.ProvisionerJobTypeWorkspaceBuild})
@@ -1303,14 +1142,6 @@ func (s *MethodTestSuite) TestProvisionerJob() {
}
func (s *MethodTestSuite) TestLicense() {
s.Run("GetActiveAISeatCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetActiveAISeatCount(gomock.Any()).Return(int64(100), nil).AnyTimes()
check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead).Returns(int64(100))
}))
s.Run("UpsertAISeatState", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertAISeatState(gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes()
check.Args(database.UpsertAISeatStateParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
a := database.License{ID: 1}
b := database.License{ID: 2}
@@ -1742,6 +1573,14 @@ func (s *MethodTestSuite) TestTemplate() {
dbm.EXPECT().GetTemplateVersionsCreatedAfter(gomock.Any(), now.Add(-time.Hour)).Return([]database.TemplateVersion{}, nil).AnyTimes()
check.Args(now.Add(-time.Hour)).Asserts(rbac.ResourceTemplate.All(), policy.ActionRead)
}))
s.Run("GetTemplateVersionHasAITask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
t := testutil.Fake(s.T(), faker, database.Template{})
tv := testutil.Fake(s.T(), faker, database.TemplateVersion{TemplateID: uuid.NullUUID{UUID: t.ID, Valid: true}})
dbm.EXPECT().GetTemplateVersionByID(gomock.Any(), tv.ID).Return(tv, nil).AnyTimes()
dbm.EXPECT().GetTemplateByID(gomock.Any(), t.ID).Return(t, nil).AnyTimes()
dbm.EXPECT().GetTemplateVersionHasAITask(gomock.Any(), tv.ID).Return(false, nil).AnyTimes()
check.Args(tv.ID).Asserts(t, policy.ActionRead)
}))
s.Run("GetTemplatesWithFilter", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
a := testutil.Fake(s.T(), faker, database.Template{})
arg := database.GetTemplatesWithFilterParams{}
@@ -2125,6 +1964,12 @@ func (s *MethodTestSuite) TestUser() {
dbm.EXPECT().UpdateUserStatus(gomock.Any(), arg).Return(u, nil).AnyTimes()
check.Args(arg).Asserts(u, policy.ActionUpdate).Returns(u)
}))
s.Run("DeleteGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.GitSSHKey{})
dbm.EXPECT().GetGitSSHKey(gomock.Any(), key.UserID).Return(key, nil).AnyTimes()
dbm.EXPECT().DeleteGitSSHKey(gomock.Any(), key.UserID).Return(nil).AnyTimes()
check.Args(key.UserID).Asserts(rbac.ResourceUserObject(key.UserID), policy.ActionUpdatePersonal).Returns()
}))
s.Run("GetGitSSHKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
key := testutil.Fake(s.T(), faker, database.GitSSHKey{})
dbm.EXPECT().GetGitSSHKey(gomock.Any(), key.UserID).Return(key, nil).AnyTimes()
@@ -2379,6 +2224,18 @@ func (s *MethodTestSuite) TestWorkspace() {
// No asserts here because SQLFilter.
check.Args(ws.OwnerID, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("GetWorkspaceBuildParametersByBuildIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{}
dbm.EXPECT().GetAuthorizedWorkspaceBuildParametersByBuildIDs(gomock.Any(), ids, gomock.Any()).Return([]database.WorkspaceBuildParameter{}, nil).AnyTimes()
// no asserts here because SQLFilter
check.Args(ids).Asserts()
}))
s.Run("GetAuthorizedWorkspaceBuildParametersByBuildIDs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ids := []uuid.UUID{}
dbm.EXPECT().GetAuthorizedWorkspaceBuildParametersByBuildIDs(gomock.Any(), ids, gomock.Any()).Return([]database.WorkspaceBuildParameter{}, nil).AnyTimes()
// no asserts here because SQLFilter
check.Args(ids, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("GetWorkspaceACLByID", s.Mocked(func(dbM *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
@@ -3540,6 +3397,13 @@ func (s *MethodTestSuite) TestCryptoKeys() {
}
func (s *MethodTestSuite) TestSystemFunctions() {
s.Run("UpdateUserLinkedID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
u := testutil.Fake(s.T(), faker, database.User{})
l := testutil.Fake(s.T(), faker, database.UserLink{UserID: u.ID})
arg := database.UpdateUserLinkedIDParams{UserID: u.ID, LinkedID: l.LinkedID, LoginType: database.LoginTypeGithub}
dbm.EXPECT().UpdateUserLinkedID(gomock.Any(), arg).Return(l, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(l)
}))
s.Run("GetLatestWorkspaceAppStatusByAppID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
appID := uuid.New()
dbm.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), appID).Return(database.WorkspaceAppStatus{}, nil).AnyTimes()
@@ -3732,6 +3596,16 @@ func (s *MethodTestSuite) TestSystemFunctions() {
Asserts(rbac.ResourceSystem, policy.ActionRead).
Returns([]database.WorkspaceAgent{agt})
}))
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
org := testutil.Fake(s.T(), faker, database.Organization{})
a := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org.ID})
b := testutil.Fake(s.T(), faker, database.ProvisionerJob{OrganizationID: org.ID})
ids := []uuid.UUID{a.ID, b.ID}
dbm.EXPECT().GetProvisionerJobsByIDs(gomock.Any(), ids).Return([]database.ProvisionerJob{a, b}, nil).AnyTimes()
check.Args(ids).
Asserts(rbac.ResourceProvisionerJobs.InOrg(org.ID), policy.ActionRead).
Returns(slice.New(a, b))
}))
s.Run("DeleteWorkspaceSubAgentByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
ws := testutil.Fake(s.T(), faker, database.Workspace{})
agent := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
@@ -3915,11 +3789,29 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), arg).Return([]database.WorkspaceAgentLogSource{}, nil).AnyTimes()
check.Args(arg).Asserts()
}))
s.Run("GetTemplateDAUs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetTemplateDAUsParams{}
dbm.EXPECT().GetTemplateDAUs(gomock.Any(), arg).Return([]database.GetTemplateDAUsRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetActiveWorkspaceBuildsByTemplateID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
id := uuid.New()
dbm.EXPECT().GetActiveWorkspaceBuildsByTemplateID(gomock.Any(), id).Return([]database.WorkspaceBuild{}, nil).AnyTimes()
check.Args(id).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.WorkspaceBuild{})
}))
s.Run("GetDeploymentDAUs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
tz := int32(0)
dbm.EXPECT().GetDeploymentDAUs(gomock.Any(), tz).Return([]database.GetDeploymentDAUsRow{}, nil).AnyTimes()
check.Args(tz).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetAppSecurityKey(gomock.Any()).Return("", sql.ErrNoRows).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows)
}))
s.Run("UpsertAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertAppSecurityKey(gomock.Any(), "foo").Return(nil).AnyTimes()
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetApplicationName", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetApplicationName(gomock.Any()).Return("foo", nil).AnyTimes()
check.Args().Asserts()
@@ -3973,6 +3865,22 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().GetProvisionerJobsToBeReaped(gomock.Any(), arg).Return([]database.ProvisionerJob{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceProvisionerJobs, policy.ActionRead)
}))
s.Run("UpsertOAuthSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertOAuthSigningKey(gomock.Any(), "foo").Return(nil).AnyTimes()
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetOAuthSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetOAuthSigningKey(gomock.Any()).Return("foo", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("UpsertCoordinatorResumeTokenSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertCoordinatorResumeTokenSigningKey(gomock.Any(), "foo").Return(nil).AnyTimes()
check.Args("foo").Asserts(rbac.ResourceSystem, policy.ActionUpdate)
}))
s.Run("GetCoordinatorResumeTokenSigningKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetCoordinatorResumeTokenSigningKey(gomock.Any()).Return("foo", nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("InsertMissingGroups", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.InsertMissingGroupsParams{}
dbm.EXPECT().InsertMissingGroups(gomock.Any(), arg).Return([]database.Group{}, xerrors.New("any error")).AnyTimes()
@@ -4626,6 +4534,12 @@ func (s *MethodTestSuite) TestOAuth2ProviderApps() {
UpdatedAt: app.UpdatedAt,
}).Asserts(rbac.ResourceOauth2App, policy.ActionUpdate).Returns(app)
}))
s.Run("GetOAuth2ProviderAppByRegistrationToken", s.Subtest(func(db database.Store, check *expects) {
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{
RegistrationAccessToken: []byte("test-token"),
})
check.Args([]byte("test-token")).Asserts(rbac.ResourceOauth2App, policy.ActionRead).Returns(app)
}))
}
func (s *MethodTestSuite) TestOAuth2ProviderAppSecrets() {
@@ -4670,6 +4584,18 @@ func (s *MethodTestSuite) TestOAuth2ProviderAppSecrets() {
AppID: app.ID,
}).Asserts(rbac.ResourceOauth2AppSecret, policy.ActionCreate)
}))
s.Run("UpdateOAuth2ProviderAppSecretByID", s.Subtest(func(db database.Store, check *expects) {
dbtestutil.DisableForeignKeysAndTriggers(s.T(), db)
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{
AppID: app.ID,
})
secret.LastUsedAt = sql.NullTime{Time: dbtestutil.NowInDefaultTimezone(), Valid: true}
check.Args(database.UpdateOAuth2ProviderAppSecretByIDParams{
ID: secret.ID,
LastUsedAt: secret.LastUsedAt,
}).Asserts(rbac.ResourceOauth2AppSecret, policy.ActionUpdate).Returns(secret)
}))
s.Run("DeleteOAuth2ProviderAppSecretByID", s.Subtest(func(db database.Store, check *expects) {
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{})
secret := dbgen.OAuth2ProviderAppSecret(s.T(), db, database.OAuth2ProviderAppSecret{
@@ -5532,10 +5458,6 @@ func TestAsChatd(t *testing.T) {
// DeploymentConfig read.
err = auth.Authorize(ctx, actor, policy.ActionRead, rbac.ResourceDeploymentConfig)
require.NoError(t, err, "deployment config read should be allowed")
// User read_personal (needed for GetUserChatCustomPrompt).
err = auth.Authorize(ctx, actor, policy.ActionReadPersonal, rbac.ResourceUser)
require.NoError(t, err, "user read_personal should be allowed")
})
t.Run("DeniedActions", func(t *testing.T) {
+10 -20
View File
@@ -578,27 +578,17 @@ func WorkspaceBuildParameters(t testing.TB, db database.Store, orig []database.W
}
func User(t testing.TB, db database.Store, orig database.User) database.User {
loginType := takeFirst(orig.LoginType, database.LoginTypePassword)
email := takeFirst(orig.Email, testutil.GetRandomName(t))
// A DB constraint requires login_type = 'none' and email = '' for service
// accounts.
if orig.IsServiceAccount {
loginType = database.LoginTypeNone
email = ""
}
user, err := db.InsertUser(genCtx, database.InsertUserParams{
ID: takeFirst(orig.ID, uuid.New()),
Email: email,
Username: takeFirst(orig.Username, testutil.GetRandomName(t)),
Name: takeFirst(orig.Name, testutil.GetRandomName(t)),
HashedPassword: takeFirstSlice(orig.HashedPassword, []byte(must(cryptorand.String(32)))),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
RBACRoles: takeFirstSlice(orig.RBACRoles, []string{}),
LoginType: loginType,
Status: string(takeFirst(orig.Status, database.UserStatusDormant)),
IsServiceAccount: orig.IsServiceAccount,
ID: takeFirst(orig.ID, uuid.New()),
Email: takeFirst(orig.Email, testutil.GetRandomName(t)),
Username: takeFirst(orig.Username, testutil.GetRandomName(t)),
Name: takeFirst(orig.Name, testutil.GetRandomName(t)),
HashedPassword: takeFirstSlice(orig.HashedPassword, []byte(must(cryptorand.String(32)))),
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
RBACRoles: takeFirstSlice(orig.RBACRoles, []string{}),
LoginType: takeFirst(orig.LoginType, database.LoginTypePassword),
Status: string(takeFirst(orig.Status, database.UserStatusDormant)),
})
require.NoError(t, err, "insert user")
-14
View File
@@ -213,20 +213,6 @@ func TestGenerator(t *testing.T) {
require.Equal(t, exp, must(db.GetUserByID(context.Background(), exp.ID)))
})
t.Run("ServiceAccountUser", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{
IsServiceAccount: true,
Email: "should-be-overridden@coder.com",
LoginType: database.LoginTypePassword,
})
require.True(t, user.IsServiceAccount)
require.Empty(t, user.Email)
require.Equal(t, database.LoginTypeNone, user.LoginType)
require.Equal(t, user, must(db.GetUserByID(context.Background(), user.ID)))
})
t.Run("SSHKey", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
+180 -164
View File
@@ -104,11 +104,11 @@ func (m queryMetricsStore) DeleteOrganization(ctx context.Context, id uuid.UUID)
return r0
}
func (m queryMetricsStore) AcquireChats(ctx context.Context, arg database.AcquireChatsParams) ([]database.Chat, error) {
func (m queryMetricsStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.AcquireChats(ctx, arg)
m.queryLatencies.WithLabelValues("AcquireChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChats").Inc()
r0, r1 := m.s.AcquireChat(ctx, arg)
m.queryLatencies.WithLabelValues("AcquireChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireChat").Inc()
return r0, r1
}
@@ -288,14 +288,6 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
return r0, r1
}
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
start := time.Now()
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
@@ -392,6 +384,14 @@ func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatMessagesByChatID(ctx, chatID)
m.queryLatencies.WithLabelValues("DeleteChatMessagesByChatID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesByChatID").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
@@ -416,22 +416,6 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
return r0
}
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
@@ -464,6 +448,14 @@ func (m queryMetricsStore) DeleteExternalAuthLink(ctx context.Context, arg datab
return r0
}
func (m queryMetricsStore) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteGitSSHKey(ctx, userID)
m.queryLatencies.WithLabelValues("DeleteGitSSHKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteGitSSHKey").Inc()
return r0
}
func (m queryMetricsStore) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteGroupByID(ctx, id)
@@ -895,14 +887,6 @@ func (m queryMetricsStore) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed
return r0, r1
}
func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetActiveAISeatCount(ctx)
m.queryLatencies.WithLabelValues("GetActiveAISeatCount").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveAISeatCount").Inc()
return r0, r1
}
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
start := time.Now()
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
@@ -967,6 +951,14 @@ func (m queryMetricsStore) GetAnnouncementBanners(ctx context.Context) (string,
return r0, r1
}
func (m queryMetricsStore) GetAppSecurityKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetAppSecurityKey(ctx)
m.queryLatencies.WithLabelValues("GetAppSecurityKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAppSecurityKey").Inc()
return r0, r1
}
func (m queryMetricsStore) GetApplicationName(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetApplicationName(ctx)
@@ -1015,38 +1007,6 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerChat(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerChat").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerUser(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerUser").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerUser").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostSummary(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostSummary").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
@@ -1095,14 +1055,6 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDDescPaginated").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDDescPaginated").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
@@ -1119,6 +1071,14 @@ func (m queryMetricsStore) GetChatModelConfigByID(ctx context.Context, id uuid.U
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigByProviderAndModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatModelConfigByProviderAndModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigByProviderAndModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigs(ctx)
@@ -1167,30 +1127,6 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
@@ -1207,6 +1143,14 @@ func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg data
return r0, r1
}
func (m queryMetricsStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetCoordinatorResumeTokenSigningKey(ctx)
m.queryLatencies.WithLabelValues("GetCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetCoordinatorResumeTokenSigningKey").Inc()
return r0, r1
}
func (m queryMetricsStore) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
start := time.Now()
r0, r1 := m.s.GetCryptoKeyByFeatureAndSequence(ctx, arg)
@@ -1271,6 +1215,14 @@ func (m queryMetricsStore) GetDefaultProxyConfig(ctx context.Context) (database.
return r0, r1
}
func (m queryMetricsStore) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) {
start := time.Now()
r0, r1 := m.s.GetDeploymentDAUs(ctx, tzOffset)
m.queryLatencies.WithLabelValues("GetDeploymentDAUs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetDeploymentDAUs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetDeploymentID(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetDeploymentID(ctx)
@@ -1367,6 +1319,14 @@ func (m queryMetricsStore) GetFileByID(ctx context.Context, id uuid.UUID) (datab
return r0, r1
}
func (m queryMetricsStore) GetFileIDByTemplateVersionID(ctx context.Context, templateVersionID uuid.UUID) (uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.GetFileIDByTemplateVersionID(ctx, templateVersionID)
m.queryLatencies.WithLabelValues("GetFileIDByTemplateVersionID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetFileIDByTemplateVersionID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) {
start := time.Now()
r0, r1 := m.s.GetFileTemplates(ctx, fileID)
@@ -1607,6 +1567,14 @@ func (m queryMetricsStore) GetOAuth2ProviderAppByID(ctx context.Context, id uuid
return r0, r1
}
func (m queryMetricsStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
start := time.Now()
r0, r1 := m.s.GetOAuth2ProviderAppByRegistrationToken(ctx, registrationAccessToken)
m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppByRegistrationToken").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetOAuth2ProviderAppByRegistrationToken").Inc()
return r0, r1
}
func (m queryMetricsStore) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) {
start := time.Now()
r0, r1 := m.s.GetOAuth2ProviderAppCodeByID(ctx, id)
@@ -1679,6 +1647,14 @@ func (m queryMetricsStore) GetOAuth2ProviderAppsByUserID(ctx context.Context, us
return r0, r1
}
func (m queryMetricsStore) GetOAuthSigningKey(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetOAuthSigningKey(ctx)
m.queryLatencies.WithLabelValues("GetOAuthSigningKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetOAuthSigningKey").Inc()
return r0, r1
}
func (m queryMetricsStore) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
start := time.Now()
r0, r1 := m.s.GetOrganizationByID(ctx, id)
@@ -1879,6 +1855,14 @@ func (m queryMetricsStore) GetProvisionerJobTimingsByJobID(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerJobsByIDs(ctx, ids)
m.queryLatencies.WithLabelValues("GetProvisionerJobsByIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetProvisionerJobsByIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerJobsByIDsWithQueuePosition(ctx, arg)
@@ -2127,6 +2111,14 @@ func (m queryMetricsStore) GetTemplateByOrganizationAndName(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateDAUs(ctx, arg)
m.queryLatencies.WithLabelValues("GetTemplateDAUs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTemplateDAUs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTemplateInsights(ctx context.Context, arg database.GetTemplateInsightsParams) (database.GetTemplateInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateInsights(ctx, arg)
@@ -2199,6 +2191,14 @@ func (m queryMetricsStore) GetTemplateVersionByTemplateIDAndName(ctx context.Con
return r0, r1
}
func (m queryMetricsStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateVersionHasAITask(ctx, id)
m.queryLatencies.WithLabelValues("GetTemplateVersionHasAITask").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetTemplateVersionHasAITask").Inc()
return r0, r1
}
func (m queryMetricsStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
start := time.Now()
r0, r1 := m.s.GetTemplateVersionParameters(ctx, templateVersionID)
@@ -2319,14 +2319,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
return r0, r1
}
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
@@ -2335,14 +2327,6 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
return r0, r1
}
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
start := time.Now()
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
@@ -2727,6 +2711,14 @@ func (m queryMetricsStore) GetWorkspaceBuildParameters(ctx context.Context, work
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIds)
m.queryLatencies.WithLabelValues("GetWorkspaceBuildParametersByBuildIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetWorkspaceBuildParametersByBuildIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
start := time.Now()
r0, r1 := m.s.GetWorkspaceBuildProvisionerStateByID(ctx, workspaceBuildID)
@@ -3367,6 +3359,14 @@ func (m queryMetricsStore) InsertUserGroupsByID(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
start := time.Now()
r0 := m.s.InsertUserGroupsByName(ctx, arg)
m.queryLatencies.WithLabelValues("InsertUserGroupsByName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertUserGroupsByName").Inc()
return r0
}
func (m queryMetricsStore) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) {
start := time.Now()
r0, r1 := m.s.InsertUserLink(ctx, arg)
@@ -3575,19 +3575,19 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
func (m queryMetricsStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
r0, r1 := m.s.ListChatsByRootID(ctx, rootChatID)
m.queryLatencies.WithLabelValues("ListChatsByRootID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatsByRootID").Inc()
return r0, r1
}
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
func (m queryMetricsStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
start := time.Now()
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
r0, r1 := m.s.ListChildChatsByParentID(ctx, parentChatID)
m.queryLatencies.WithLabelValues("ListChildChatsByParentID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChildChatsByParentID").Inc()
return r0, r1
}
@@ -3695,6 +3695,14 @@ func (m queryMetricsStore) RegisterWorkspaceProxy(ctx context.Context, arg datab
return r0, r1
}
func (m queryMetricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
start := time.Now()
r0 := m.s.RemoveUserFromAllGroups(ctx, userID)
m.queryLatencies.WithLabelValues("RemoveUserFromAllGroups").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "RemoveUserFromAllGroups").Inc()
return r0
}
func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.RemoveUserFromGroups(ctx, arg)
@@ -3703,14 +3711,6 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
return r0, r1
}
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
@@ -3943,6 +3943,14 @@ func (m queryMetricsStore) UpdateOAuth2ProviderAppByID(ctx context.Context, arg
return r0, r1
}
func (m queryMetricsStore) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg database.UpdateOAuth2ProviderAppSecretByIDParams) (database.OAuth2ProviderAppSecret, error) {
start := time.Now()
r0, r1 := m.s.UpdateOAuth2ProviderAppSecretByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateOAuth2ProviderAppSecretByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOAuth2ProviderAppSecretByID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateOrganization(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {
start := time.Now()
r0, r1 := m.s.UpdateOrganization(ctx, arg)
@@ -4230,6 +4238,14 @@ func (m queryMetricsStore) UpdateUserLink(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserLinkedID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserLinkedID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserLinkedID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateUserLoginType(ctx context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserLoginType(ctx, arg)
@@ -4510,14 +4526,6 @@ func (m queryMetricsStore) UpdateWorkspacesTTLByTemplateID(ctx context.Context,
return r0
}
func (m queryMetricsStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
start := time.Now()
r0, r1 := m.s.UpsertAISeatState(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertAISeatState").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertAISeatState").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertAnnouncementBanners(ctx, value)
@@ -4526,6 +4534,14 @@ func (m queryMetricsStore) UpsertAnnouncementBanners(ctx context.Context, value
return r0
}
func (m queryMetricsStore) UpsertAppSecurityKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertAppSecurityKey(ctx, value)
m.queryLatencies.WithLabelValues("UpsertAppSecurityKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertAppSecurityKey").Inc()
return r0
}
func (m queryMetricsStore) UpsertApplicationName(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertApplicationName(ctx, value)
@@ -4566,30 +4582,6 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
return r0
}
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
start := time.Now()
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
start := time.Now()
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
@@ -4598,6 +4590,14 @@ func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertCoordinatorResumeTokenSigningKey(ctx, value)
m.queryLatencies.WithLabelValues("UpsertCoordinatorResumeTokenSigningKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertCoordinatorResumeTokenSigningKey").Inc()
return r0
}
func (m queryMetricsStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
start := time.Now()
r0 := m.s.UpsertDefaultProxy(ctx, arg)
@@ -4654,6 +4654,14 @@ func (m queryMetricsStore) UpsertOAuth2GithubDefaultEligible(ctx context.Context
return r0
}
func (m queryMetricsStore) UpsertOAuthSigningKey(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertOAuthSigningKey(ctx, value)
m.queryLatencies.WithLabelValues("UpsertOAuthSigningKey").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertOAuthSigningKey").Inc()
return r0
}
func (m queryMetricsStore) UpsertPrebuildsSettings(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertPrebuildsSettings(ctx, value)
@@ -4822,6 +4830,14 @@ func (m queryMetricsStore) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prepared)
m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaceBuildParametersByBuildIDs").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedWorkspaceBuildParametersByBuildIDs").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
start := time.Now()
r0, r1 := m.s.GetAuthorizedUsers(ctx, arg, prepared)
+329 -304
View File
@@ -44,19 +44,19 @@ func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// AcquireChats mocks base method.
func (m *MockStore) AcquireChats(ctx context.Context, arg database.AcquireChatsParams) ([]database.Chat, error) {
// AcquireChat mocks base method.
func (m *MockStore) AcquireChat(ctx context.Context, arg database.AcquireChatParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireChats", ctx, arg)
ret0, _ := ret[0].([]database.Chat)
ret := m.ctrl.Call(m, "AcquireChat", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireChats indicates an expected call of AcquireChats.
func (mr *MockStoreMockRecorder) AcquireChats(ctx, arg any) *gomock.Call {
// AcquireChat indicates an expected call of AcquireChat.
func (mr *MockStoreMockRecorder) AcquireChat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireChats", reflect.TypeOf((*MockStore)(nil).AcquireChats), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireChat", reflect.TypeOf((*MockStore)(nil).AcquireChat), ctx, arg)
}
// AcquireLock mocks base method.
@@ -424,21 +424,6 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
}
// CountEnabledModelsWithoutPricing mocks base method.
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
}
// CountInProgressPrebuilds mocks base method.
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
m.ctrl.T.Helper()
@@ -612,6 +597,20 @@ func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
}
// DeleteChatMessagesByChatID mocks base method.
func (m *MockStore) DeleteChatMessagesByChatID(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatMessagesByChatID", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatMessagesByChatID indicates an expected call of DeleteChatMessagesByChatID.
func (mr *MockStoreMockRecorder) DeleteChatMessagesByChatID(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesByChatID), ctx, chatID)
}
// DeleteChatModelConfigByID mocks base method.
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -654,34 +653,6 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
}
// DeleteChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
}
// DeleteChatUsageLimitUserOverride mocks base method.
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
}
// DeleteCryptoKey mocks base method.
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -740,6 +711,20 @@ func (mr *MockStoreMockRecorder) DeleteExternalAuthLink(ctx, arg any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExternalAuthLink", reflect.TypeOf((*MockStore)(nil).DeleteExternalAuthLink), ctx, arg)
}
// DeleteGitSSHKey mocks base method.
func (m *MockStore) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteGitSSHKey", ctx, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteGitSSHKey indicates an expected call of DeleteGitSSHKey.
func (mr *MockStoreMockRecorder) DeleteGitSSHKey(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGitSSHKey", reflect.TypeOf((*MockStore)(nil).DeleteGitSSHKey), ctx, userID)
}
// DeleteGroupByID mocks base method.
func (m *MockStore) DeleteGroupByID(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -1521,21 +1506,6 @@ func (mr *MockStoreMockRecorder) GetAPIKeysLastUsedAfter(ctx, lastUsed any) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysLastUsedAfter", reflect.TypeOf((*MockStore)(nil).GetAPIKeysLastUsedAfter), ctx, lastUsed)
}
// GetActiveAISeatCount mocks base method.
func (m *MockStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveAISeatCount", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveAISeatCount indicates an expected call of GetActiveAISeatCount.
func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
}
// GetActivePresetPrebuildSchedules mocks base method.
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
m.ctrl.T.Helper()
@@ -1656,6 +1626,21 @@ func (mr *MockStoreMockRecorder) GetAnnouncementBanners(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnnouncementBanners", reflect.TypeOf((*MockStore)(nil).GetAnnouncementBanners), ctx)
}
// GetAppSecurityKey mocks base method.
func (m *MockStore) GetAppSecurityKey(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAppSecurityKey", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAppSecurityKey indicates an expected call of GetAppSecurityKey.
func (mr *MockStoreMockRecorder) GetAppSecurityKey(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAppSecurityKey", reflect.TypeOf((*MockStore)(nil).GetAppSecurityKey), ctx)
}
// GetApplicationName mocks base method.
func (m *MockStore) GetApplicationName(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -1776,6 +1761,21 @@ func (mr *MockStoreMockRecorder) GetAuthorizedUsers(ctx, arg, prepared any) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedUsers", reflect.TypeOf((*MockStore)(nil).GetAuthorizedUsers), ctx, arg, prepared)
}
// GetAuthorizedWorkspaceBuildParametersByBuildIDs mocks base method.
func (m *MockStore) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizedWorkspaceBuildParametersByBuildIDs", ctx, workspaceBuildIDs, prepared)
ret0, _ := ret[0].([]database.WorkspaceBuildParameter)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthorizedWorkspaceBuildParametersByBuildIDs indicates an expected call of GetAuthorizedWorkspaceBuildParametersByBuildIDs.
func (mr *MockStoreMockRecorder) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prepared any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetAuthorizedWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIDs, prepared)
}
// GetAuthorizedWorkspaces mocks base method.
func (m *MockStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) {
m.ctrl.T.Helper()
@@ -1836,66 +1836,6 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
}
// GetChatCostPerChat mocks base method.
func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerChat", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerChatRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerChat indicates an expected call of GetChatCostPerChat.
func (mr *MockStoreMockRecorder) GetChatCostPerChat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerChat", reflect.TypeOf((*MockStore)(nil).GetChatCostPerChat), ctx, arg)
}
// GetChatCostPerModel mocks base method.
func (m *MockStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerModel", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerModelRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerModel indicates an expected call of GetChatCostPerModel.
func (mr *MockStoreMockRecorder) GetChatCostPerModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerModel", reflect.TypeOf((*MockStore)(nil).GetChatCostPerModel), ctx, arg)
}
// GetChatCostPerUser mocks base method.
func (m *MockStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerUser", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerUserRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerUser indicates an expected call of GetChatCostPerUser.
func (mr *MockStoreMockRecorder) GetChatCostPerUser(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerUser", reflect.TypeOf((*MockStore)(nil).GetChatCostPerUser), ctx, arg)
}
// GetChatCostSummary mocks base method.
func (m *MockStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostSummary", ctx, arg)
ret0, _ := ret[0].(database.GetChatCostSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostSummary indicates an expected call of GetChatCostSummary.
func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
@@ -1986,21 +1926,6 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
}
// GetChatMessagesByChatIDDescPaginated mocks base method.
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDDescPaginated", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesByChatIDDescPaginated indicates an expected call of GetChatMessagesByChatIDDescPaginated.
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDDescPaginated(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDDescPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDDescPaginated), ctx, arg)
}
// GetChatMessagesForPromptByChatID mocks base method.
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -2031,6 +1956,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigByID(ctx, id any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByID", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByID), ctx, id)
}
// GetChatModelConfigByProviderAndModel mocks base method.
func (m *MockStore) GetChatModelConfigByProviderAndModel(ctx context.Context, arg database.GetChatModelConfigByProviderAndModelParams) (database.ChatModelConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigByProviderAndModel", ctx, arg)
ret0, _ := ret[0].(database.ChatModelConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigByProviderAndModel indicates an expected call of GetChatModelConfigByProviderAndModel.
func (mr *MockStoreMockRecorder) GetChatModelConfigByProviderAndModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigByProviderAndModel", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigByProviderAndModel), ctx, arg)
}
// GetChatModelConfigs mocks base method.
func (m *MockStore) GetChatModelConfigs(ctx context.Context) ([]database.ChatModelConfig, error) {
m.ctrl.T.Helper()
@@ -2121,51 +2061,6 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
}
// GetChatUsageLimitConfig mocks base method.
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
}
// GetChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
}
// GetChatUsageLimitUserOverride mocks base method.
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
}
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
@@ -2196,6 +2091,21 @@ func (mr *MockStoreMockRecorder) GetConnectionLogsOffset(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnectionLogsOffset", reflect.TypeOf((*MockStore)(nil).GetConnectionLogsOffset), ctx, arg)
}
// GetCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) GetCoordinatorResumeTokenSigningKey(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCoordinatorResumeTokenSigningKey", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCoordinatorResumeTokenSigningKey indicates an expected call of GetCoordinatorResumeTokenSigningKey.
func (mr *MockStoreMockRecorder) GetCoordinatorResumeTokenSigningKey(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).GetCoordinatorResumeTokenSigningKey), ctx)
}
// GetCryptoKeyByFeatureAndSequence mocks base method.
func (m *MockStore) GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg database.GetCryptoKeyByFeatureAndSequenceParams) (database.CryptoKey, error) {
m.ctrl.T.Helper()
@@ -2316,6 +2226,21 @@ func (mr *MockStoreMockRecorder) GetDefaultProxyConfig(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultProxyConfig", reflect.TypeOf((*MockStore)(nil).GetDefaultProxyConfig), ctx)
}
// GetDeploymentDAUs mocks base method.
func (m *MockStore) GetDeploymentDAUs(ctx context.Context, tzOffset int32) ([]database.GetDeploymentDAUsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDeploymentDAUs", ctx, tzOffset)
ret0, _ := ret[0].([]database.GetDeploymentDAUsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetDeploymentDAUs indicates an expected call of GetDeploymentDAUs.
func (mr *MockStoreMockRecorder) GetDeploymentDAUs(ctx, tzOffset any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDeploymentDAUs", reflect.TypeOf((*MockStore)(nil).GetDeploymentDAUs), ctx, tzOffset)
}
// GetDeploymentID mocks base method.
func (m *MockStore) GetDeploymentID(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -2496,6 +2421,21 @@ func (mr *MockStoreMockRecorder) GetFileByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileByID", reflect.TypeOf((*MockStore)(nil).GetFileByID), ctx, id)
}
// GetFileIDByTemplateVersionID mocks base method.
func (m *MockStore) GetFileIDByTemplateVersionID(ctx context.Context, templateVersionID uuid.UUID) (uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetFileIDByTemplateVersionID", ctx, templateVersionID)
ret0, _ := ret[0].(uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetFileIDByTemplateVersionID indicates an expected call of GetFileIDByTemplateVersionID.
func (mr *MockStoreMockRecorder) GetFileIDByTemplateVersionID(ctx, templateVersionID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFileIDByTemplateVersionID", reflect.TypeOf((*MockStore)(nil).GetFileIDByTemplateVersionID), ctx, templateVersionID)
}
// GetFileTemplates mocks base method.
func (m *MockStore) GetFileTemplates(ctx context.Context, fileID uuid.UUID) ([]database.GetFileTemplatesRow, error) {
m.ctrl.T.Helper()
@@ -2946,6 +2886,21 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppByID(ctx, id any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppByID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppByID), ctx, id)
}
// GetOAuth2ProviderAppByRegistrationToken mocks base method.
func (m *MockStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOAuth2ProviderAppByRegistrationToken", ctx, registrationAccessToken)
ret0, _ := ret[0].(database.OAuth2ProviderApp)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOAuth2ProviderAppByRegistrationToken indicates an expected call of GetOAuth2ProviderAppByRegistrationToken.
func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppByRegistrationToken(ctx, registrationAccessToken any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppByRegistrationToken", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppByRegistrationToken), ctx, registrationAccessToken)
}
// GetOAuth2ProviderAppCodeByID mocks base method.
func (m *MockStore) GetOAuth2ProviderAppCodeByID(ctx context.Context, id uuid.UUID) (database.OAuth2ProviderAppCode, error) {
m.ctrl.T.Helper()
@@ -3081,6 +3036,21 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppsByUserID(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuth2ProviderAppsByUserID", reflect.TypeOf((*MockStore)(nil).GetOAuth2ProviderAppsByUserID), ctx, userID)
}
// GetOAuthSigningKey mocks base method.
func (m *MockStore) GetOAuthSigningKey(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOAuthSigningKey", ctx)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOAuthSigningKey indicates an expected call of GetOAuthSigningKey.
func (mr *MockStoreMockRecorder) GetOAuthSigningKey(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOAuthSigningKey", reflect.TypeOf((*MockStore)(nil).GetOAuthSigningKey), ctx)
}
// GetOrganizationByID mocks base method.
func (m *MockStore) GetOrganizationByID(ctx context.Context, id uuid.UUID) (database.Organization, error) {
m.ctrl.T.Helper()
@@ -3456,6 +3426,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobTimingsByJobID(ctx, jobID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobTimingsByJobID", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobTimingsByJobID), ctx, jobID)
}
// GetProvisionerJobsByIDs mocks base method.
func (m *MockStore) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]database.ProvisionerJob, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProvisionerJobsByIDs", ctx, ids)
ret0, _ := ret[0].([]database.ProvisionerJob)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProvisionerJobsByIDs indicates an expected call of GetProvisionerJobsByIDs.
func (mr *MockStoreMockRecorder) GetProvisionerJobsByIDs(ctx, ids any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsByIDs", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsByIDs), ctx, ids)
}
// GetProvisionerJobsByIDsWithQueuePosition mocks base method.
func (m *MockStore) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, arg database.GetProvisionerJobsByIDsWithQueuePositionParams) ([]database.GetProvisionerJobsByIDsWithQueuePositionRow, error) {
m.ctrl.T.Helper()
@@ -3921,6 +3906,21 @@ func (mr *MockStoreMockRecorder) GetTemplateByOrganizationAndName(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateByOrganizationAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateByOrganizationAndName), ctx, arg)
}
// GetTemplateDAUs mocks base method.
func (m *MockStore) GetTemplateDAUs(ctx context.Context, arg database.GetTemplateDAUsParams) ([]database.GetTemplateDAUsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTemplateDAUs", ctx, arg)
ret0, _ := ret[0].([]database.GetTemplateDAUsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTemplateDAUs indicates an expected call of GetTemplateDAUs.
func (mr *MockStoreMockRecorder) GetTemplateDAUs(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateDAUs", reflect.TypeOf((*MockStore)(nil).GetTemplateDAUs), ctx, arg)
}
// GetTemplateGroupRoles mocks base method.
func (m *MockStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) {
m.ctrl.T.Helper()
@@ -4086,6 +4086,21 @@ func (mr *MockStoreMockRecorder) GetTemplateVersionByTemplateIDAndName(ctx, arg
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionByTemplateIDAndName", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionByTemplateIDAndName), ctx, arg)
}
// GetTemplateVersionHasAITask mocks base method.
func (m *MockStore) GetTemplateVersionHasAITask(ctx context.Context, id uuid.UUID) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTemplateVersionHasAITask", ctx, id)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTemplateVersionHasAITask indicates an expected call of GetTemplateVersionHasAITask.
func (mr *MockStoreMockRecorder) GetTemplateVersionHasAITask(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemplateVersionHasAITask", reflect.TypeOf((*MockStore)(nil).GetTemplateVersionHasAITask), ctx, id)
}
// GetTemplateVersionParameters mocks base method.
func (m *MockStore) GetTemplateVersionParameters(ctx context.Context, templateVersionID uuid.UUID) ([]database.TemplateVersionParameter, error) {
m.ctrl.T.Helper()
@@ -4311,21 +4326,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
}
// GetUserChatSpendInPeriod mocks base method.
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
}
// GetUserCount mocks base method.
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
m.ctrl.T.Helper()
@@ -4341,21 +4341,6 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
}
// GetUserGroupSpendLimit mocks base method.
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
}
// GetUserLatencyInsights mocks base method.
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
m.ctrl.T.Helper()
@@ -5076,6 +5061,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceBuildParameters(ctx, workspaceBuild
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParameters", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParameters), ctx, workspaceBuildID)
}
// GetWorkspaceBuildParametersByBuildIDs mocks base method.
func (m *MockStore) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIds []uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetWorkspaceBuildParametersByBuildIDs", ctx, workspaceBuildIds)
ret0, _ := ret[0].([]database.WorkspaceBuildParameter)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetWorkspaceBuildParametersByBuildIDs indicates an expected call of GetWorkspaceBuildParametersByBuildIDs.
func (mr *MockStoreMockRecorder) GetWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceBuildParametersByBuildIDs", reflect.TypeOf((*MockStore)(nil).GetWorkspaceBuildParametersByBuildIDs), ctx, workspaceBuildIds)
}
// GetWorkspaceBuildProvisionerStateByID mocks base method.
func (m *MockStore) GetWorkspaceBuildProvisionerStateByID(ctx context.Context, workspaceBuildID uuid.UUID) (database.GetWorkspaceBuildProvisionerStateByIDRow, error) {
m.ctrl.T.Helper()
@@ -6280,6 +6280,20 @@ func (mr *MockStoreMockRecorder) InsertUserGroupsByID(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), ctx, arg)
}
// InsertUserGroupsByName mocks base method.
func (m *MockStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertUserGroupsByName", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// InsertUserGroupsByName indicates an expected call of InsertUserGroupsByName.
func (mr *MockStoreMockRecorder) InsertUserGroupsByName(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByName", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByName), ctx, arg)
}
// InsertUserLink mocks base method.
func (m *MockStore) InsertUserLink(ctx context.Context, arg database.InsertUserLinkParams) (database.UserLink, error) {
m.ctrl.T.Helper()
@@ -6695,34 +6709,34 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
}
// ListChatUsageLimitGroupOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
// ListChatsByRootID mocks base method.
func (m *MockStore) ListChatsByRootID(ctx context.Context, rootChatID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
ret := m.ctrl.Call(m, "ListChatsByRootID", ctx, rootChatID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
// ListChatsByRootID indicates an expected call of ListChatsByRootID.
func (mr *MockStoreMockRecorder) ListChatsByRootID(ctx, rootChatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatsByRootID", reflect.TypeOf((*MockStore)(nil).ListChatsByRootID), ctx, rootChatID)
}
// ListChatUsageLimitOverrides mocks base method.
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
// ListChildChatsByParentID mocks base method.
func (m *MockStore) ListChildChatsByParentID(ctx context.Context, parentChatID uuid.UUID) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
ret := m.ctrl.Call(m, "ListChildChatsByParentID", ctx, parentChatID)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
// ListChildChatsByParentID indicates an expected call of ListChildChatsByParentID.
func (mr *MockStoreMockRecorder) ListChildChatsByParentID(ctx, parentChatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChildChatsByParentID", reflect.TypeOf((*MockStore)(nil).ListChildChatsByParentID), ctx, parentChatID)
}
// ListProvisionerKeysByOrganization mocks base method.
@@ -6948,6 +6962,20 @@ func (mr *MockStoreMockRecorder) RegisterWorkspaceProxy(ctx, arg any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterWorkspaceProxy", reflect.TypeOf((*MockStore)(nil).RegisterWorkspaceProxy), ctx, arg)
}
// RemoveUserFromAllGroups mocks base method.
func (m *MockStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveUserFromAllGroups", ctx, userID)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveUserFromAllGroups indicates an expected call of RemoveUserFromAllGroups.
func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), ctx, userID)
}
// RemoveUserFromGroups mocks base method.
func (m *MockStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
@@ -6963,21 +6991,6 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
}
// ResolveUserChatSpendLimit mocks base method.
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
}
// RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
m.ctrl.T.Helper()
@@ -7404,6 +7417,21 @@ func (mr *MockStoreMockRecorder) UpdateOAuth2ProviderAppByID(ctx, arg any) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOAuth2ProviderAppByID", reflect.TypeOf((*MockStore)(nil).UpdateOAuth2ProviderAppByID), ctx, arg)
}
// UpdateOAuth2ProviderAppSecretByID mocks base method.
func (m *MockStore) UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg database.UpdateOAuth2ProviderAppSecretByIDParams) (database.OAuth2ProviderAppSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateOAuth2ProviderAppSecretByID", ctx, arg)
ret0, _ := ret[0].(database.OAuth2ProviderAppSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateOAuth2ProviderAppSecretByID indicates an expected call of UpdateOAuth2ProviderAppSecretByID.
func (mr *MockStoreMockRecorder) UpdateOAuth2ProviderAppSecretByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOAuth2ProviderAppSecretByID", reflect.TypeOf((*MockStore)(nil).UpdateOAuth2ProviderAppSecretByID), ctx, arg)
}
// UpdateOrganization mocks base method.
func (m *MockStore) UpdateOrganization(ctx context.Context, arg database.UpdateOrganizationParams) (database.Organization, error) {
m.ctrl.T.Helper()
@@ -7918,6 +7946,21 @@ func (mr *MockStoreMockRecorder) UpdateUserLink(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserLink", reflect.TypeOf((*MockStore)(nil).UpdateUserLink), ctx, arg)
}
// UpdateUserLinkedID mocks base method.
func (m *MockStore) UpdateUserLinkedID(ctx context.Context, arg database.UpdateUserLinkedIDParams) (database.UserLink, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserLinkedID", ctx, arg)
ret0, _ := ret[0].(database.UserLink)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserLinkedID indicates an expected call of UpdateUserLinkedID.
func (mr *MockStoreMockRecorder) UpdateUserLinkedID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserLinkedID", reflect.TypeOf((*MockStore)(nil).UpdateUserLinkedID), ctx, arg)
}
// UpdateUserLoginType mocks base method.
func (m *MockStore) UpdateUserLoginType(ctx context.Context, arg database.UpdateUserLoginTypeParams) (database.User, error) {
m.ctrl.T.Helper()
@@ -8422,21 +8465,6 @@ func (mr *MockStoreMockRecorder) UpdateWorkspacesTTLByTemplateID(ctx, arg any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspacesTTLByTemplateID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspacesTTLByTemplateID), ctx, arg)
}
// UpsertAISeatState mocks base method.
func (m *MockStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertAISeatState", ctx, arg)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertAISeatState indicates an expected call of UpsertAISeatState.
func (mr *MockStoreMockRecorder) UpsertAISeatState(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAISeatState", reflect.TypeOf((*MockStore)(nil).UpsertAISeatState), ctx, arg)
}
// UpsertAnnouncementBanners mocks base method.
func (m *MockStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
m.ctrl.T.Helper()
@@ -8451,6 +8479,20 @@ func (mr *MockStoreMockRecorder) UpsertAnnouncementBanners(ctx, value any) *gomo
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAnnouncementBanners", reflect.TypeOf((*MockStore)(nil).UpsertAnnouncementBanners), ctx, value)
}
// UpsertAppSecurityKey mocks base method.
func (m *MockStore) UpsertAppSecurityKey(ctx context.Context, value string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertAppSecurityKey", ctx, value)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertAppSecurityKey indicates an expected call of UpsertAppSecurityKey.
func (mr *MockStoreMockRecorder) UpsertAppSecurityKey(ctx, value any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAppSecurityKey", reflect.TypeOf((*MockStore)(nil).UpsertAppSecurityKey), ctx, value)
}
// UpsertApplicationName mocks base method.
func (m *MockStore) UpsertApplicationName(ctx context.Context, value string) error {
m.ctrl.T.Helper()
@@ -8524,51 +8566,6 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
}
// UpsertChatUsageLimitConfig mocks base method.
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
}
// UpsertChatUsageLimitGroupOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
}
// UpsertChatUsageLimitUserOverride mocks base method.
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
}
// UpsertConnectionLog mocks base method.
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
m.ctrl.T.Helper()
@@ -8584,6 +8581,20 @@ func (mr *MockStoreMockRecorder) UpsertConnectionLog(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertConnectionLog", reflect.TypeOf((*MockStore)(nil).UpsertConnectionLog), ctx, arg)
}
// UpsertCoordinatorResumeTokenSigningKey mocks base method.
func (m *MockStore) UpsertCoordinatorResumeTokenSigningKey(ctx context.Context, value string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertCoordinatorResumeTokenSigningKey", ctx, value)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertCoordinatorResumeTokenSigningKey indicates an expected call of UpsertCoordinatorResumeTokenSigningKey.
func (mr *MockStoreMockRecorder) UpsertCoordinatorResumeTokenSigningKey(ctx, value any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertCoordinatorResumeTokenSigningKey", reflect.TypeOf((*MockStore)(nil).UpsertCoordinatorResumeTokenSigningKey), ctx, value)
}
// UpsertDefaultProxy mocks base method.
func (m *MockStore) UpsertDefaultProxy(ctx context.Context, arg database.UpsertDefaultProxyParams) error {
m.ctrl.T.Helper()
@@ -8682,6 +8693,20 @@ func (mr *MockStoreMockRecorder) UpsertOAuth2GithubDefaultEligible(ctx, eligible
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertOAuth2GithubDefaultEligible", reflect.TypeOf((*MockStore)(nil).UpsertOAuth2GithubDefaultEligible), ctx, eligible)
}
// UpsertOAuthSigningKey mocks base method.
func (m *MockStore) UpsertOAuthSigningKey(ctx context.Context, value string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertOAuthSigningKey", ctx, value)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertOAuthSigningKey indicates an expected call of UpsertOAuthSigningKey.
func (mr *MockStoreMockRecorder) UpsertOAuthSigningKey(ctx, value any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertOAuthSigningKey", reflect.TypeOf((*MockStore)(nil).UpsertOAuthSigningKey), ctx, value)
}
// UpsertPrebuildsSettings mocks base method.
func (m *MockStore) UpsertPrebuildsSettings(ctx context.Context, value string) error {
m.ctrl.T.Helper()
+9 -98
View File
@@ -10,11 +10,6 @@ CREATE TYPE agent_key_scope_enum AS ENUM (
'no_user_data'
);
CREATE TYPE ai_seat_usage_reason AS ENUM (
'aibridge',
'task'
);
CREATE TYPE api_key_scope AS ENUM (
'coder:all',
'coder:application_connect',
@@ -270,23 +265,12 @@ CREATE TYPE build_reason AS ENUM (
'task_resume'
);
CREATE TYPE chat_message_role AS ENUM (
'system',
'user',
'assistant',
'tool'
);
CREATE TYPE chat_message_visibility AS ENUM (
'user',
'model',
'both'
);
CREATE TYPE chat_mode AS ENUM (
'computer_use'
);
CREATE TYPE chat_status AS ENUM (
'waiting',
'pending',
@@ -508,8 +492,7 @@ CREATE TYPE resource_type AS ENUM (
'workspace_agent',
'workspace_app',
'prebuilds_settings',
'task',
'ai_seat'
'task'
);
CREATE TYPE startup_script_behavior AS ENUM (
@@ -1052,15 +1035,6 @@ BEGIN
END;
$$;
CREATE TABLE ai_seat_state (
user_id uuid NOT NULL,
first_used_at timestamp with time zone NOT NULL,
last_used_at timestamp with time zone NOT NULL,
last_event_type ai_seat_usage_reason NOT NULL,
last_event_description text NOT NULL,
updated_at timestamp with time zone NOT NULL
);
CREATE TABLE aibridge_interceptions (
id uuid NOT NULL,
initiator_id uuid NOT NULL,
@@ -1213,17 +1187,7 @@ CREATE TABLE chat_diff_statuses (
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
git_branch text DEFAULT ''::text NOT NULL,
git_remote_origin text DEFAULT ''::text NOT NULL,
pull_request_title text DEFAULT ''::text NOT NULL,
pull_request_draft boolean DEFAULT false NOT NULL,
author_login text,
author_avatar_url text,
base_branch text,
pr_number integer,
commits integer,
approved boolean,
reviewer_count integer,
head_branch text
git_remote_origin text DEFAULT ''::text NOT NULL
);
CREATE TABLE chat_files (
@@ -1241,7 +1205,7 @@ CREATE TABLE chat_messages (
chat_id uuid NOT NULL,
model_config_id uuid,
created_at timestamp with time zone DEFAULT now() NOT NULL,
role chat_message_role NOT NULL,
role text NOT NULL,
content jsonb,
visibility chat_message_visibility DEFAULT 'both'::chat_message_visibility NOT NULL,
input_tokens bigint,
@@ -1251,10 +1215,7 @@ CREATE TABLE chat_messages (
cache_creation_tokens bigint,
cache_read_tokens bigint,
context_limit bigint,
compressed boolean DEFAULT false NOT NULL,
created_by uuid,
content_version smallint NOT NULL,
total_cost_micros bigint
compressed boolean DEFAULT false NOT NULL
);
CREATE SEQUENCE chat_messages_id_seq
@@ -1318,28 +1279,6 @@ CREATE SEQUENCE chat_queued_messages_id_seq
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
CREATE TABLE chat_usage_limit_config (
id bigint NOT NULL,
singleton boolean DEFAULT true NOT NULL,
enabled boolean DEFAULT false NOT NULL,
default_limit_micros bigint DEFAULT 0 NOT NULL,
period text DEFAULT 'month'::text NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
);
CREATE SEQUENCE chat_usage_limit_config_id_seq
START WITH 1
INCREMENT BY 1
NO MINVALUE
NO MAXVALUE
CACHE 1;
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
CREATE TABLE chats (
id uuid DEFAULT gen_random_uuid() NOT NULL,
owner_id uuid NOT NULL,
@@ -1355,8 +1294,7 @@ CREATE TABLE chats (
root_chat_id uuid,
last_model_config_id uuid NOT NULL,
archived boolean DEFAULT false NOT NULL,
last_error text,
mode chat_mode
last_error text
);
CREATE TABLE connection_logs (
@@ -1496,9 +1434,7 @@ CREATE TABLE groups (
avatar_url text DEFAULT ''::text NOT NULL,
quota_allowance integer DEFAULT 0 NOT NULL,
display_name text DEFAULT ''::text NOT NULL,
source group_source DEFAULT 'user'::group_source NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
source group_source DEFAULT 'user'::group_source NOT NULL
);
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
@@ -1532,12 +1468,7 @@ CREATE TABLE users (
hashed_one_time_passcode bytea,
one_time_passcode_expires_at timestamp with time zone,
is_system boolean DEFAULT false NOT NULL,
is_service_account boolean DEFAULT false NOT NULL,
chat_spend_limit_micros bigint,
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
);
@@ -1553,8 +1484,6 @@ COMMENT ON COLUMN users.one_time_passcode_expires_at IS 'The time when the one-t
COMMENT ON COLUMN users.is_system IS 'Determines if a user is a system user, and therefore cannot login or perform normal actions';
COMMENT ON COLUMN users.is_service_account IS 'Determines if a user is an admin-managed account that cannot login';
CREATE VIEW group_members_expanded AS
WITH all_members AS (
SELECT group_members.user_id,
@@ -3182,8 +3111,6 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
@@ -3199,9 +3126,6 @@ ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval
ALTER TABLE ONLY workspace_agent_stats
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
ALTER TABLE ONLY ai_seat_state
ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
ALTER TABLE ONLY aibridge_interceptions
ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
@@ -3244,12 +3168,6 @@ ALTER TABLE ONLY chat_providers
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
ALTER TABLE ONLY chat_usage_limit_config
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
ALTER TABLE ONLY chats
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
@@ -3598,11 +3516,7 @@ CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::chat_message_role) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::text) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
@@ -3684,7 +3598,7 @@ CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at
CREATE INDEX idx_user_status_changes_changed_at ON user_status_changes USING btree (changed_at);
CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE ((deleted = false) AND (email <> ''::text));
CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false);
CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false);
@@ -3734,7 +3648,7 @@ CREATE UNIQUE INDEX user_secrets_user_file_path_idx ON user_secrets USING btree
CREATE UNIQUE INDEX user_secrets_user_name_idx ON user_secrets USING btree (user_id, name);
CREATE UNIQUE INDEX users_email_lower_idx ON users USING btree (lower(email)) WHERE ((deleted = false) AND (email <> ''::text));
CREATE UNIQUE INDEX users_email_lower_idx ON users USING btree (lower(email)) WHERE (deleted = false);
CREATE UNIQUE INDEX users_username_lower_idx ON users USING btree (lower(username)) WHERE (deleted = false);
@@ -3870,9 +3784,6 @@ COMMENT ON TRIGGER workspace_agent_name_unique_trigger ON workspace_agents IS 'U
the uniqueness requirement. A trigger allows us to enforce uniqueness going
forward without requiring a migration to clean up historical data.';
ALTER TABLE ONLY ai_seat_state
ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY aibridge_interceptions
ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);

Some files were not shown because too many files have changed in this diff Show More