Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fe682828dd | |||
| 16c9141edd | |||
| 15925edc08 | |||
| 53f13edec8 |
@@ -113,7 +113,7 @@ Coder emphasizes clear error handling, with specific patterns required:
|
||||
|
||||
All tests should run in parallel using `t.Parallel()` to ensure efficient testing and expose potential race conditions. The codebase is rigorously linted with golangci-lint to maintain consistent code quality.
|
||||
|
||||
Git contributions follow [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/). See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
|
||||
Git contributions follow a standard format with commit messages structured as `type: <message>`, where type is one of `feat`, `fix`, or `chore`.
|
||||
|
||||
## Development Workflow
|
||||
|
||||
|
||||
@@ -4,13 +4,22 @@ This guide documents the PR description style used in the Coder repository, base
|
||||
|
||||
## PR Title Format
|
||||
|
||||
Format: `type(scope): description`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
|
||||
Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/) format:
|
||||
|
||||
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
|
||||
- Scopes must be a real path (directory or file stem) containing all changed files
|
||||
- Omit scope if changes span multiple top-level directories
|
||||
```text
|
||||
type(scope): brief description
|
||||
```
|
||||
|
||||
Examples:
|
||||
**Common types:**
|
||||
|
||||
- `feat`: New features
|
||||
- `fix`: Bug fixes
|
||||
- `refactor`: Code refactoring without behavior change
|
||||
- `perf`: Performance improvements
|
||||
- `docs`: Documentation changes
|
||||
- `chore`: Dependency updates, tooling changes
|
||||
|
||||
**Examples:**
|
||||
|
||||
- `feat: add tracing to aibridge`
|
||||
- `fix: move contexts to appropriate locations`
|
||||
|
||||
@@ -136,11 +136,9 @@ Then make your changes and push normally. Don't use `git push --force` unless th
|
||||
|
||||
## Commit Style
|
||||
|
||||
Format: `type(scope): message`. See [CONTRIBUTING.md](docs/about/contributing/CONTRIBUTING.md#commit-messages) for full rules. PR titles are linted in CI.
|
||||
|
||||
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci`, `chore`, `revert`
|
||||
- Scopes must be a real path (directory or file stem) containing all changed files
|
||||
- Omit scope if changes span multiple top-level directories
|
||||
- Follow [Conventional Commits 1.0.0](https://www.conventionalcommits.org/en/v1.0.0/)
|
||||
- Format: `type(scope): message`
|
||||
- Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`
|
||||
- Keep message titles concise (~70 characters)
|
||||
- Use imperative, present tense in commit titles
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ runs:
|
||||
TEST_PACKAGES: ${{ inputs.test-packages }}
|
||||
RACE_DETECTION: ${{ inputs.race-detection }}
|
||||
TS_DEBUG_DISCO: "true"
|
||||
TS_DEBUG_DERP: "true"
|
||||
LC_CTYPE: "en_US.UTF-8"
|
||||
LC_ALL: "en_US.UTF-8"
|
||||
run: |
|
||||
|
||||
@@ -1198,7 +1198,7 @@ jobs:
|
||||
make -j \
|
||||
build/coder_linux_{amd64,arm64,armv7} \
|
||||
build/coder_"$version"_windows_amd64.zip \
|
||||
build/coder_"$version"_linux_{amd64,arm64,armv7}.{tar.gz,deb}
|
||||
build/coder_"$version"_linux_amd64.{tar.gz,deb}
|
||||
env:
|
||||
# The Windows and Darwin slim binaries must be signed for Coder
|
||||
# Desktop to accept them.
|
||||
@@ -1216,28 +1216,11 @@ jobs:
|
||||
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
|
||||
JSIGN_PATH: /tmp/jsign-6.0.jar
|
||||
|
||||
# Free up disk space before building Docker images. The preceding
|
||||
# Build step produces ~2 GB of binaries and packages, the Go build
|
||||
# cache is ~1.3 GB, and node_modules is ~500 MB. Docker image
|
||||
# builds, pushes, and SBOM generation need headroom that isn't
|
||||
# available without reclaiming some of that space.
|
||||
- name: Clean up build cache
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
# Go caches are no longer needed — binaries are already compiled.
|
||||
go clean -cache -modcache
|
||||
# Remove .apk and .rpm packages that are not uploaded as
|
||||
# artifacts and were only built as make prerequisites.
|
||||
rm -f ./build/*.apk ./build/*.rpm
|
||||
|
||||
- name: Build Linux Docker images
|
||||
id: build-docker
|
||||
env:
|
||||
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
|
||||
DOCKER_CLI_EXPERIMENTAL: "enabled"
|
||||
# Skip building .deb/.rpm/.apk/.tar.gz as prerequisites for
|
||||
# the Docker image targets — they were already built above.
|
||||
DOCKER_IMAGE_NO_PREREQUISITES: "true"
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
|
||||
@@ -1455,60 +1438,15 @@ jobs:
|
||||
^v
|
||||
prune-untagged: true
|
||||
|
||||
- name: Upload build artifact (coder-linux-amd64.tar.gz)
|
||||
- name: Upload build artifacts
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-amd64.tar.gz
|
||||
path: ./build/*_linux_amd64.tar.gz
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-linux-amd64.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-amd64.deb
|
||||
path: ./build/*_linux_amd64.deb
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-linux-arm64.tar.gz)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-arm64.tar.gz
|
||||
path: ./build/*_linux_arm64.tar.gz
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-linux-arm64.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-arm64.deb
|
||||
path: ./build/*_linux_arm64.deb
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-linux-armv7.tar.gz)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-armv7.tar.gz
|
||||
path: ./build/*_linux_armv7.tar.gz
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-linux-armv7.deb)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-linux-armv7.deb
|
||||
path: ./build/*_linux_armv7.deb
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload build artifact (coder-windows-amd64.zip)
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
|
||||
with:
|
||||
name: coder-windows-amd64.zip
|
||||
path: ./build/*_windows_amd64.zip
|
||||
name: coder
|
||||
path: |
|
||||
./build/*.zip
|
||||
./build/*.tar.gz
|
||||
./build/*.deb
|
||||
retention-days: 7
|
||||
|
||||
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
|
||||
|
||||
@@ -23,44 +23,6 @@ permissions:
|
||||
concurrency: pr-${{ github.ref }}
|
||||
|
||||
jobs:
|
||||
community-label:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
if: >-
|
||||
${{
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.action == 'opened' &&
|
||||
github.event.pull_request.author_association != 'MEMBER' &&
|
||||
github.event.pull_request.author_association != 'COLLABORATOR' &&
|
||||
github.event.pull_request.author_association != 'OWNER'
|
||||
}}
|
||||
steps:
|
||||
- name: Add community label
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
const params = {
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
}
|
||||
|
||||
const labels = context.payload.pull_request.labels.map((label) => label.name)
|
||||
if (labels.includes("community")) {
|
||||
console.log('PR already has "community" label.')
|
||||
return
|
||||
}
|
||||
|
||||
console.log(
|
||||
'Adding "community" label for author association "%s".',
|
||||
context.payload.pull_request.author_association,
|
||||
)
|
||||
await github.rest.issues.addLabels({
|
||||
...params,
|
||||
labels: ["community"],
|
||||
})
|
||||
|
||||
cla:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
@@ -83,109 +45,6 @@ jobs:
|
||||
# Some users have signed a corporate CLA with Coder so are exempt from signing our community one.
|
||||
allowlist: "coryb,aaronlehmann,dependabot*,blink-so*,blinkagent*"
|
||||
|
||||
title:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event_name == 'pull_request_target' }}
|
||||
steps:
|
||||
- name: Validate PR title
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
const { pull_request } = context.payload;
|
||||
const title = pull_request.title;
|
||||
const repo = { owner: context.repo.owner, repo: context.repo.repo };
|
||||
|
||||
const allowedTypes = [
|
||||
"feat", "fix", "docs", "style", "refactor",
|
||||
"perf", "test", "build", "ci", "chore", "revert",
|
||||
];
|
||||
const expectedFormat = `"type(scope): description" or "type: description"`;
|
||||
const guidelinesLink = `See: https://github.com/coder/coder/blob/main/docs/about/contributing/CONTRIBUTING.md#commit-messages`;
|
||||
const scopeHint = (type) =>
|
||||
`Use a broader scope or no scope (e.g., "${type}: ...") for cross-cutting changes.\n` +
|
||||
guidelinesLink;
|
||||
|
||||
console.log("Title: %s", title);
|
||||
|
||||
// Parse conventional commit format: type(scope)!: description
|
||||
const match = title.match(/^(\w+)(\(([^)]*)\))?(!)?\s*:\s*.+/);
|
||||
if (!match) {
|
||||
core.setFailed(
|
||||
`PR title does not match conventional commit format.\n` +
|
||||
`Expected: ${expectedFormat}\n` +
|
||||
`Allowed types: ${allowedTypes.join(", ")}\n` +
|
||||
guidelinesLink
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const type = match[1];
|
||||
const scope = match[3]; // undefined if no parentheses
|
||||
|
||||
// Validate type.
|
||||
if (!allowedTypes.includes(type)) {
|
||||
core.setFailed(
|
||||
`PR title has invalid type "${type}".\n` +
|
||||
`Expected: ${expectedFormat}\n` +
|
||||
`Allowed types: ${allowedTypes.join(", ")}\n` +
|
||||
guidelinesLink
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// If no scope, we're done.
|
||||
if (!scope) {
|
||||
console.log("No scope provided, title is valid.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("Scope: %s", scope);
|
||||
|
||||
// Fetch changed files.
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
...repo,
|
||||
pull_number: pull_request.number,
|
||||
per_page: 100,
|
||||
});
|
||||
const changedPaths = files.map(f => f.filename);
|
||||
console.log("Changed files: %d", changedPaths.length);
|
||||
|
||||
// Derive scope type from the changed files. The diff is the
|
||||
// source of truth: if files exist under the scope, the path
|
||||
// exists on the PR branch. No need for Contents API calls.
|
||||
const isDir = changedPaths.some(f => f.startsWith(scope + "/"));
|
||||
const isFile = changedPaths.some(f => f === scope);
|
||||
const isStem = changedPaths.some(f => f.startsWith(scope + "."));
|
||||
|
||||
if (!isDir && !isFile && !isStem) {
|
||||
core.setFailed(
|
||||
`PR title scope "${scope}" does not match any files changed in this PR.\n` +
|
||||
`Scopes must reference a path (directory or file stem) that contains changed files.\n` +
|
||||
scopeHint(type)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Verify all changed files fall under the scope.
|
||||
const outsideFiles = changedPaths.filter(f => {
|
||||
if (isDir && f.startsWith(scope + "/")) return false;
|
||||
if (f === scope) return false;
|
||||
if (isStem && f.startsWith(scope + ".")) return false;
|
||||
return true;
|
||||
});
|
||||
|
||||
if (outsideFiles.length > 0) {
|
||||
const listed = outsideFiles.map(f => " - " + f).join("\n");
|
||||
core.setFailed(
|
||||
`PR title scope "${scope}" does not contain all changed files.\n` +
|
||||
`Files outside scope:\n${listed}\n\n` +
|
||||
scopeHint(type)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title is valid.");
|
||||
|
||||
release-labels:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write # to authenticate to EKS cluster
|
||||
id-token: write
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
@@ -82,23 +82,27 @@ jobs:
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # v6.0.0
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
with:
|
||||
role-to-assume: ${{ vars.AWS_DOGFOOD_DEPLOY_ROLE }}
|
||||
aws-region: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
|
||||
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
|
||||
|
||||
- name: Get Cluster Credentials
|
||||
run: aws eks update-kubeconfig --name "$AWS_DOGFOOD_CLUSTER_NAME" --region "$AWS_DOGFOOD_DEPLOY_REGION"
|
||||
env:
|
||||
AWS_DOGFOOD_CLUSTER_NAME: ${{ vars.AWS_DOGFOOD_CLUSTER_NAME }}
|
||||
AWS_DOGFOOD_DEPLOY_REGION: ${{ vars.AWS_DOGFOOD_DEPLOY_REGION }}
|
||||
- name: Set up Google Cloud SDK
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Set up Flux CLI
|
||||
uses: fluxcd/flux2/action@8454b02a32e48d775b9f563cb51fdcb1787b5b93 # v2.7.5
|
||||
with:
|
||||
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
|
||||
version: "2.8.2"
|
||||
version: "2.7.0"
|
||||
|
||||
- name: Get Cluster Credentials
|
||||
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
|
||||
with:
|
||||
cluster_name: dogfood-v2
|
||||
location: us-central1-a
|
||||
project_id: coder-dogfood-v2
|
||||
|
||||
# Retag image as dogfood while maintaining the multi-arch manifest
|
||||
- name: Tag image as dogfood
|
||||
@@ -109,16 +113,16 @@ jobs:
|
||||
- name: Reconcile Flux
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m source git flux-system
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m source git coder-main
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m kustomization flux-system
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m kustomization coder
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m source chart coder-coder
|
||||
flux --namespace flux-system reconcile --verbose --timeout=5m source chart coder-coder-provisioner
|
||||
flux --namespace coder reconcile --verbose --timeout=10m helmrelease coder
|
||||
flux --namespace coder reconcile --verbose --timeout=10m helmrelease coder-provisioner
|
||||
flux --namespace coder reconcile --verbose --timeout=10m helmrelease coder-provisioner-tagged
|
||||
flux --namespace coder reconcile --verbose --timeout=10m helmrelease coder-provisioner-tagged-prebuilds
|
||||
flux --namespace flux-system reconcile source git flux-system
|
||||
flux --namespace flux-system reconcile source git coder-main
|
||||
flux --namespace flux-system reconcile kustomization flux-system
|
||||
flux --namespace flux-system reconcile kustomization coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder-provisioner
|
||||
flux --namespace coder reconcile helmrelease coder
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner-tagged
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner-tagged-prebuilds
|
||||
|
||||
# Just updating Flux is usually not enough. The Helm release may get
|
||||
# redeployed, but unless something causes the Deployment to update the
|
||||
|
||||
@@ -700,9 +700,11 @@ jobs:
|
||||
name: Publish to Homebrew tap
|
||||
runs-on: ubuntu-latest
|
||||
needs: release
|
||||
if: ${{ !inputs.dry_run && inputs.release_channel == 'mainline' }}
|
||||
if: ${{ !inputs.dry_run }}
|
||||
|
||||
steps:
|
||||
# TODO: skip this if it's not a new release (i.e. a backport). This is
|
||||
# fine right now because it just makes a PR that we can close.
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
with:
|
||||
|
||||
@@ -100,31 +100,6 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
|
||||
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
- Add swagger annotations when introducing new HTTP endpoints. Do this in
|
||||
the same change as the handler so the docs do not get missed before
|
||||
release.
|
||||
- For user-scoped or resource-scoped routes, prefer path parameters over
|
||||
query parameters when that matches existing route patterns.
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
- Use `ByX` when `X` is the lookup or filter column.
|
||||
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
|
||||
dimension.
|
||||
- Avoid `ByX` names for grouped queries.
|
||||
|
||||
### Database-to-SDK Conversions
|
||||
|
||||
- Extract explicit db-to-SDK conversion helpers instead of inlining large
|
||||
conversion blocks inside handlers.
|
||||
- Keep nullable-field handling, type coercion, and response shaping in the
|
||||
converter so handlers stay focused on request flow and authorization.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
@@ -209,21 +184,6 @@ seems like it should use `time.Sleep`, read through https://github.com/coder/qua
|
||||
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
|
||||
- Commit format: `type(scope): message`
|
||||
|
||||
### Frontend Patterns
|
||||
|
||||
- Prefer existing shared UI components and utilities over custom
|
||||
implementations. Reuse common primitives such as loading, table, and error
|
||||
handling components when they fit the use case.
|
||||
- Use Storybook stories for all component and page testing, including
|
||||
visual presentation, user interactions, keyboard navigation, focus
|
||||
management, and accessibility behavior. Do not create standalone
|
||||
vitest/RTL test files for components or pages. Stories double as living
|
||||
documentation, visual regression coverage, and interaction test suites
|
||||
via `play` functions. Reserve plain vitest files for pure logic only:
|
||||
utility functions, data transformations, hooks tested via
|
||||
`renderHook()` that do not require DOM assertions, and query/cache
|
||||
operations with no rendered output.
|
||||
|
||||
### Writing Comments
|
||||
|
||||
Code comments should be clear, well-formatted, and add meaningful context.
|
||||
|
||||
@@ -136,10 +136,18 @@ endif
|
||||
# the search path so that these exclusions match.
|
||||
FIND_EXCLUSIONS= \
|
||||
-not \( \( -path '*/.git/*' -o -path './build/*' -o -path './vendor/*' -o -path './.coderv2/*' -o -path '*/node_modules/*' -o -path '*/out/*' -o -path './coderd/apidoc/*' -o -path '*/.next/*' -o -path '*/.terraform/*' -o -path './_gen/*' \) -prune \)
|
||||
|
||||
# Source files used for make targets, evaluated on use.
|
||||
GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -name '*_test.go')
|
||||
|
||||
# Same as GO_SRC_FILES but excluding certain files that have problematic
|
||||
# Makefile dependencies (e.g. pnpm).
|
||||
MOST_GO_SRC_FILES := $(shell \
|
||||
find . \
|
||||
$(FIND_EXCLUSIONS) \
|
||||
-type f \
|
||||
-name '*.go' \
|
||||
-not -name '*_test.go' \
|
||||
-not -wholename './agent/agentcontainers/dcspec/dcspec_gen.go' \
|
||||
)
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
@@ -506,12 +514,6 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
cp "$<" "$$output_file"
|
||||
.PHONY: install
|
||||
|
||||
# Only wildcard the go files in the develop directory to avoid rebuilds
|
||||
# when project files are changd. Technically changes to some imports may
|
||||
# not be detected, but it's unlikely to cause any issues.
|
||||
build/.bin/develop: go.mod go.sum $(wildcard scripts/develop/*.go)
|
||||
CGO_ENABLED=0 go build -o $@ ./scripts/develop
|
||||
|
||||
BOLD := $(shell tput bold 2>/dev/null)
|
||||
GREEN := $(shell tput setaf 2 2>/dev/null)
|
||||
RED := $(shell tput setaf 1 2>/dev/null)
|
||||
|
||||
+2
-16
@@ -39,7 +39,6 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/clistat"
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentgit"
|
||||
@@ -311,7 +310,6 @@ type agent struct {
|
||||
filesAPI *agentfiles.API
|
||||
gitAPI *agentgit.API
|
||||
processAPI *agentproc.API
|
||||
desktopAPI *agentdesktop.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -385,18 +383,10 @@ func (a *agent) init() {
|
||||
|
||||
pathStore := agentgit.NewPathStore()
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem, pathStore)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore, func() string {
|
||||
if m := a.manifest.Load(); m != nil {
|
||||
return m.Directory
|
||||
}
|
||||
return ""
|
||||
})
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv, pathStore)
|
||||
gitOpts := append([]agentgit.Option{agentgit.WithClock(a.clock)}, a.gitAPIOptions...)
|
||||
a.gitAPI = agentgit.NewAPI(a.logger.Named("git"), pathStore, gitOpts...)
|
||||
desktop := agentdesktop.NewPortableDesktop(
|
||||
a.logger.Named("desktop"), a.execer, a.scriptRunner.ScriptBinDir(),
|
||||
)
|
||||
a.desktopAPI = agentdesktop.NewAPI(a.logger.Named("desktop"), desktop, a.clock)
|
||||
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
a.sshServer,
|
||||
@@ -2067,10 +2057,6 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.desktopAPI.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "desktop API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,536 +0,0 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// DesktopAction is the request body for the desktop action endpoint.
|
||||
type DesktopAction struct {
|
||||
Action string `json:"action"`
|
||||
Coordinate *[2]int `json:"coordinate,omitempty"`
|
||||
StartCoordinate *[2]int `json:"start_coordinate,omitempty"`
|
||||
Text *string `json:"text,omitempty"`
|
||||
Duration *int `json:"duration,omitempty"`
|
||||
ScrollAmount *int `json:"scroll_amount,omitempty"`
|
||||
ScrollDirection *string `json:"scroll_direction,omitempty"`
|
||||
// ScaledWidth and ScaledHeight are the coordinate space the
|
||||
// model is using. When provided, coordinates are linearly
|
||||
// mapped from scaled → native before dispatching.
|
||||
ScaledWidth *int `json:"scaled_width,omitempty"`
|
||||
ScaledHeight *int `json:"scaled_height,omitempty"`
|
||||
}
|
||||
|
||||
// DesktopActionResponse is the response from the desktop action
|
||||
// endpoint.
|
||||
type DesktopActionResponse struct {
|
||||
Output string `json:"output,omitempty"`
|
||||
ScreenshotData string `json:"screenshot_data,omitempty"`
|
||||
ScreenshotWidth int `json:"screenshot_width,omitempty"`
|
||||
ScreenshotHeight int `json:"screenshot_height,omitempty"`
|
||||
}
|
||||
|
||||
// API exposes the desktop streaming HTTP routes for the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
desktop Desktop
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewAPI creates a new desktop streaming API.
|
||||
func NewAPI(logger slog.Logger, desktop Desktop, clock quartz.Clock) *API {
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
return &API{
|
||||
logger: logger,
|
||||
desktop: desktop,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
// Routes returns the chi router for mounting at /api/v0/desktop.
|
||||
func (a *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/vnc", a.handleDesktopVNC)
|
||||
r.Post("/action", a.handleAction)
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *API) handleDesktopVNC(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Start the desktop session (idempotent).
|
||||
_, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Get a VNC connection.
|
||||
vncConn, err := a.desktop.VNCConn(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to connect to VNC server.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Accept WebSocket from coderd.
|
||||
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// No read limit — RFB framebuffer updates can be large.
|
||||
conn.SetReadLimit(-1)
|
||||
|
||||
wsCtx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
// Bicopy raw bytes between WebSocket and VNC TCP.
|
||||
agentssh.Bicopy(wsCtx, wsNetConn, vncConn)
|
||||
}
|
||||
|
||||
func (a *API) handleAction(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
handlerStart := a.clock.Now()
|
||||
|
||||
// Ensure the desktop is running and grab native dimensions.
|
||||
cfg, err := a.desktop.Start(ctx)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: desktop.Start failed",
|
||||
slog.Error(err),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start desktop session.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var action DesktopAction
|
||||
if err := json.NewDecoder(r.Body).Decode(&action); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to decode request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.Info(ctx, "handleAction: started",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
|
||||
// Helper to scale a coordinate pair from the model's space to
|
||||
// native display pixels.
|
||||
scaleXY := func(x, y int) (int, int) {
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
x = scaleCoordinate(x, *action.ScaledWidth, cfg.Width)
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
y = scaleCoordinate(y, *action.ScaledHeight, cfg.Height)
|
||||
}
|
||||
return x, y
|
||||
}
|
||||
|
||||
var resp DesktopActionResponse
|
||||
|
||||
switch action.Action {
|
||||
case "key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.KeyPress(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key press failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "key action performed"
|
||||
|
||||
case "type":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for type action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := a.desktop.Type(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Type action failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "type action performed"
|
||||
|
||||
case "cursor_position":
|
||||
x, y, err := a.desktop.CursorPosition(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Cursor position failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "x=" + strconv.Itoa(x) + ",y=" + strconv.Itoa(y)
|
||||
|
||||
case "mouse_move":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Move(ctx, x, y); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Mouse move failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "mouse_move action performed"
|
||||
|
||||
case "left_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
stepStart := a.clock.Now()
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: Click failed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step", "click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "handleAction: Click completed",
|
||||
slog.F("action", "left_click"),
|
||||
slog.F("step_ms", time.Since(stepStart).Milliseconds()),
|
||||
slog.F("elapsed_ms", a.clock.Since(handlerStart).Milliseconds()),
|
||||
)
|
||||
resp.Output = "left_click action performed"
|
||||
|
||||
case "left_click_drag":
|
||||
if action.Coordinate == nil || action.StartCoordinate == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"coordinate\" or \"start_coordinate\" for left_click_drag.",
|
||||
})
|
||||
return
|
||||
}
|
||||
sx, sy := scaleXY(action.StartCoordinate[0], action.StartCoordinate[1])
|
||||
ex, ey := scaleXY(action.Coordinate[0], action.Coordinate[1])
|
||||
if err := a.desktop.Drag(ctx, sx, sy, ex, ey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left click drag failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_click_drag action performed"
|
||||
|
||||
case "left_mouse_down":
|
||||
if err := a.desktop.ButtonDown(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_down action performed"
|
||||
|
||||
case "left_mouse_up":
|
||||
if err := a.desktop.ButtonUp(ctx, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Left mouse up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "left_mouse_up action performed"
|
||||
|
||||
case "right_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonRight); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Right click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "right_click action performed"
|
||||
|
||||
case "middle_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonMiddle); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Middle click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "middle_click action performed"
|
||||
|
||||
case "double_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
if err := a.desktop.DoubleClick(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Double click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "double_click action performed"
|
||||
|
||||
case "triple_click":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
for range 3 {
|
||||
if err := a.desktop.Click(ctx, x, y, MouseButtonLeft); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Triple click failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
resp.Output = "triple_click action performed"
|
||||
|
||||
case "scroll":
|
||||
x, y, err := coordFromAction(action)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
x, y = scaleXY(x, y)
|
||||
|
||||
amount := 3
|
||||
if action.ScrollAmount != nil {
|
||||
amount = *action.ScrollAmount
|
||||
}
|
||||
direction := "down"
|
||||
if action.ScrollDirection != nil {
|
||||
direction = *action.ScrollDirection
|
||||
}
|
||||
|
||||
var dx, dy int
|
||||
switch direction {
|
||||
case "up":
|
||||
dy = -amount
|
||||
case "down":
|
||||
dy = amount
|
||||
case "left":
|
||||
dx = -amount
|
||||
case "right":
|
||||
dx = amount
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid scroll direction: " + direction,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.desktop.Scroll(ctx, x, y, dx, dy); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Scroll failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "scroll action performed"
|
||||
|
||||
case "hold_key":
|
||||
if action.Text == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Missing \"text\" for hold_key action.",
|
||||
})
|
||||
return
|
||||
}
|
||||
dur := 1000
|
||||
if action.Duration != nil {
|
||||
dur = *action.Duration
|
||||
}
|
||||
if err := a.desktop.KeyDown(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key down failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
timer := a.clock.NewTimer(time.Duration(dur)*time.Millisecond, "agentdesktop", "hold_key")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context canceled; release the key immediately.
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
a.logger.Warn(ctx, "handleAction: KeyUp after context cancel", slog.Error(err))
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
if err := a.desktop.KeyUp(ctx, *action.Text); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Key up failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "hold_key action performed"
|
||||
|
||||
case "screenshot":
|
||||
var opts ScreenshotOptions
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 {
|
||||
opts.TargetWidth = *action.ScaledWidth
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 {
|
||||
opts.TargetHeight = *action.ScaledHeight
|
||||
}
|
||||
result, err := a.desktop.Screenshot(ctx, opts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Screenshot failed.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
resp.Output = "screenshot"
|
||||
resp.ScreenshotData = result.Data
|
||||
if action.ScaledWidth != nil && *action.ScaledWidth > 0 && *action.ScaledWidth != cfg.Width {
|
||||
resp.ScreenshotWidth = *action.ScaledWidth
|
||||
} else {
|
||||
resp.ScreenshotWidth = cfg.Width
|
||||
}
|
||||
if action.ScaledHeight != nil && *action.ScaledHeight > 0 && *action.ScaledHeight != cfg.Height {
|
||||
resp.ScreenshotHeight = *action.ScaledHeight
|
||||
} else {
|
||||
resp.ScreenshotHeight = cfg.Height
|
||||
}
|
||||
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Unknown action: " + action.Action,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
elapsedMs := a.clock.Since(handlerStart).Milliseconds()
|
||||
if ctx.Err() != nil {
|
||||
a.logger.Error(ctx, "handleAction: context canceled before writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return
|
||||
}
|
||||
a.logger.Info(ctx, "handleAction: writing response",
|
||||
slog.F("action", action.Action),
|
||||
slog.F("elapsed_ms", elapsedMs),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session if one is running.
|
||||
func (a *API) Close() error {
|
||||
return a.desktop.Close()
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
// returning an error if the coordinate field is missing.
|
||||
func coordFromAction(action DesktopAction) (x, y int, err error) {
|
||||
if action.Coordinate == nil {
|
||||
return 0, 0, &missingFieldError{field: "coordinate", action: action.Action}
|
||||
}
|
||||
return action.Coordinate[0], action.Coordinate[1], nil
|
||||
}
|
||||
|
||||
// missingFieldError is returned when a required field is absent from
|
||||
// a DesktopAction.
|
||||
type missingFieldError struct {
|
||||
field string
|
||||
action string
|
||||
}
|
||||
|
||||
func (e *missingFieldError) Error() string {
|
||||
return "Missing \"" + e.field + "\" for " + e.action + " action."
|
||||
}
|
||||
|
||||
// scaleCoordinate maps a coordinate from scaled → native space.
|
||||
func scaleCoordinate(scaled, scaledDim, nativeDim int) int {
|
||||
if scaledDim == 0 || scaledDim == nativeDim {
|
||||
return scaled
|
||||
}
|
||||
native := (float64(scaled)+0.5)*float64(nativeDim)/float64(scaledDim) - 0.5
|
||||
// Clamp to valid range.
|
||||
native = math.Max(native, 0)
|
||||
native = math.Min(native, float64(nativeDim-1))
|
||||
return int(native)
|
||||
}
|
||||
@@ -1,467 +0,0 @@
|
||||
package agentdesktop_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentdesktop"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Ensure fakeDesktop satisfies the Desktop interface at compile time.
|
||||
var _ agentdesktop.Desktop = (*fakeDesktop)(nil)
|
||||
|
||||
// fakeDesktop is a minimal Desktop implementation for unit tests.
|
||||
type fakeDesktop struct {
|
||||
startErr error
|
||||
startCfg agentdesktop.DisplayConfig
|
||||
vncConnErr error
|
||||
screenshotErr error
|
||||
screenshotRes agentdesktop.ScreenshotResult
|
||||
closed bool
|
||||
|
||||
// Track calls for assertions.
|
||||
lastMove [2]int
|
||||
lastClick [3]int // x, y, button
|
||||
lastScroll [4]int // x, y, dx, dy
|
||||
lastKey string
|
||||
lastTyped string
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Start(context.Context) (agentdesktop.DisplayConfig, error) {
|
||||
return f.startCfg, f.startErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) VNCConn(context.Context) (net.Conn, error) {
|
||||
return nil, f.vncConnErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Screenshot(_ context.Context, _ agentdesktop.ScreenshotOptions) (agentdesktop.ScreenshotResult, error) {
|
||||
return f.screenshotRes, f.screenshotErr
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Move(_ context.Context, x, y int) error {
|
||||
f.lastMove = [2]int{x, y}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Click(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 1}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) DoubleClick(_ context.Context, x, y int, _ agentdesktop.MouseButton) error {
|
||||
f.lastClick = [3]int{x, y, 2}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) ButtonDown(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
func (*fakeDesktop) ButtonUp(context.Context, agentdesktop.MouseButton) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) Scroll(_ context.Context, x, y, dx, dy int) error {
|
||||
f.lastScroll = [4]int{x, y, dx, dy}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) Drag(context.Context, int, int, int, int) error { return nil }
|
||||
|
||||
func (f *fakeDesktop) KeyPress(_ context.Context, key string) error {
|
||||
f.lastKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyDown(_ context.Context, key string) error {
|
||||
f.lastKeyDown = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) KeyUp(_ context.Context, key string) error {
|
||||
f.lastKeyUp = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Type(_ context.Context, text string) error {
|
||||
f.lastTyped = text
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*fakeDesktop) CursorPosition(context.Context) (x int, y int, err error) {
|
||||
return 10, 20, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandleDesktopVNC_StartError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{startErr: xerrors.New("no desktop")}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Failed to start desktop session.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: workspacesdk.DesktopDisplayWidth, Height: workspacesdk.DesktopDisplayHeight},
|
||||
screenshotRes: agentdesktop.ScreenshotResult{Data: "base64data"},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "screenshot"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var result agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&result)
|
||||
require.NoError(t, err)
|
||||
// Dimensions come from DisplayConfig, not the screenshot CLI.
|
||||
assert.Equal(t, "screenshot", result.Output)
|
||||
assert.Equal(t, "base64data", result.ScreenshotData)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayWidth, result.ScreenshotWidth)
|
||||
assert.Equal(t, workspacesdk.DesktopDisplayHeight, result.ScreenshotHeight)
|
||||
}
|
||||
|
||||
func TestHandleAction_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "left_click",
|
||||
Coordinate: &[2]int{100, 200},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "left_click action performed", resp.Output)
|
||||
assert.Equal(t, [3]int{100, 200, 1}, fake.lastClick)
|
||||
}
|
||||
|
||||
func TestHandleAction_UnknownAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "explode"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
}
|
||||
|
||||
func TestHandleAction_KeyAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "Return"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "key",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "Return", fake.lastKey)
|
||||
}
|
||||
|
||||
func TestHandleAction_TypeAction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
text := "hello world"
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "type",
|
||||
Text: &text,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "hello world", fake.lastTyped)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
mClk := quartz.NewMock(t)
|
||||
trap := mClk.Trap().NewTimer("agentdesktop", "hold_key")
|
||||
defer trap.Close()
|
||||
api := agentdesktop.NewAPI(logger, fake, mClk)
|
||||
defer api.Close()
|
||||
|
||||
text := "Shift_L"
|
||||
dur := 100
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "hold_key",
|
||||
Text: &text,
|
||||
Duration: &dur,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
handler.ServeHTTP(rr, req)
|
||||
}()
|
||||
|
||||
// Wait for the timer to be created, then advance past it.
|
||||
trap.MustWait(req.Context()).MustRelease(req.Context())
|
||||
mClk.Advance(time.Duration(dur) * time.Millisecond).MustWait(req.Context())
|
||||
|
||||
<-done
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
var resp agentdesktop.DesktopActionResponse
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hold_key action performed", resp.Output)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyDown)
|
||||
assert.Equal(t, "Shift_L", fake.lastKeyUp)
|
||||
}
|
||||
|
||||
func TestHandleAction_HoldKeyMissingText(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
body := agentdesktop.DesktopAction{Action: "hold_key"}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rr.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err = json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Missing \"text\" for hold_key action.", resp.Message)
|
||||
}
|
||||
|
||||
func TestHandleAction_ScrollDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
dir := "down"
|
||||
amount := 5
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "scroll",
|
||||
Coordinate: &[2]int{500, 400},
|
||||
ScrollDirection: &dir,
|
||||
ScrollAmount: &amount,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// dy should be positive 5 for "down".
|
||||
assert.Equal(t, [4]int{500, 400, 0, 5}, fake.lastScroll)
|
||||
}
|
||||
|
||||
func TestHandleAction_CoordinateScaling(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
// Native display is 1920x1080.
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
// Model is working in a 1280x720 coordinate space.
|
||||
sw := 1280
|
||||
sh := 720
|
||||
body := agentdesktop.DesktopAction{
|
||||
Action: "mouse_move",
|
||||
Coordinate: &[2]int{640, 360},
|
||||
ScaledWidth: &sw,
|
||||
ScaledHeight: &sh,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/action", bytes.NewReader(b))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
// 640 in 1280-space → 960 in 1920-space (midpoint maps to
|
||||
// midpoint).
|
||||
assert.Equal(t, 960, fake.lastMove[0])
|
||||
assert.Equal(t, 540, fake.lastMove[1])
|
||||
}
|
||||
|
||||
func TestClose_DelegatesToDesktop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, fake.closed)
|
||||
}
|
||||
|
||||
func TestClose_PreventsNewSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// After Close(), Start() will return an error because the
|
||||
// underlying Desktop is closed.
|
||||
fake := &fakeDesktop{}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
|
||||
err := api.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the closed desktop returning an error on Start().
|
||||
fake.startErr = xerrors.New("desktop is closed")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/vnc", nil)
|
||||
|
||||
handler := api.Routes()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Desktop abstracts a virtual desktop session running inside a workspace.
|
||||
type Desktop interface {
|
||||
// Start launches the desktop session. It is idempotent — calling
|
||||
// Start on an already-running session returns the existing
|
||||
// config. The returned DisplayConfig describes the running
|
||||
// session.
|
||||
Start(ctx context.Context) (DisplayConfig, error)
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames. Each call returns a new
|
||||
// connection; multiple clients can connect simultaneously.
|
||||
// Start must be called before VNCConn.
|
||||
VNCConn(ctx context.Context) (net.Conn, error)
|
||||
|
||||
// Screenshot captures the current framebuffer as a PNG and
|
||||
// returns it base64-encoded. TargetWidth/TargetHeight in opts
|
||||
// are the desired output dimensions (the implementation
|
||||
// rescales); pass 0 to use native resolution.
|
||||
Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error)
|
||||
|
||||
// Mouse operations.
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
Move(ctx context.Context, x, y int) error
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
Click(ctx context.Context, x, y int, button MouseButton) error
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
DoubleClick(ctx context.Context, x, y int, button MouseButton) error
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
ButtonDown(ctx context.Context, button MouseButton) error
|
||||
// ButtonUp releases a mouse button.
|
||||
ButtonUp(ctx context.Context, button MouseButton) error
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
Scroll(ctx context.Context, x, y, dx, dy int) error
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding
|
||||
// the left mouse button.
|
||||
Drag(ctx context.Context, startX, startY, endX, endY int) error
|
||||
|
||||
// Keyboard operations.
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string
|
||||
// (e.g. "Return", "ctrl+c").
|
||||
KeyPress(ctx context.Context, keys string) error
|
||||
// KeyDown presses and holds a key.
|
||||
KeyDown(ctx context.Context, key string) error
|
||||
// KeyUp releases a key.
|
||||
KeyUp(ctx context.Context, key string) error
|
||||
// Type types a string of text character-by-character.
|
||||
Type(ctx context.Context, text string) error
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
CursorPosition(ctx context.Context) (x, y int, err error)
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
type DisplayConfig struct {
|
||||
Width int // native width in pixels
|
||||
Height int // native height in pixels
|
||||
VNCPort int // local TCP port for the VNC server
|
||||
Display int // X11 display number (e.g. 1 for :1), -1 if N/A
|
||||
}
|
||||
|
||||
// MouseButton identifies a mouse button.
|
||||
type MouseButton string
|
||||
|
||||
const (
|
||||
MouseButtonLeft MouseButton = "left"
|
||||
MouseButtonRight MouseButton = "right"
|
||||
MouseButtonMiddle MouseButton = "middle"
|
||||
)
|
||||
|
||||
// ScreenshotOptions configures a screenshot capture.
|
||||
type ScreenshotOptions struct {
|
||||
TargetWidth int // 0 = native
|
||||
TargetHeight int // 0 = native
|
||||
}
|
||||
|
||||
// ScreenshotResult is a captured screenshot.
|
||||
type ScreenshotResult struct {
|
||||
Data string // base64-encoded PNG
|
||||
}
|
||||
@@ -1,399 +0,0 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// portableDesktopOutput is the JSON output from
|
||||
// `portabledesktop up --json`.
|
||||
type portableDesktopOutput struct {
|
||||
VNCPort int `json:"vncPort"`
|
||||
Geometry string `json:"geometry"` // e.g. "1920x1080"
|
||||
}
|
||||
|
||||
// desktopSession tracks a running portabledesktop process.
|
||||
type desktopSession struct {
|
||||
cmd *exec.Cmd
|
||||
vncPort int
|
||||
width int // native width, parsed from geometry
|
||||
height int // native height, parsed from geometry
|
||||
display int // X11 display number, -1 if not available
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// cursorOutput is the JSON output from `portabledesktop cursor --json`.
|
||||
type cursorOutput struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
}
|
||||
|
||||
// screenshotOutput is the JSON output from
|
||||
// `portabledesktop screenshot --json`.
|
||||
type screenshotOutput struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// portableDesktop implements Desktop by shelling out to the
|
||||
// portabledesktop CLI via agentexec.Execer.
|
||||
type portableDesktop struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
scriptBinDir string // coder script bin directory
|
||||
|
||||
mu sync.Mutex
|
||||
session *desktopSession // nil until started
|
||||
binPath string // resolved path to binary, cached
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewPortableDesktop creates a Desktop backed by the portabledesktop
|
||||
// CLI binary, using execer to spawn child processes. scriptBinDir is
|
||||
// the coder script bin directory checked for the binary.
|
||||
func NewPortableDesktop(
|
||||
logger slog.Logger,
|
||||
execer agentexec.Execer,
|
||||
scriptBinDir string,
|
||||
) Desktop {
|
||||
return &portableDesktop{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the desktop session (idempotent).
|
||||
func (p *portableDesktop) Start(ctx context.Context) (DisplayConfig, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return DisplayConfig{}, xerrors.New("desktop is closed")
|
||||
}
|
||||
|
||||
if err := p.ensureBinary(ctx); err != nil {
|
||||
return DisplayConfig{}, xerrors.Errorf("ensure portabledesktop binary: %w", err)
|
||||
}
|
||||
|
||||
// If we have an existing session, check if it's still alive.
|
||||
if p.session != nil {
|
||||
if !(p.session.cmd.ProcessState != nil && p.session.cmd.ProcessState.Exited()) {
|
||||
return DisplayConfig{
|
||||
Width: p.session.width,
|
||||
Height: p.session.height,
|
||||
VNCPort: p.session.vncPort,
|
||||
Display: p.session.display,
|
||||
}, nil
|
||||
}
|
||||
// Process died — clean up and recreate.
|
||||
p.logger.Warn(ctx, "portabledesktop process died, recreating session")
|
||||
p.session.cancel()
|
||||
p.session = nil
|
||||
}
|
||||
|
||||
// Spawn portabledesktop up --json.
|
||||
sessionCtx, sessionCancel := context.WithCancel(context.Background())
|
||||
|
||||
//nolint:gosec // portabledesktop is a trusted binary resolved via ensureBinary.
|
||||
cmd := p.execer.CommandContext(sessionCtx, p.binPath, "up", "--json",
|
||||
"--geometry", fmt.Sprintf("%dx%d", workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight))
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
sessionCancel()
|
||||
return DisplayConfig{}, xerrors.Errorf("start portabledesktop: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON output to get VNC port and geometry.
|
||||
var output portableDesktopOutput
|
||||
if err := json.NewDecoder(stdout).Decode(&output); err != nil {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.Errorf("parse portabledesktop output: %w", err)
|
||||
}
|
||||
|
||||
if output.VNCPort == 0 {
|
||||
sessionCancel()
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
return DisplayConfig{}, xerrors.New("portabledesktop returned port 0")
|
||||
}
|
||||
|
||||
var w, h int
|
||||
if output.Geometry != "" {
|
||||
if _, err := fmt.Sscanf(output.Geometry, "%dx%d", &w, &h); err != nil {
|
||||
p.logger.Warn(ctx, "failed to parse geometry, using defaults",
|
||||
slog.F("geometry", output.Geometry),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info(ctx, "started portabledesktop session",
|
||||
slog.F("vnc_port", output.VNCPort),
|
||||
slog.F("width", w),
|
||||
slog.F("height", h),
|
||||
slog.F("pid", cmd.Process.Pid),
|
||||
)
|
||||
|
||||
p.session = &desktopSession{
|
||||
cmd: cmd,
|
||||
vncPort: output.VNCPort,
|
||||
width: w,
|
||||
height: h,
|
||||
display: -1,
|
||||
cancel: sessionCancel,
|
||||
}
|
||||
|
||||
return DisplayConfig{
|
||||
Width: w,
|
||||
Height: h,
|
||||
VNCPort: output.VNCPort,
|
||||
Display: -1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VNCConn dials the desktop's VNC server and returns a raw
|
||||
// net.Conn carrying RFB binary frames.
|
||||
func (p *portableDesktop) VNCConn(_ context.Context) (net.Conn, error) {
|
||||
p.mu.Lock()
|
||||
session := p.session
|
||||
p.mu.Unlock()
|
||||
|
||||
if session == nil {
|
||||
return nil, xerrors.New("desktop session not started")
|
||||
}
|
||||
|
||||
return net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", session.vncPort))
|
||||
}
|
||||
|
||||
// Screenshot captures the current framebuffer as a base64-encoded PNG.
|
||||
func (p *portableDesktop) Screenshot(ctx context.Context, opts ScreenshotOptions) (ScreenshotResult, error) {
|
||||
args := []string{"screenshot", "--json"}
|
||||
if opts.TargetWidth > 0 {
|
||||
args = append(args, "--target-width", strconv.Itoa(opts.TargetWidth))
|
||||
}
|
||||
if opts.TargetHeight > 0 {
|
||||
args = append(args, "--target-height", strconv.Itoa(opts.TargetHeight))
|
||||
}
|
||||
|
||||
out, err := p.runCmd(ctx, args...)
|
||||
if err != nil {
|
||||
return ScreenshotResult{}, err
|
||||
}
|
||||
|
||||
var result screenshotOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return ScreenshotResult{}, xerrors.Errorf("parse screenshot output: %w", err)
|
||||
}
|
||||
|
||||
return ScreenshotResult(result), nil
|
||||
}
|
||||
|
||||
// Move moves the mouse cursor to absolute coordinates.
|
||||
func (p *portableDesktop) Move(ctx context.Context, x, y int) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y))
|
||||
return err
|
||||
}
|
||||
|
||||
// Click performs a mouse button click at the given coordinates.
|
||||
func (p *portableDesktop) Click(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// DoubleClick performs a double-click at the given coordinates.
|
||||
func (p *portableDesktop) DoubleClick(ctx context.Context, x, y int, button MouseButton) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "click", string(button)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "click", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonDown presses and holds a mouse button.
|
||||
func (p *portableDesktop) ButtonDown(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "down", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// ButtonUp releases a mouse button.
|
||||
func (p *portableDesktop) ButtonUp(ctx context.Context, button MouseButton) error {
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(button))
|
||||
return err
|
||||
}
|
||||
|
||||
// Scroll scrolls by (dx, dy) clicks at the given coordinates.
|
||||
func (p *portableDesktop) Scroll(ctx context.Context, x, y, dx, dy int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(x), strconv.Itoa(y)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "scroll", strconv.Itoa(dx), strconv.Itoa(dy))
|
||||
return err
|
||||
}
|
||||
|
||||
// Drag moves from (startX,startY) to (endX,endY) while holding the
|
||||
// left mouse button.
|
||||
func (p *portableDesktop) Drag(ctx context.Context, startX, startY, endX, endY int) error {
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(startX), strconv.Itoa(startY)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "down", string(MouseButtonLeft)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := p.runCmd(ctx, "mouse", "move", strconv.Itoa(endX), strconv.Itoa(endY)); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := p.runCmd(ctx, "mouse", "up", string(MouseButtonLeft))
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyPress sends a key-down then key-up for a key combo string.
|
||||
func (p *portableDesktop) KeyPress(ctx context.Context, keys string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "key", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyDown presses and holds a key.
|
||||
func (p *portableDesktop) KeyDown(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "down", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// KeyUp releases a key.
|
||||
func (p *portableDesktop) KeyUp(ctx context.Context, key string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "up", key)
|
||||
return err
|
||||
}
|
||||
|
||||
// Type types a string of text character-by-character.
|
||||
func (p *portableDesktop) Type(ctx context.Context, text string) error {
|
||||
_, err := p.runCmd(ctx, "keyboard", "type", text)
|
||||
return err
|
||||
}
|
||||
|
||||
// CursorPosition returns the current cursor coordinates.
|
||||
func (p *portableDesktop) CursorPosition(ctx context.Context) (x int, y int, err error) {
|
||||
out, err := p.runCmd(ctx, "cursor", "--json")
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var result cursorOutput
|
||||
if err := json.Unmarshal([]byte(out), &result); err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse cursor output: %w", err)
|
||||
}
|
||||
|
||||
return result.X, result.Y, nil
|
||||
}
|
||||
|
||||
// Close shuts down the desktop session and cleans up resources.
|
||||
func (p *portableDesktop) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.closed = true
|
||||
if p.session != nil {
|
||||
p.session.cancel()
|
||||
// Xvnc is a child process — killing it cleans up the X
|
||||
// session.
|
||||
_ = p.session.cmd.Process.Kill()
|
||||
_ = p.session.cmd.Wait()
|
||||
p.session = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runCmd executes a portabledesktop subcommand and returns combined
|
||||
// output. The caller must have previously called ensureBinary.
|
||||
func (p *portableDesktop) runCmd(ctx context.Context, args ...string) (string, error) {
|
||||
start := time.Now()
|
||||
//nolint:gosec // args are constructed by the caller, not user input.
|
||||
cmd := p.execer.CommandContext(ctx, p.binPath, args...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "portabledesktop command failed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
slog.Error(err),
|
||||
slog.F("output", string(out)),
|
||||
)
|
||||
return "", xerrors.Errorf("portabledesktop %s: %w: %s", args[0], err, string(out))
|
||||
}
|
||||
if elapsed > 5*time.Second {
|
||||
p.logger.Warn(ctx, "portabledesktop command slow",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
} else {
|
||||
p.logger.Debug(ctx, "portabledesktop command completed",
|
||||
slog.F("args", args),
|
||||
slog.F("elapsed_ms", elapsed.Milliseconds()),
|
||||
)
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ensureBinary resolves the portabledesktop binary from PATH or the
|
||||
// coder script bin directory. It must be called while p.mu is held.
|
||||
func (p *portableDesktop) ensureBinary(ctx context.Context) error {
|
||||
if p.binPath != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. Check PATH.
|
||||
if path, err := exec.LookPath("portabledesktop"); err == nil {
|
||||
p.logger.Info(ctx, "found portabledesktop in PATH",
|
||||
slog.F("path", path),
|
||||
)
|
||||
p.binPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. Check the coder script bin directory.
|
||||
scriptBinPath := filepath.Join(p.scriptBinDir, "portabledesktop")
|
||||
if info, err := os.Stat(scriptBinPath); err == nil && !info.IsDir() {
|
||||
// On Windows, permission bits don't indicate executability,
|
||||
// so accept any regular file.
|
||||
if runtime.GOOS == "windows" || info.Mode()&0o111 != 0 {
|
||||
p.logger.Info(ctx, "found portabledesktop in script bin directory",
|
||||
slog.F("path", scriptBinPath),
|
||||
)
|
||||
p.binPath = scriptBinPath
|
||||
return nil
|
||||
}
|
||||
p.logger.Warn(ctx, "portabledesktop found in script bin directory but not executable",
|
||||
slog.F("path", scriptBinPath),
|
||||
slog.F("mode", info.Mode().String()),
|
||||
)
|
||||
}
|
||||
|
||||
return xerrors.New("portabledesktop binary not found in PATH or script bin directory")
|
||||
}
|
||||
@@ -1,545 +0,0 @@
|
||||
package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/pty"
|
||||
)
|
||||
|
||||
// recordedExecer implements agentexec.Execer by recording every
|
||||
// invocation and delegating to a real shell command built from a
|
||||
// caller-supplied mapping of subcommand → shell script body.
|
||||
type recordedExecer struct {
|
||||
mu sync.Mutex
|
||||
commands [][]string
|
||||
// scripts maps a subcommand keyword (e.g. "up", "screenshot")
|
||||
// to a shell snippet whose stdout will be the command output.
|
||||
scripts map[string]string
|
||||
}
|
||||
|
||||
func (r *recordedExecer) record(cmd string, args ...string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.commands = append(r.commands, append([]string{cmd}, args...))
|
||||
}
|
||||
|
||||
func (r *recordedExecer) allCommands() [][]string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
out := make([][]string, len(r.commands))
|
||||
copy(out, r.commands)
|
||||
return out
|
||||
}
|
||||
|
||||
// scriptFor finds the first matching script key present in args.
|
||||
func (r *recordedExecer) scriptFor(args []string) string {
|
||||
for _, a := range args {
|
||||
if s, ok := r.scripts[a]; ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: succeed silently.
|
||||
return "true"
|
||||
}
|
||||
|
||||
func (r *recordedExecer) CommandContext(ctx context.Context, cmd string, args ...string) *exec.Cmd {
|
||||
r.record(cmd, args...)
|
||||
script := r.scriptFor(args)
|
||||
//nolint:gosec // Test helper — script content is controlled by the test.
|
||||
return exec.CommandContext(ctx, "sh", "-c", script)
|
||||
}
|
||||
|
||||
func (r *recordedExecer) PTYCommandContext(ctx context.Context, cmd string, args ...string) *pty.Cmd {
|
||||
r.record(cmd, args...)
|
||||
return pty.CommandContext(ctx, "sh", "-c", r.scriptFor(args))
|
||||
}
|
||||
|
||||
// --- portableDesktop tests ---
|
||||
|
||||
func TestPortableDesktop_Start_ParsesOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// The "up" script prints the JSON line then sleeps until
|
||||
// the context is canceled (simulating a long-running process).
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop", // pre-set so ensureBinary is a no-op
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
cfg, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1920, cfg.Width)
|
||||
assert.Equal(t, 1080, cfg.Height)
|
||||
assert.Equal(t, 5901, cfg.VNCPort)
|
||||
assert.Equal(t, -1, cfg.Display)
|
||||
|
||||
// Clean up the long-running process.
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Start_Idempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
cfg1, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg1, cfg2, "second Start should return the same config")
|
||||
|
||||
// The execer should have been called exactly once for "up".
|
||||
cmds := rec.allCommands()
|
||||
upCalls := 0
|
||||
for _, c := range cmds {
|
||||
for _, a := range c {
|
||||
if a == "up" {
|
||||
upCalls++
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, upCalls, "expected exactly one 'up' invocation")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"abc123"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
result, err := pd.Screenshot(ctx, ScreenshotOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "abc123", result.Data)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Screenshot_WithTargetDimensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"screenshot": `echo '{"data":"x"}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.Screenshot(ctx, ScreenshotOptions{
|
||||
TargetWidth: 800,
|
||||
TargetHeight: 600,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
// The last command should contain the target dimension flags.
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
assert.Contains(t, joined, "--target-width 800")
|
||||
assert.Contains(t, joined, "--target-height 600")
|
||||
}
|
||||
|
||||
func TestPortableDesktop_MouseMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Each sub-test verifies a single mouse method dispatches the
|
||||
// correct CLI arguments.
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string // substrings expected in a recorded command
|
||||
}{
|
||||
{
|
||||
name: "Move",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Move(ctx, 42, 99)
|
||||
},
|
||||
wantArgs: []string{"mouse", "move", "42", "99"},
|
||||
},
|
||||
{
|
||||
name: "Click",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Click(ctx, 10, 20, MouseButtonLeft)
|
||||
},
|
||||
// Click does move then click.
|
||||
wantArgs: []string{"mouse", "click", "left"},
|
||||
},
|
||||
{
|
||||
name: "DoubleClick",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.DoubleClick(ctx, 5, 6, MouseButtonRight)
|
||||
},
|
||||
wantArgs: []string{"mouse", "click", "right"},
|
||||
},
|
||||
{
|
||||
name: "ButtonDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonDown(ctx, MouseButtonMiddle)
|
||||
},
|
||||
wantArgs: []string{"mouse", "down", "middle"},
|
||||
},
|
||||
{
|
||||
name: "ButtonUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.ButtonUp(ctx, MouseButtonLeft)
|
||||
},
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
{
|
||||
name: "Scroll",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Scroll(ctx, 50, 60, 3, 4)
|
||||
},
|
||||
wantArgs: []string{"mouse", "scroll", "3", "4"},
|
||||
},
|
||||
{
|
||||
name: "Drag",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Drag(ctx, 10, 20, 30, 40)
|
||||
},
|
||||
// Drag ends with mouse up left.
|
||||
wantArgs: []string{"mouse", "up", "left"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"mouse": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds, "expected at least one command")
|
||||
|
||||
// Find at least one recorded command that contains
|
||||
// all expected argument substrings.
|
||||
found := false
|
||||
for _, cmd := range cmds {
|
||||
joined := strings.Join(cmd, " ")
|
||||
match := true
|
||||
for _, want := range tt.wantArgs {
|
||||
if !strings.Contains(joined, want) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found,
|
||||
"no recorded command matched %v; got %v", tt.wantArgs, cmds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_KeyboardMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
invoke func(context.Context, *portableDesktop) error
|
||||
wantArgs []string
|
||||
}{
|
||||
{
|
||||
name: "KeyPress",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyPress(ctx, "Return")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "key", "Return"},
|
||||
},
|
||||
{
|
||||
name: "KeyDown",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyDown(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "down", "shift"},
|
||||
},
|
||||
{
|
||||
name: "KeyUp",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.KeyUp(ctx, "shift")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "up", "shift"},
|
||||
},
|
||||
{
|
||||
name: "Type",
|
||||
invoke: func(ctx context.Context, pd *portableDesktop) error {
|
||||
return pd.Type(ctx, "hello world")
|
||||
},
|
||||
wantArgs: []string{"keyboard", "type", "hello world"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"keyboard": `echo ok`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
err := tt.invoke(t.Context(), pd)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmds := rec.allCommands()
|
||||
require.NotEmpty(t, cmds)
|
||||
|
||||
last := cmds[len(cmds)-1]
|
||||
joined := strings.Join(last, " ")
|
||||
for _, want := range tt.wantArgs {
|
||||
assert.Contains(t, joined, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortableDesktop_CursorPosition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"cursor": `echo '{"x":100,"y":200}'`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
x, y, err := pd.CursorPosition(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, x)
|
||||
assert.Equal(t, 200, y)
|
||||
}
|
||||
|
||||
func TestPortableDesktop_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1024x768"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "portabledesktop",
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
_, err := pd.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Session should exist.
|
||||
pd.mu.Lock()
|
||||
require.NotNil(t, pd.session)
|
||||
pd.mu.Unlock()
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
|
||||
// Session should be cleaned up.
|
||||
pd.mu.Lock()
|
||||
assert.Nil(t, pd.session)
|
||||
assert.True(t, pd.closed)
|
||||
pd.mu.Unlock()
|
||||
|
||||
// Subsequent Start must fail.
|
||||
_, err = pd.Start(ctx)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "desktop is closed")
|
||||
}
|
||||
|
||||
// --- ensureBinary tests ---
|
||||
|
||||
func TestEnsureBinary_UsesCachedBinPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// When binPath is already set, ensureBinary should return
|
||||
// immediately without doing any work.
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(),
|
||||
binPath: "/already/set",
|
||||
}
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/already/set", pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_UsesScriptBinDir(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
require.NoError(t, os.Chmod(binPath, 0o755))
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binPath, pd.binPath)
|
||||
}
|
||||
|
||||
func TestEnsureBinary_ScriptBinDirNotExecutable(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Windows does not support Unix permission bits")
|
||||
}
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
|
||||
scriptBinDir := t.TempDir()
|
||||
binPath := filepath.Join(scriptBinDir, "portabledesktop")
|
||||
// Write without execute permission.
|
||||
require.NoError(t, os.WriteFile(binPath, []byte("#!/bin/sh\n"), 0o600))
|
||||
_ = binPath
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: scriptBinDir,
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestEnsureBinary_NotFound(t *testing.T) {
|
||||
// Cannot use t.Parallel because t.Setenv modifies the process
|
||||
// environment.
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: agentexec.DefaultExecer,
|
||||
scriptBinDir: t.TempDir(), // empty directory
|
||||
}
|
||||
|
||||
// Clear PATH so LookPath won't find a real binary.
|
||||
t.Setenv("PATH", "")
|
||||
|
||||
err := pd.ensureBinary(t.Context())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
// Ensure that portableDesktop satisfies the Desktop interface at
|
||||
// compile time. This uses the unexported type so it lives in the
|
||||
// internal test package.
|
||||
var _ Desktop = (*portableDesktop)(nil)
|
||||
+38
-89
@@ -447,10 +447,13 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
content := string(data)
|
||||
|
||||
for _, edit := range edits {
|
||||
var err error
|
||||
content, err = fuzzyReplace(content, edit)
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, xerrors.Errorf("edit %s: %w", path, err)
|
||||
var ok bool
|
||||
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
|
||||
if !ok {
|
||||
api.logger.Warn(ctx, "edit search string not found, skipping",
|
||||
slog.F("path", path),
|
||||
slog.F("search_preview", truncate(edit.Search, 64)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,92 +480,51 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace it
|
||||
// with `replace`. It uses a cascading match strategy inspired by
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace its first
|
||||
// occurrence with `replace`. It uses a cascading match strategy inspired by
|
||||
// openai/codex's apply_patch:
|
||||
//
|
||||
// 1. Exact substring match (byte-for-byte).
|
||||
// 2. Line-by-line match ignoring trailing whitespace on each line.
|
||||
// 3. Line-by-line match ignoring all leading/trailing whitespace
|
||||
// (indentation-tolerant).
|
||||
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
|
||||
//
|
||||
// When edit.ReplaceAll is false (the default), the search string must
|
||||
// match exactly one location. If multiple matches are found, an error
|
||||
// is returned asking the caller to include more context or set
|
||||
// replace_all.
|
||||
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
|
||||
// at the byte offsets of the original content so that surrounding text (including
|
||||
// indentation of untouched lines) is preserved.
|
||||
//
|
||||
// When a fuzzy match is found (passes 2 or 3), the replacement is still
|
||||
// applied at the byte offsets of the original content so that surrounding
|
||||
// text (including indentation of untouched lines) is preserved.
|
||||
func fuzzyReplace(content string, edit workspacesdk.FileEdit) (string, error) {
|
||||
search := edit.Search
|
||||
replace := edit.Replace
|
||||
|
||||
// Pass 1 – exact substring match.
|
||||
// Returns the (possibly modified) content and a bool indicating whether a match
|
||||
// was found.
|
||||
func fuzzyReplace(content, search, replace string) (string, bool) {
|
||||
// Pass 1 – exact substring (replace all occurrences).
|
||||
if strings.Contains(content, search) {
|
||||
if edit.ReplaceAll {
|
||||
return strings.ReplaceAll(content, search, replace), nil
|
||||
}
|
||||
count := strings.Count(content, search)
|
||||
if count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
// Exactly one match.
|
||||
return strings.Replace(content, search, replace, 1), nil
|
||||
return strings.ReplaceAll(content, search, replace), true
|
||||
}
|
||||
|
||||
// For line-level fuzzy matching we split both content and search
|
||||
// into lines.
|
||||
// For line-level fuzzy matching we split both content and search into lines.
|
||||
contentLines := strings.SplitAfter(content, "\n")
|
||||
searchLines := strings.SplitAfter(search, "\n")
|
||||
|
||||
// A trailing newline in the search produces an empty final element
|
||||
// from SplitAfter. Drop it so it doesn't interfere with line
|
||||
// matching.
|
||||
// A trailing newline in the search produces an empty final element from
|
||||
// SplitAfter. Drop it so it doesn't interfere with line matching.
|
||||
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
|
||||
searchLines = searchLines[:len(searchLines)-1]
|
||||
}
|
||||
|
||||
trimRight := func(a, b string) bool {
|
||||
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
|
||||
}
|
||||
trimAll := func(a, b string) bool {
|
||||
return strings.TrimSpace(a) == strings.TrimSpace(b)
|
||||
}
|
||||
|
||||
// Pass 2 – trim trailing whitespace on each line.
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimRight); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimRight); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
// Pass 3 – trim all leading and trailing whitespace
|
||||
// (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, trimAll); ok {
|
||||
if !edit.ReplaceAll {
|
||||
if count := countLineMatches(contentLines, searchLines, trimAll); count > 1 {
|
||||
return "", xerrors.Errorf("search string matches %d occurrences "+
|
||||
"(expected exactly 1). Include more surrounding "+
|
||||
"context to make the match unique, or set "+
|
||||
"replace_all to true", count)
|
||||
}
|
||||
}
|
||||
return spliceLines(contentLines, start, end, replace), nil
|
||||
// Pass 3 – trim all leading and trailing whitespace (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
return strings.TrimSpace(a) == strings.TrimSpace(b)
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
return "", xerrors.New("search string not found in file. Verify the search " +
|
||||
"string matches the file content exactly, including whitespace " +
|
||||
"and indentation")
|
||||
return content, false
|
||||
}
|
||||
|
||||
// seekLines scans contentLines looking for a contiguous subsequence that matches
|
||||
@@ -587,26 +549,6 @@ outer:
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// countLineMatches counts how many non-overlapping contiguous
|
||||
// subsequences of contentLines match searchLines according to eq.
|
||||
func countLineMatches(contentLines, searchLines []string, eq func(a, b string) bool) int {
|
||||
count := 0
|
||||
if len(searchLines) == 0 || len(searchLines) > len(contentLines) {
|
||||
return count
|
||||
}
|
||||
outer:
|
||||
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
|
||||
for j, sLine := range searchLines {
|
||||
if !eq(contentLines[i+j], sLine) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
count++
|
||||
i += len(searchLines) - 1 // skip past this match
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// spliceLines replaces contentLines[start:end] with replacement text, returning
|
||||
// the full content as a single string.
|
||||
func spliceLines(contentLines []string, start, end int, replacement string) string {
|
||||
@@ -620,3 +562,10 @@ func spliceLines(contentLines []string, start, end int, replacement string) stri
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
@@ -576,9 +576,7 @@ func TestEditFiles(t *testing.T) {
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
|
||||
},
|
||||
{
|
||||
// When the second edit creates ambiguity (two "bar"
|
||||
// occurrences), it should fail.
|
||||
name: "EditEditAmbiguous",
|
||||
name: "EditEdit", // Edits affect previous edits.
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
@@ -595,33 +593,7 @@ func TestEditFiles(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"matches 2 occurrences"},
|
||||
// File should not be modified on error.
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
|
||||
},
|
||||
{
|
||||
// With replace_all the cascading edit replaces
|
||||
// both occurrences.
|
||||
name: "EditEditReplaceAll",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "edit-edit-ra"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
{
|
||||
Search: "bar",
|
||||
Replace: "qux",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit-ra"): "qux qux"},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
|
||||
},
|
||||
{
|
||||
name: "Multiline",
|
||||
@@ -748,7 +720,7 @@ func TestEditFiles(t *testing.T) {
|
||||
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
|
||||
},
|
||||
{
|
||||
name: "NoMatchErrors",
|
||||
name: "NoMatchStillSucceeds",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
@@ -761,46 +733,9 @@ func TestEditFiles(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"search string not found in file"},
|
||||
// File should remain unchanged.
|
||||
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
},
|
||||
{
|
||||
name: "AmbiguousExactMatch",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ambig-exact"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "qux",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"matches 3 occurrences"},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ambig-exact"): "foo bar foo baz foo"},
|
||||
},
|
||||
{
|
||||
name: "ReplaceAllExact",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "ra-exact"): "foo bar foo baz foo"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "ra-exact"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "qux",
|
||||
ReplaceAll: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "ra-exact"): "qux bar qux baz qux"},
|
||||
},
|
||||
{
|
||||
name: "MixedWhitespaceMultiline",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
|
||||
|
||||
@@ -26,10 +26,10 @@ type API struct {
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore, workingDir func() string) *API {
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), pathStore *agentgit.PathStore) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv, workingDir),
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
pathStore: pathStore,
|
||||
}
|
||||
}
|
||||
|
||||
+3
-105
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -98,25 +97,18 @@ func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.
|
||||
// execer, returning the handler and API.
|
||||
func newTestAPI(t *testing.T) http.Handler {
|
||||
t.Helper()
|
||||
return newTestAPIWithOptions(t, nil, nil)
|
||||
return newTestAPIWithUpdateEnv(t, nil)
|
||||
}
|
||||
|
||||
// newTestAPIWithUpdateEnv creates a new API with an optional
|
||||
// updateEnv hook for testing environment injection.
|
||||
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
|
||||
t.Helper()
|
||||
return newTestAPIWithOptions(t, updateEnv, nil)
|
||||
}
|
||||
|
||||
// newTestAPIWithOptions creates a new API with optional
|
||||
// updateEnv and workingDir hooks.
|
||||
func newTestAPIWithOptions(t *testing.T, updateEnv func([]string) ([]string, error), workingDir func() string) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil, workingDir)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv, nil)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
@@ -261,100 +253,6 @@ func TestStartProcess(t *testing.T) {
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirIsHome", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// No working directory closure, so the process
|
||||
// should fall back to $HOME. We verify through
|
||||
// the process list API which reports the resolved
|
||||
// working directory using native OS paths,
|
||||
// avoiding shell path format mismatches on
|
||||
// Windows (Git Bash returns POSIX paths).
|
||||
handler := newTestAPI(t)
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo ok",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var listResp workspacesdk.ListProcessesResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
|
||||
var proc *workspacesdk.ProcessInfo
|
||||
for i := range listResp.Processes {
|
||||
if listResp.Processes[i].ID == id {
|
||||
proc = &listResp.Processes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, proc, "process not found in list")
|
||||
require.Equal(t, homeDir, proc.WorkDir)
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirFromClosure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The closure provides a valid directory, so the
|
||||
// process should start there. Use the marker file
|
||||
// pattern to avoid path format mismatches on
|
||||
// Windows.
|
||||
tmpDir := t.TempDir()
|
||||
handler := newTestAPIWithOptions(t, nil, func() string {
|
||||
return tmpDir
|
||||
})
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "touch marker.txt && ls marker.txt",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("DefaultWorkDirClosureNonExistentFallsBackToHome", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The closure returns a path that doesn't exist,
|
||||
// so the process should fall back to $HOME.
|
||||
handler := newTestAPIWithOptions(t, nil, func() string {
|
||||
return "/tmp/nonexistent-dir-" + fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
})
|
||||
|
||||
homeDir, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo ok",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var listResp workspacesdk.ListProcessesResponse
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&listResp))
|
||||
var proc *workspacesdk.ProcessInfo
|
||||
for i := range listResp.Processes {
|
||||
if listResp.Processes[i].ID == id {
|
||||
proc = &listResp.Processes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, proc, "process not found in list")
|
||||
require.Equal(t, homeDir, proc.WorkDir)
|
||||
})
|
||||
|
||||
t.Run("CustomEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -883,7 +781,7 @@ func TestHandleStartProcess_ChatHeaders_EmptyWorkDir_StillNotifies(t *testing.T)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, func(current []string) ([]string, error) {
|
||||
return current, nil
|
||||
}, pathStore, nil)
|
||||
}, pathStore)
|
||||
defer api.Close()
|
||||
|
||||
routes := api.Routes()
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// procSysProcAttr returns the SysProcAttr to use when spawning
|
||||
// processes. On Unix, Setpgid creates a new process group so
|
||||
// that signals can be delivered to the entire group (the shell
|
||||
// and all its children).
|
||||
func procSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// signalProcess sends a signal to the process group rooted at p.
|
||||
// Using the negative PID sends the signal to every process in the
|
||||
// group, ensuring child processes (e.g. from shell pipelines) are
|
||||
// also signaled.
|
||||
func signalProcess(p *os.Process, sig syscall.Signal) error {
|
||||
return syscall.Kill(-p.Pid, sig)
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// procSysProcAttr returns the SysProcAttr to use when spawning
|
||||
// processes. On Windows, process groups are not supported in the
|
||||
// same way as Unix, so this returns an empty struct.
|
||||
func procSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{}
|
||||
}
|
||||
|
||||
// signalProcess sends a signal directly to the process. Windows
|
||||
// does not support process group signaling, so we fall back to
|
||||
// sending the signal to the process itself.
|
||||
func signalProcess(p *os.Process, _ syscall.Signal) error {
|
||||
return p.Kill()
|
||||
}
|
||||
+21
-45
@@ -70,25 +70,23 @@ func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
|
||||
|
||||
// manager tracks processes spawned by the agent.
|
||||
type manager struct {
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
updateEnv func(current []string) (updated []string, err error)
|
||||
workingDir func() string
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
updateEnv func(current []string) (updated []string, err error)
|
||||
}
|
||||
|
||||
// newManager creates a new process manager.
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error), workingDir func() string) *manager {
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
|
||||
return &manager{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
updateEnv: updateEnv,
|
||||
workingDir: workingDir,
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
updateEnv: updateEnv,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,9 +109,10 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
// the process is not tied to any HTTP request.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
|
||||
cmd.Dir = m.resolveWorkDir(req.WorkDir)
|
||||
if req.WorkDir != "" {
|
||||
cmd.Dir = req.WorkDir
|
||||
}
|
||||
cmd.Stdin = nil
|
||||
cmd.SysProcAttr = procSysProcAttr()
|
||||
|
||||
// WaitDelay ensures cmd.Wait returns promptly after
|
||||
// the process is killed, even if child processes are
|
||||
@@ -158,7 +157,7 @@ func (m *manager) start(req workspacesdk.StartProcessRequest, chatID string) (*p
|
||||
proc := &process{
|
||||
id: id,
|
||||
command: req.Command,
|
||||
workDir: cmd.Dir,
|
||||
workDir: req.WorkDir,
|
||||
background: req.Background,
|
||||
chatID: chatID,
|
||||
cmd: cmd,
|
||||
@@ -273,15 +272,13 @@ func (m *manager) signal(id string, sig string) error {
|
||||
|
||||
switch sig {
|
||||
case "kill":
|
||||
// Use process group kill to ensure child processes
|
||||
// (e.g. from shell pipelines) are also killed.
|
||||
if err := signalProcess(proc.cmd.Process, syscall.SIGKILL); err != nil {
|
||||
if err := proc.cmd.Process.Kill(); err != nil {
|
||||
return xerrors.Errorf("kill process: %w", err)
|
||||
}
|
||||
case "terminate":
|
||||
// Use process group signal to ensure child processes
|
||||
// are also terminated.
|
||||
if err := signalProcess(proc.cmd.Process, syscall.SIGTERM); err != nil {
|
||||
//nolint:revive // syscall.SIGTERM is portable enough
|
||||
// for our supported platforms.
|
||||
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
return xerrors.Errorf("terminate process: %w", err)
|
||||
}
|
||||
default:
|
||||
@@ -319,24 +316,3 @@ func (m *manager) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveWorkDir returns the directory a process should start in.
|
||||
// Priority: explicit request dir > agent configured dir > $HOME.
|
||||
// Falls through when a candidate is empty or does not exist on
|
||||
// disk, matching the behavior of SSH sessions.
|
||||
func (m *manager) resolveWorkDir(requested string) string {
|
||||
if requested != "" {
|
||||
return requested
|
||||
}
|
||||
if m.workingDir != nil {
|
||||
if dir := m.workingDir(); dir != "" {
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
return dir
|
||||
}
|
||||
}
|
||||
}
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return home
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ func (a *agent) apiHandler() http.Handler {
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/git", a.gitAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
r.Mount("/api/v0/desktop", a.desktopAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -22,6 +23,26 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// logSink captures structured log entries for testing.
|
||||
type logSink struct {
|
||||
mu sync.Mutex
|
||||
entries []slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *logSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.entries = append(s.entries, e)
|
||||
}
|
||||
|
||||
func (*logSink) Sync() {}
|
||||
|
||||
func (s *logSink) getEntries() []slog.SinkEntry {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return append([]slog.SinkEntry{}, s.entries...)
|
||||
}
|
||||
|
||||
// getField returns the value of a field by name from a slog.Map.
|
||||
func getField(fields slog.Map, name string) interface{} {
|
||||
for _, f := range fields {
|
||||
@@ -55,8 +76,8 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
sink := testutil.NewFakeSink(t)
|
||||
logger := sink.Logger(slog.LevelInfo)
|
||||
sink := &logSink{}
|
||||
logger := slog.Make(sink)
|
||||
workspaceID := uuid.New()
|
||||
templateID := uuid.New()
|
||||
templateVersionID := uuid.New()
|
||||
@@ -97,10 +118,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
sendBoundaryLogsRequest(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.Entries()) >= 1
|
||||
return len(sink.getEntries()) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries := sink.Entries()
|
||||
entries := sink.getEntries()
|
||||
require.Len(t, entries, 1)
|
||||
entry := entries[0]
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
@@ -131,10 +152,10 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
sendBoundaryLogsRequest(t, conn, req2)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(sink.Entries()) >= 2
|
||||
return len(sink.getEntries()) >= 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
entries = sink.Entries()
|
||||
entries = sink.getEntries()
|
||||
entry = entries[1]
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, slog.LevelInfo, entry.Level)
|
||||
|
||||
+8
-31
@@ -2,7 +2,6 @@ package reaper
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-reap"
|
||||
|
||||
@@ -43,42 +42,20 @@ func WithLogger(logger slog.Logger) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithReaperStop sets a channel that, when closed, stops the reaper
|
||||
// WithDone sets a channel that, when closed, stops the reaper
|
||||
// goroutine. Callers that invoke ForkReap more than once in the
|
||||
// same process (e.g. tests) should use this to prevent goroutine
|
||||
// accumulation.
|
||||
func WithReaperStop(ch chan struct{}) Option {
|
||||
func WithDone(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.ReaperStop = ch
|
||||
}
|
||||
}
|
||||
|
||||
// WithReaperStopped sets a channel that is closed after the
|
||||
// reaper goroutine has fully exited.
|
||||
func WithReaperStopped(ch chan struct{}) Option {
|
||||
return func(o *options) {
|
||||
o.ReaperStopped = ch
|
||||
}
|
||||
}
|
||||
|
||||
// WithReapLock sets a mutex shared between the reaper and Wait4.
|
||||
// The reaper holds the write lock while reaping, and ForkReap
|
||||
// holds the read lock during Wait4, preventing the reaper from
|
||||
// stealing the child's exit status. This is only needed for
|
||||
// tests with instant-exit children where the race window is
|
||||
// large.
|
||||
func WithReapLock(mu *sync.RWMutex) Option {
|
||||
return func(o *options) {
|
||||
o.ReapLock = mu
|
||||
o.Done = ch
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
ReaperStop chan struct{}
|
||||
ReaperStopped chan struct{}
|
||||
ReapLock *sync.RWMutex
|
||||
ExecArgs []string
|
||||
PIDs reap.PidCh
|
||||
CatchSignals []os.Signal
|
||||
Logger slog.Logger
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
+34
-99
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -19,82 +18,35 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// subprocessEnvKey is set when a test re-execs itself as an
|
||||
// isolated subprocess. Tests that call ForkReap or send signals
|
||||
// to their own process check this to decide whether to run real
|
||||
// test logic or launch the subprocess and wait for it.
|
||||
const subprocessEnvKey = "CODER_REAPER_TEST_SUBPROCESS"
|
||||
// withDone returns an option that stops the reaper goroutine when t
|
||||
// completes, preventing goroutine accumulation across subtests.
|
||||
func withDone(t *testing.T) reaper.Option {
|
||||
t.Helper()
|
||||
done := make(chan struct{})
|
||||
t.Cleanup(func() { close(done) })
|
||||
return reaper.WithDone(done)
|
||||
}
|
||||
|
||||
// runSubprocess re-execs the current test binary in a new process
|
||||
// running only the named test. This isolates ForkReap's
|
||||
// syscall.ForkExec and any process-directed signals (e.g. SIGINT)
|
||||
// from the parent test binary, making these tests safe to run in
|
||||
// CI and alongside other tests.
|
||||
// TestReap checks that's the reaper is successfully reaping
|
||||
// exited processes and passing the PIDs through the shared
|
||||
// channel.
|
||||
//
|
||||
// Returns true inside the subprocess (caller should proceed with
|
||||
// the real test logic). Returns false in the parent after the
|
||||
// subprocess exits successfully (caller should return).
|
||||
func runSubprocess(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
if os.Getenv(subprocessEnvKey) == "1" {
|
||||
return true
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
//nolint:gosec // Test-controlled arguments.
|
||||
cmd := exec.CommandContext(ctx, os.Args[0],
|
||||
"-test.run=^"+t.Name()+"$",
|
||||
"-test.v",
|
||||
)
|
||||
cmd.Env = append(os.Environ(), subprocessEnvKey+"=1")
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
t.Logf("Subprocess output:\n%s", out)
|
||||
require.NoError(t, err, "subprocess failed")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// withDone returns options that stop the reaper goroutine when t
|
||||
// completes and wait for it to fully exit, preventing
|
||||
// overlapping reapers across sequential subtests.
|
||||
func withDone(t *testing.T) []reaper.Option {
|
||||
t.Helper()
|
||||
stop := make(chan struct{})
|
||||
stopped := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
close(stop)
|
||||
<-stopped
|
||||
})
|
||||
return []reaper.Option{
|
||||
reaper.WithReaperStop(stop),
|
||||
reaper.WithReaperStopped(stopped),
|
||||
}
|
||||
}
|
||||
|
||||
// TestReap checks that the reaper successfully reaps exited
|
||||
// processes and passes their PIDs through the shared channel.
|
||||
//nolint:paralleltest
|
||||
func TestReap(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
}
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
pids := make(reap.PidCh, 1)
|
||||
var reapLock sync.RWMutex
|
||||
opts := append([]reaper.Option{
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
// Provide some argument that immediately exits.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", "exit 0"),
|
||||
reaper.WithReapLock(&reapLock),
|
||||
}, withDone(t)...)
|
||||
reapLock.RLock()
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
reapLock.RUnlock()
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, exitCode)
|
||||
|
||||
@@ -114,7 +66,7 @@ func TestReap(t *testing.T) {
|
||||
|
||||
expectedPIDs := []int{cmd.Process.Pid, cmd2.Process.Pid}
|
||||
|
||||
for range len(expectedPIDs) {
|
||||
for i := 0; i < len(expectedPIDs); i++ {
|
||||
select {
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatalf("Timed out waiting for process")
|
||||
@@ -124,15 +76,11 @@ func TestReap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:tparallel // Subtests must be sequential, each starts its own reaper.
|
||||
//nolint:paralleltest
|
||||
func TestForkReapExitCodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
}
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -147,35 +95,26 @@ func TestForkReapExitCodes(t *testing.T) {
|
||||
{"SIGTERM", "kill -15 $$", 128 + 15},
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Subtests must be sequential, each starts its own reaper.
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var reapLock sync.RWMutex
|
||||
opts := append([]reaper.Option{
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithExecArgs("/bin/sh", "-c", tt.command),
|
||||
reaper.WithReapLock(&reapLock),
|
||||
}, withDone(t)...)
|
||||
reapLock.RLock()
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
reapLock.RUnlock()
|
||||
withDone(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedCode, exitCode, "exit code mismatch for %q", tt.command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReapInterrupt verifies that ForkReap forwards caught signals
|
||||
// to the child process. The test sends SIGINT to its own process
|
||||
// and checks that the child receives it. Running in a subprocess
|
||||
// ensures SIGINT cannot kill the parent test binary.
|
||||
//nolint:paralleltest // Signal handling.
|
||||
func TestReapInterrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Don't run the reaper test in CI. It does weird
|
||||
// things like forkexecing which may have unintended
|
||||
// consequences in CI.
|
||||
if testutil.InCI() {
|
||||
t.Skip("Detected CI, skipping reaper tests")
|
||||
}
|
||||
if !runSubprocess(t) {
|
||||
return
|
||||
}
|
||||
|
||||
errC := make(chan error, 1)
|
||||
pids := make(reap.PidCh, 1)
|
||||
@@ -187,28 +126,24 @@ func TestReapInterrupt(t *testing.T) {
|
||||
defer signal.Stop(usrSig)
|
||||
|
||||
go func() {
|
||||
opts := append([]reaper.Option{
|
||||
exitCode, err := reaper.ForkReap(
|
||||
reaper.WithPIDCallback(pids),
|
||||
reaper.WithCatchSignals(os.Interrupt),
|
||||
withDone(t),
|
||||
// Signal propagation does not extend to children of children, so
|
||||
// we create a little bash script to ensure sleep is interrupted.
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf(
|
||||
"pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait",
|
||||
os.Getpid(), os.Getpid(),
|
||||
)),
|
||||
}, withDone(t)...)
|
||||
exitCode, err := reaper.ForkReap(opts...)
|
||||
reaper.WithExecArgs("/bin/sh", "-c", fmt.Sprintf("pid=0; trap 'kill -USR2 %d; kill -TERM $pid' INT; sleep 10 &\npid=$!; kill -USR1 %d; wait", os.Getpid(), os.Getpid())),
|
||||
)
|
||||
// The child exits with 128 + SIGTERM (15) = 143, but the trap catches
|
||||
// SIGINT and sends SIGTERM to the sleep process, so exit code varies.
|
||||
_ = exitCode
|
||||
errC <- err
|
||||
}()
|
||||
|
||||
require.Equal(t, syscall.SIGUSR1, <-usrSig)
|
||||
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR1)
|
||||
err := syscall.Kill(os.Getpid(), syscall.SIGINT)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, <-usrSig, syscall.SIGUSR2)
|
||||
|
||||
require.Equal(t, syscall.SIGUSR2, <-usrSig)
|
||||
require.NoError(t, <-errC)
|
||||
}
|
||||
|
||||
+14
-24
@@ -19,36 +19,31 @@ func IsInitProcess() bool {
|
||||
return os.Getpid() == 1
|
||||
}
|
||||
|
||||
// startSignalForwarding registers signal handlers synchronously
|
||||
// then forwards caught signals to the child in a background
|
||||
// goroutine. Registering before the goroutine starts ensures no
|
||||
// signal is lost between ForkExec and the handler being ready.
|
||||
func startSignalForwarding(logger slog.Logger, pid int, sigs []os.Signal) {
|
||||
func catchSignals(logger slog.Logger, pid int, sigs []os.Signal) {
|
||||
if len(sigs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sc := make(chan os.Signal, 1)
|
||||
signal.Notify(sc, sigs...)
|
||||
defer signal.Stop(sc)
|
||||
|
||||
logger.Info(context.Background(), "reaper catching signals",
|
||||
slog.F("signals", sigs),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
|
||||
go func() {
|
||||
defer signal.Stop(sc)
|
||||
for s := range sc {
|
||||
sig, ok := s.(syscall.Signal)
|
||||
if ok {
|
||||
logger.Info(context.Background(), "reaper caught signal, killing child process",
|
||||
slog.F("signal", sig.String()),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
_ = syscall.Kill(pid, sig)
|
||||
}
|
||||
for {
|
||||
s := <-sc
|
||||
sig, ok := s.(syscall.Signal)
|
||||
if ok {
|
||||
logger.Info(context.Background(), "reaper caught signal, killing child process",
|
||||
slog.F("signal", sig.String()),
|
||||
slog.F("child_pid", pid),
|
||||
)
|
||||
_ = syscall.Kill(pid, sig)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// ForkReap spawns a goroutine that reaps children. In order to avoid
|
||||
@@ -69,12 +64,7 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
o(opts)
|
||||
}
|
||||
|
||||
go func() {
|
||||
reap.ReapChildren(opts.PIDs, nil, opts.ReaperStop, opts.ReapLock)
|
||||
if opts.ReaperStopped != nil {
|
||||
close(opts.ReaperStopped)
|
||||
}
|
||||
}()
|
||||
go reap.ReapChildren(opts.PIDs, nil, opts.Done, nil)
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
@@ -100,7 +90,7 @@ func ForkReap(opt ...Option) (int, error) {
|
||||
return 1, xerrors.Errorf("fork exec: %w", err)
|
||||
}
|
||||
|
||||
startSignalForwarding(opts.Logger, pid, opts.CatchSignals)
|
||||
go catchSignals(opts.Logger, pid, opts.CatchSignals)
|
||||
|
||||
var wstatus syscall.WaitStatus
|
||||
_, err = syscall.Wait4(pid, &wstatus, 0, nil)
|
||||
|
||||
@@ -46,7 +46,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
autoUpdates string
|
||||
copyParametersFrom string
|
||||
useParameterDefaults bool
|
||||
noWait bool
|
||||
// Organization context is only required if more than 1 template
|
||||
// shares the same name across multiple organizations.
|
||||
orgContext = NewOrganizationContext()
|
||||
@@ -373,14 +372,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
|
||||
cliutil.WarnMatchedProvisioners(inv.Stderr, workspace.LatestBuild.MatchedProvisioners, workspace.LatestBuild.Job)
|
||||
|
||||
if noWait {
|
||||
_, _ = fmt.Fprintf(inv.Stdout,
|
||||
"\nThe %s workspace has been created and is building in the background.\n",
|
||||
cliui.Keyword(workspace.Name),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("watch build: %w", err)
|
||||
@@ -454,12 +445,6 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
Description: "Automatically accept parameter defaults when no value is provided.",
|
||||
Value: serpent.BoolOf(&useParameterDefaults),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "no-wait",
|
||||
Env: "CODER_CREATE_NO_WAIT",
|
||||
Description: "Return immediately after creating the workspace. The build will run in the background.",
|
||||
Value: serpent.BoolOf(&noWait),
|
||||
},
|
||||
cliui.SkipPromptOption(),
|
||||
)
|
||||
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
|
||||
|
||||
@@ -603,81 +603,6 @@ func TestCreate(t *testing.T) {
|
||||
assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NoWait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv, root := clitest.New(t, "create", "my-workspace",
|
||||
"--template", template.Name,
|
||||
"-y",
|
||||
"--no-wait",
|
||||
)
|
||||
clitest.SetupConfig(t, member, root)
|
||||
doneChan := make(chan struct{})
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatchContext(ctx, "building in the background")
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify workspace was actually created.
|
||||
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ws.TemplateName, template.Name)
|
||||
})
|
||||
|
||||
t.Run("NoWaitWithParameterDefaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses([]*proto.RichParameter{
|
||||
{Name: "region", Type: "string", DefaultValue: "us-east-1"},
|
||||
{Name: "instance_type", Type: "string", DefaultValue: "t3.micro"},
|
||||
}))
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
inv, root := clitest.New(t, "create", "my-workspace",
|
||||
"--template", template.Name,
|
||||
"-y",
|
||||
"--use-parameter-defaults",
|
||||
"--no-wait",
|
||||
)
|
||||
clitest.SetupConfig(t, member, root)
|
||||
doneChan := make(chan struct{})
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
pty.ExpectMatchContext(ctx, "building in the background")
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
|
||||
// Verify workspace was created and parameters were applied.
|
||||
ws, err := member.WorkspaceByOwnerAndName(ctx, codersdk.Me, "my-workspace", codersdk.WorkspaceOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ws.TemplateName, template.Name)
|
||||
|
||||
buildParams, err := member.WorkspaceBuildParameters(ctx, ws.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "region", Value: "us-east-1"})
|
||||
assert.Contains(t, buildParams, codersdk.WorkspaceBuildParameter{Name: "instance_type", Value: "t3.micro"})
|
||||
})
|
||||
}
|
||||
|
||||
func prepareEchoResponses(parameters []*proto.RichParameter, presets ...*proto.Preset) *echo.Responses {
|
||||
|
||||
@@ -1000,12 +1000,6 @@ func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool
|
||||
Properties: sdkTool.Schema.Properties,
|
||||
Required: sdkTool.Schema.Required,
|
||||
},
|
||||
Annotations: mcp.ToolAnnotation{
|
||||
ReadOnlyHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.ReadOnlyHint),
|
||||
DestructiveHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.DestructiveHint),
|
||||
IdempotentHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.IdempotentHint),
|
||||
OpenWorldHint: mcp.ToBoolPtr(sdkTool.MCPAnnotations.OpenWorldHint),
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
+1
-16
@@ -81,13 +81,7 @@ func TestExpMcpServer(t *testing.T) {
|
||||
var toolsResponse struct {
|
||||
Result struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Annotations struct {
|
||||
ReadOnlyHint *bool `json:"readOnlyHint"`
|
||||
DestructiveHint *bool `json:"destructiveHint"`
|
||||
IdempotentHint *bool `json:"idempotentHint"`
|
||||
OpenWorldHint *bool `json:"openWorldHint"`
|
||||
} `json:"annotations"`
|
||||
Name string `json:"name"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
@@ -100,15 +94,6 @@ func TestExpMcpServer(t *testing.T) {
|
||||
}
|
||||
slices.Sort(foundTools)
|
||||
require.Equal(t, []string{"coder_get_authenticated_user"}, foundTools)
|
||||
annotations := toolsResponse.Result.Tools[0].Annotations
|
||||
require.NotNil(t, annotations.ReadOnlyHint)
|
||||
require.NotNil(t, annotations.DestructiveHint)
|
||||
require.NotNil(t, annotations.IdempotentHint)
|
||||
require.NotNil(t, annotations.OpenWorldHint)
|
||||
assert.True(t, *annotations.ReadOnlyHint)
|
||||
assert.False(t, *annotations.DestructiveHint)
|
||||
assert.True(t, *annotations.IdempotentHint)
|
||||
assert.False(t, *annotations.OpenWorldHint)
|
||||
|
||||
// Call the tool and ensure it works.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":3,"method":"tools/call", "params": {"name": "coder_get_authenticated_user", "arguments": {}}}`
|
||||
|
||||
+45
-74
@@ -1732,18 +1732,19 @@ const (
|
||||
|
||||
func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
var (
|
||||
workspaceCount int64
|
||||
workspaceJobTimeout time.Duration
|
||||
autostartBuildTimeout time.Duration
|
||||
autostartDelay time.Duration
|
||||
template string
|
||||
noCleanup bool
|
||||
workspaceCount int64
|
||||
workspaceJobTimeout time.Duration
|
||||
autostartDelay time.Duration
|
||||
autostartTimeout time.Duration
|
||||
template string
|
||||
noCleanup bool
|
||||
|
||||
parameterFlags workspaceParameterFlags
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
@@ -1771,7 +1772,7 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse output flags: %w", err)
|
||||
return xerrors.Errorf("could not parse --output flags")
|
||||
}
|
||||
|
||||
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
|
||||
@@ -1802,41 +1803,15 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
}
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := autostart.NewMetrics(reg)
|
||||
|
||||
setupBarrier := new(sync.WaitGroup)
|
||||
setupBarrier.Add(int(workspaceCount))
|
||||
|
||||
// The workspace-build-updates experiment must be enabled to use
|
||||
// the centralized pubsub channel for coordinating workspace builds.
|
||||
experiments, err := client.Experiments(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get experiments: %w", err)
|
||||
}
|
||||
if !experiments.Enabled(codersdk.ExperimentWorkspaceBuildUpdates) {
|
||||
return xerrors.New("the workspace-build-updates experiment must be enabled to run the autostart scaletest")
|
||||
}
|
||||
|
||||
workspaceNames := make([]string, 0, workspaceCount)
|
||||
resultSink := make(chan autostart.RunResult, workspaceCount)
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
for i := range workspaceCount {
|
||||
id := strconv.Itoa(int(i))
|
||||
workspaceNames = append(workspaceNames, loadtestutil.GenerateDeterministicWorkspaceName(id))
|
||||
}
|
||||
dispatcher := autostart.NewWorkspaceDispatcher(workspaceNames)
|
||||
|
||||
decoder, err := client.WatchAllWorkspaceBuilds(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("watch all workspace builds: %w", err)
|
||||
}
|
||||
defer decoder.Close()
|
||||
|
||||
// Start the dispatcher. It will run in a goroutine and automatically
|
||||
// close all workspace channels when the build updates channel closes.
|
||||
dispatcher.Start(ctx, decoder.Chan())
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
for workspaceName, buildUpdatesChannel := range dispatcher.Channels {
|
||||
id := strings.TrimPrefix(workspaceName, loadtestutil.ScaleTestPrefix+"-")
|
||||
|
||||
config := autostart.Config{
|
||||
User: createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
@@ -1846,16 +1821,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
Request: codersdk.CreateWorkspaceRequest{
|
||||
TemplateID: tpl.ID,
|
||||
RichParameterValues: richParameters,
|
||||
// Use deterministic workspace name so we can pre-create the channel.
|
||||
Name: workspaceName,
|
||||
},
|
||||
},
|
||||
WorkspaceJobTimeout: workspaceJobTimeout,
|
||||
AutostartBuildTimeout: autostartBuildTimeout,
|
||||
AutostartDelay: autostartDelay,
|
||||
SetupBarrier: setupBarrier,
|
||||
BuildUpdates: buildUpdatesChannel,
|
||||
ResultSink: resultSink,
|
||||
WorkspaceJobTimeout: workspaceJobTimeout,
|
||||
AutostartDelay: autostartDelay,
|
||||
AutostartTimeout: autostartTimeout,
|
||||
Metrics: metrics,
|
||||
SetupBarrier: setupBarrier,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
@@ -1877,11 +1849,18 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
th.AddRun(autostartTestName, id, runner)
|
||||
}
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
|
||||
if err := closeTracing(ctx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
// Wait for prometheus metrics to be scraped
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Running autostart load test...")
|
||||
@@ -1892,40 +1871,31 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// Collect all metrics from the channel.
|
||||
close(resultSink)
|
||||
var runResults []autostart.RunResult
|
||||
for r := range resultSink {
|
||||
runResults = append(runResults, r)
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nAll %d autostart builds completed successfully (elapsed: %s)\n", res.TotalRuns, time.Duration(res.Elapsed).Round(time.Millisecond))
|
||||
|
||||
if len(runResults) > 0 {
|
||||
results := autostart.NewRunResults(runResults)
|
||||
for _, out := range outputs {
|
||||
if err := out.write(results.ToHarnessResults(), inv.Stdout); err != nil {
|
||||
return xerrors.Errorf("write output: %w", err)
|
||||
}
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(context.Background())
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Cleanup complete")
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nSkipping cleanup (--no-cleanup specified). Resources left running.")
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1948,13 +1918,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
Description: "Timeout for workspace jobs (e.g. build, start).",
|
||||
Value: serpent.DurationOf(&workspaceJobTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "autostart-build-timeout",
|
||||
Env: "CODER_SCALETEST_AUTOSTART_BUILD_TIMEOUT",
|
||||
Default: "15m",
|
||||
Description: "Timeout for the autostart build to complete. Must be longer than workspace-job-timeout to account for queueing time in high-load scenarios.",
|
||||
Value: serpent.DurationOf(&autostartBuildTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "autostart-delay",
|
||||
Env: "CODER_SCALETEST_AUTOSTART_DELAY",
|
||||
@@ -1962,6 +1925,13 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
Description: "How long after all the workspaces have been stopped to schedule them to be started again.",
|
||||
Value: serpent.DurationOf(&autostartDelay),
|
||||
},
|
||||
{
|
||||
Flag: "autostart-timeout",
|
||||
Env: "CODER_SCALETEST_AUTOSTART_TIMEOUT",
|
||||
Default: "5m",
|
||||
Description: "Timeout for the autostart build to be initiated after the scheduled start time.",
|
||||
Value: serpent.DurationOf(&autostartTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "template",
|
||||
FlagShorthand: "t",
|
||||
@@ -1980,9 +1950,10 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
|
||||
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
} else {
|
||||
updated, err = client.CreateOrganizationRole(ctx, customRole)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create role: %w", err)
|
||||
return xerrors.Errorf("patch role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -79,29 +79,6 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
)
|
||||
build = workspace.LatestBuild
|
||||
default:
|
||||
// If the last build was a failed start, run a stop
|
||||
// first to clean up any partially-provisioned
|
||||
// resources.
|
||||
if workspace.LatestBuild.Status == codersdk.WorkspaceStatusFailed &&
|
||||
workspace.LatestBuild.Transition == codersdk.WorkspaceTransitionStart {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "The last start build failed. Cleaning up before retrying...\n")
|
||||
stopBuild, stopErr := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionStop,
|
||||
})
|
||||
if stopErr != nil {
|
||||
return xerrors.Errorf("cleanup stop after failed start: %w", stopErr)
|
||||
}
|
||||
stopErr = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, stopBuild.ID)
|
||||
if stopErr != nil {
|
||||
return xerrors.Errorf("wait for cleanup stop: %w", stopErr)
|
||||
}
|
||||
// Re-fetch workspace after stop completes so
|
||||
// startWorkspace sees the latest state.
|
||||
workspace, err = namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
build, err = startWorkspace(inv, client, workspace, parameterFlags, bflags, WorkspaceStart)
|
||||
// It's possible for a workspace build to fail due to the template requiring starting
|
||||
// workspaces with the active version.
|
||||
|
||||
@@ -534,55 +534,3 @@ func TestStart_WithReason(t *testing.T) {
|
||||
workspace = coderdtest.MustWorkspace(t, member, workspace.ID)
|
||||
require.Equal(t, codersdk.BuildReasonCLI, workspace.LatestBuild.Reason)
|
||||
}
|
||||
|
||||
func TestStart_FailedStartCleansUp(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
store, ps := dbtestutil.NewDB(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: store,
|
||||
Pubsub: ps,
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, memberClient, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// Insert a failed start build directly into the database so that
|
||||
// the workspace's latest build is a failed "start" transition.
|
||||
dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
ID: workspace.ID,
|
||||
OwnerID: member.ID,
|
||||
OrganizationID: owner.OrganizationID,
|
||||
TemplateID: template.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
BuildNumber: workspace.LatestBuild.BuildNumber + 1,
|
||||
}).
|
||||
Failed().
|
||||
Do()
|
||||
|
||||
inv, root := clitest.New(t, "start", workspace.Name)
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// The CLI should detect the failed start and clean up first.
|
||||
pty.ExpectMatch("Cleaning up before retrying")
|
||||
pty.ExpectMatch("workspace has been started")
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, doneChan)
|
||||
}
|
||||
|
||||
+17
-26
@@ -113,20 +113,6 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
)
|
||||
cliLog.Debug(inv.Context(), "invocation", slog.F("args", strings.Join(os.Args, " ")))
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
// Note: this can only be done by the owner user.
|
||||
if ok, err := support.CanGenerateFull(inv.Context(), client); err == nil && ok {
|
||||
cliLog.Debug(inv.Context(), "running as owner")
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
} else if !ok {
|
||||
cliLog.Warn(inv.Context(), "not running as owner, not all information available")
|
||||
} else {
|
||||
cliLog.Error(inv.Context(), "failed to look up current user", slog.Error(err))
|
||||
}
|
||||
|
||||
// Check if we're running inside a workspace
|
||||
if val, found := os.LookupEnv("CODER"); found && val == "true" {
|
||||
cliui.Warn(inv.Stderr, "Running inside Coder workspace; this can affect results!")
|
||||
@@ -214,6 +200,12 @@ func (r *RootCmd) supportBundle() *serpent.Command {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "pprof data collection will take approximately 30 seconds...")
|
||||
}
|
||||
|
||||
// Bypass rate limiting for support bundle collection since it makes many API calls.
|
||||
client.HTTPClient.Transport = &codersdk.HeaderTransport{
|
||||
Transport: client.HTTPClient.Transport,
|
||||
Header: http.Header{codersdk.BypassRatelimitHeader: {"true"}},
|
||||
}
|
||||
|
||||
deps := support.Deps{
|
||||
Client: client,
|
||||
// Support adds a sink so we don't need to supply one ourselves.
|
||||
@@ -362,20 +354,19 @@ func summarizeBundle(inv *serpent.Invocation, bun *support.Bundle) {
|
||||
return
|
||||
}
|
||||
|
||||
var docsURL string
|
||||
if bun.Deployment.Config != nil {
|
||||
docsURL = bun.Deployment.Config.Values.DocsURL.String()
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment configuration available. This may require the Owner role.")
|
||||
if bun.Deployment.Config == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment configuration available!")
|
||||
return
|
||||
}
|
||||
|
||||
if bun.Deployment.HealthReport != nil {
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
}
|
||||
} else {
|
||||
cliui.Warn(inv.Stdout, "No deployment health report available.")
|
||||
docsURL := bun.Deployment.Config.Values.DocsURL.String()
|
||||
if bun.Deployment.HealthReport == nil {
|
||||
cliui.Error(inv.Stdout, "No deployment health report available!")
|
||||
return
|
||||
}
|
||||
deployHealthSummary := bun.Deployment.HealthReport.Summarize(docsURL)
|
||||
if len(deployHealthSummary) > 0 {
|
||||
cliui.Warn(inv.Stdout, "Deployment health issues detected:", deployHealthSummary...)
|
||||
}
|
||||
|
||||
if bun.Network.Netcheck == nil {
|
||||
|
||||
+3
-30
@@ -132,35 +132,12 @@ func TestSupportBundle(t *testing.T) {
|
||||
assertBundleContents(t, path, true, false, []string{secretValue})
|
||||
})
|
||||
|
||||
t.Run("MemberCanGenerateBundle", func(t *testing.T) {
|
||||
t.Run("NoPrivilege", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d := t.TempDir()
|
||||
path := filepath.Join(d, "bundle.zip")
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--output-file", path, "--yes")
|
||||
inv, root := clitest.New(t, "support", "bundle", memberWorkspace.Workspace.Name, "--yes")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
r, err := zip.OpenReader(path)
|
||||
require.NoError(t, err, "open zip file")
|
||||
defer r.Close()
|
||||
fileNames := make(map[string]struct{}, len(r.File))
|
||||
for _, f := range r.File {
|
||||
fileNames[f.Name] = struct{}{}
|
||||
}
|
||||
// These should always be present in the zip structure, even if
|
||||
// the content is null/empty for non-admin users.
|
||||
for _, name := range []string{
|
||||
"deployment/buildinfo.json",
|
||||
"deployment/config.json",
|
||||
"workspace/workspace.json",
|
||||
"logs.txt",
|
||||
"cli_logs.txt",
|
||||
"network/netcheck.json",
|
||||
"network/interfaces.json",
|
||||
} {
|
||||
require.Contains(t, fileNames, name)
|
||||
}
|
||||
require.ErrorContains(t, err, "failed authorization check")
|
||||
})
|
||||
|
||||
// This ensures that the CLI does not panic when trying to generate a support bundle
|
||||
@@ -182,10 +159,6 @@ func TestSupportBundle(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("received request: %s %s", r.Method, r.URL)
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me":
|
||||
resp := codersdk.User{}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(resp))
|
||||
case "/api/v2/authcheck":
|
||||
// Fake auth check
|
||||
resp := codersdk.AuthorizationResponse{
|
||||
|
||||
-4
@@ -20,10 +20,6 @@ OPTIONS:
|
||||
--copy-parameters-from string, $CODER_WORKSPACE_COPY_PARAMETERS_FROM
|
||||
Specify the source workspace name to copy parameters from.
|
||||
|
||||
--no-wait bool, $CODER_CREATE_NO_WAIT
|
||||
Return immediately after creating the workspace. The build will run in
|
||||
the background.
|
||||
|
||||
--parameter string-array, $CODER_RICH_PARAMETER
|
||||
Rich parameter value in the format "name=value".
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"last_seen_at": "====[timestamp]=====",
|
||||
"name": "test-daemon",
|
||||
"version": "v0.0.0-devel",
|
||||
"api_version": "1.16",
|
||||
"api_version": "1.15",
|
||||
"provisioners": [
|
||||
"echo"
|
||||
],
|
||||
|
||||
+5
@@ -143,6 +143,11 @@ AI BRIDGE OPTIONS:
|
||||
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
|
||||
Whether to start an in-memory aibridged instance.
|
||||
|
||||
--aibridge-inject-coder-mcp-tools bool, $CODER_AIBRIDGE_INJECT_CODER_MCP_TOOLS (default: false)
|
||||
Whether to inject Coder's MCP tools into intercepted AI Bridge
|
||||
requests (requires the "oauth2" and "mcp-server-http" experiments to
|
||||
be enabled).
|
||||
|
||||
--aibridge-max-concurrency int, $CODER_AIBRIDGE_MAX_CONCURRENCY (default: 0)
|
||||
Maximum number of concurrent AI Bridge requests per replica. Set to 0
|
||||
to disable (unlimited).
|
||||
|
||||
@@ -24,10 +24,6 @@ OPTIONS:
|
||||
-p, --password string
|
||||
Specifies a password for the new user.
|
||||
|
||||
--service-account bool
|
||||
Create a user account intended to be used by a service or as an
|
||||
intermediary rather than by a human.
|
||||
|
||||
-u, --username string
|
||||
Specifies a username for the new user.
|
||||
|
||||
|
||||
+2
-9
@@ -752,11 +752,6 @@ workspace_prebuilds:
|
||||
# limit; disabled when set to zero.
|
||||
# (default: 3, type: int)
|
||||
failure_hard_limit: 3
|
||||
# Configure the background chat processing daemon.
|
||||
chat:
|
||||
# How many pending chats a worker should acquire per polling cycle.
|
||||
# (default: 10, type: int)
|
||||
acquireBatchSize: 10
|
||||
aibridge:
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
# (default: false, type: bool)
|
||||
@@ -783,10 +778,8 @@ aibridge:
|
||||
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
|
||||
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
|
||||
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
|
||||
# Deprecated: Injected MCP in AI Bridge is deprecated and will be removed in a
|
||||
# future release. Whether to inject Coder's MCP tools into intercepted AI Bridge
|
||||
# requests (requires the "oauth2" and "mcp-server-http" experiments to be
|
||||
# enabled).
|
||||
# Whether to inject Coder's MCP tools into intercepted AI Bridge requests
|
||||
# (requires the "oauth2" and "mcp-server-http" experiments to be enabled).
|
||||
# (default: false, type: bool)
|
||||
inject_coder_mcp_tools: false
|
||||
# Length of time to retain data such as interceptions and all related records
|
||||
|
||||
+12
-37
@@ -17,14 +17,13 @@ import (
|
||||
|
||||
func (r *RootCmd) userCreate() *serpent.Command {
|
||||
var (
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
serviceAccount bool
|
||||
orgContext = NewOrganizationContext()
|
||||
email string
|
||||
username string
|
||||
name string
|
||||
password string
|
||||
disableLogin bool
|
||||
loginType string
|
||||
orgContext = NewOrganizationContext()
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create",
|
||||
@@ -33,23 +32,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
serpent.RequireNArgs(0),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
if serviceAccount {
|
||||
switch {
|
||||
case loginType != "":
|
||||
return xerrors.New("You cannot use --login-type with --service-account")
|
||||
case password != "":
|
||||
return xerrors.New("You cannot use --password with --service-account")
|
||||
case email != "":
|
||||
return xerrors.New("You cannot use --email with --service-account")
|
||||
case disableLogin:
|
||||
return xerrors.New("You cannot use --disable-login with --service-account")
|
||||
}
|
||||
}
|
||||
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -77,7 +59,7 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if email == "" && !serviceAccount {
|
||||
if email == "" {
|
||||
email, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Email:",
|
||||
Validate: func(s string) error {
|
||||
@@ -105,7 +87,10 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
}
|
||||
}
|
||||
userLoginType := codersdk.LoginTypePassword
|
||||
if disableLogin || serviceAccount {
|
||||
if disableLogin && loginType != "" {
|
||||
return xerrors.New("You cannot specify both --disable-login and --login-type")
|
||||
}
|
||||
if disableLogin {
|
||||
userLoginType = codersdk.LoginTypeNone
|
||||
} else if loginType != "" {
|
||||
userLoginType = codersdk.LoginType(loginType)
|
||||
@@ -126,7 +111,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
Password: password,
|
||||
OrganizationIDs: []uuid.UUID{organization.ID},
|
||||
UserLoginType: userLoginType,
|
||||
ServiceAccount: serviceAccount,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -143,10 +127,6 @@ func (r *RootCmd) userCreate() *serpent.Command {
|
||||
case codersdk.LoginTypeOIDC:
|
||||
authenticationMethod = `Login is authenticated through the configured OIDC provider.`
|
||||
}
|
||||
if serviceAccount {
|
||||
email = "n/a"
|
||||
authenticationMethod = "Service accounts must authenticate with a token and cannot log in."
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, `A new user has been created!
|
||||
Share the instructions below to get them started.
|
||||
@@ -214,11 +194,6 @@ Create a workspace `+pretty.Sprint(cliui.DefaultStyles.Code, "coder create")+`!
|
||||
)),
|
||||
Value: serpent.StringOf(&loginType),
|
||||
},
|
||||
{
|
||||
Flag: "service-account",
|
||||
Description: "Create a user account intended to be used by a service or as an intermediary rather than by a human.",
|
||||
Value: serpent.BoolOf(&serviceAccount),
|
||||
},
|
||||
}
|
||||
|
||||
orgContext.AttachOptions(cmd)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -125,56 +124,4 @@ func TestUserCreate(t *testing.T) {
|
||||
assert.Equal(t, args[5], created.Username)
|
||||
assert.Empty(t, created.Name)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
args: []string{"--service-account", "-u", "dean", "--login-type", "none"},
|
||||
err: "You cannot use --login-type with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountDisableLogin",
|
||||
args: []string{"--service-account", "-u", "dean", "--disable-login"},
|
||||
err: "You cannot use --disable-login with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountEmail",
|
||||
args: []string{"--service-account", "-u", "dean", "--email", "dean@coder.com"},
|
||||
err: "You cannot use --email with --service-account",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountPassword",
|
||||
args: []string{"--service-account", "-u", "dean", "--password", "1n5ecureP4ssw0rd!"},
|
||||
err: "You cannot use --password with --service-account",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
inv, root := clitest.New(t, append([]string{"users", "create"}, tt.args...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
err := inv.Run()
|
||||
if tt.err == "" {
|
||||
require.NoError(t, err)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
created, err := client.User(ctx, "dean")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, codersdk.LoginTypeNone, created.LoginType)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
// Package aiseats is the AGPL version the package.
|
||||
// The actual implementation is in `enterprise/aiseats`.
|
||||
package aiseats
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
)
|
||||
|
||||
type Reason struct {
|
||||
EventType database.AiSeatUsageReason
|
||||
Description string
|
||||
}
|
||||
|
||||
// ReasonAIBridge constructs a reason for usage originating from AI Bridge.
|
||||
func ReasonAIBridge(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonAibridge, Description: description}
|
||||
}
|
||||
|
||||
// ReasonTask constructs a reason for usage originating from tasks.
|
||||
func ReasonTask(description string) Reason {
|
||||
return Reason{EventType: database.AiSeatUsageReasonTask, Description: description}
|
||||
}
|
||||
|
||||
// SeatTracker records AI seat consumption state.
|
||||
type SeatTracker interface {
|
||||
// RecordUsage does not return an error to prevent blocking the user from using
|
||||
// AI features. This method is used to record usage, not enforce it.
|
||||
RecordUsage(ctx context.Context, userID uuid.UUID, reason Reason)
|
||||
}
|
||||
|
||||
// Noop is an AGPL seat tracker that does nothing.
|
||||
type Noop struct{}
|
||||
|
||||
func (Noop) RecordUsage(context.Context, uuid.UUID, Reason) {}
|
||||
Generated
+1683
-2030
File diff suppressed because it is too large
Load Diff
Generated
+1683
-2008
File diff suppressed because it is too large
Load Diff
@@ -32,8 +32,7 @@ type Auditable interface {
|
||||
idpsync.OrganizationSyncSettings |
|
||||
idpsync.GroupSyncSettings |
|
||||
idpsync.RoleSyncSettings |
|
||||
database.TaskTable |
|
||||
database.AiSeatState
|
||||
database.TaskTable
|
||||
}
|
||||
|
||||
// Map is a map of changed fields in an audited resource. It maps field names to
|
||||
|
||||
@@ -132,8 +132,6 @@ func ResourceTarget[T Auditable](tgt T) string {
|
||||
return "Organization Role Sync"
|
||||
case database.TaskTable:
|
||||
return typed.Name
|
||||
case database.AiSeatState:
|
||||
return "AI Seat"
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceTarget", tgt))
|
||||
}
|
||||
@@ -198,8 +196,6 @@ func ResourceID[T Auditable](tgt T) uuid.UUID {
|
||||
return noID // Org field on audit log has org id
|
||||
case database.TaskTable:
|
||||
return typed.ID
|
||||
case database.AiSeatState:
|
||||
return typed.UserID
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceID", tgt))
|
||||
}
|
||||
@@ -255,8 +251,6 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
|
||||
return database.ResourceTypeIdpSyncSettingsGroup
|
||||
case database.TaskTable:
|
||||
return database.ResourceTypeTask
|
||||
case database.AiSeatState:
|
||||
return database.ResourceTypeAiSeat
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceType", typed))
|
||||
}
|
||||
@@ -315,8 +309,6 @@ func ResourceRequiresOrgID[T Auditable]() bool {
|
||||
return true
|
||||
case database.TaskTable:
|
||||
return true
|
||||
case database.AiSeatState:
|
||||
return false
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T for ResourceRequiresOrgID", tgt))
|
||||
}
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package chatcost
|
||||
|
||||
import (
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Returns cost in micros -- millionths of a dollar, rounded up to the next
|
||||
// whole microdollar.
|
||||
// Returns nil when pricing is not configured or when all priced usage fields
|
||||
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
|
||||
func CalculateTotalCostMicros(
|
||||
usage codersdk.ChatMessageUsage,
|
||||
cost *codersdk.ModelCostConfig,
|
||||
) *int64 {
|
||||
if cost == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// A cost config with no prices set means pricing is effectively
|
||||
// unconfigured — return nil (unpriced) rather than zero.
|
||||
if cost.InputPricePerMillionTokens == nil &&
|
||||
cost.OutputPricePerMillionTokens == nil &&
|
||||
cost.CacheReadPricePerMillionTokens == nil &&
|
||||
cost.CacheWritePricePerMillionTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if usage.InputTokens == nil &&
|
||||
usage.OutputTokens == nil &&
|
||||
usage.ReasoningTokens == nil &&
|
||||
usage.CacheCreationTokens == nil &&
|
||||
usage.CacheReadTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OutputTokens already includes reasoning tokens per provider
|
||||
// semantics (e.g. OpenAI's completion_tokens encompasses
|
||||
// reasoning_tokens). Adding ReasoningTokens here would
|
||||
// double-count.
|
||||
|
||||
// Preserve nil when usage exists only in categories without configured
|
||||
// pricing, so callers can distinguish "unpriced" from "priced at zero".
|
||||
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
|
||||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
|
||||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
|
||||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
|
||||
if !hasMatchingPrice {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
|
||||
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
|
||||
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
|
||||
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
|
||||
|
||||
total := inputMicros.
|
||||
Add(outputMicros).
|
||||
Add(cacheReadMicros).
|
||||
Add(cacheWriteMicros)
|
||||
rounded := total.Ceil().IntPart()
|
||||
return &rounded
|
||||
}
|
||||
|
||||
// calcCost returns the cost in fractional microdollars (millionths of a USD)
|
||||
// for the given token count at the specified per-million-token price.
|
||||
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
|
||||
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package chatcost_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
usage codersdk.ChatMessageUsage
|
||||
cost *codersdk.ModelCostConfig
|
||||
want *int64
|
||||
}{
|
||||
{
|
||||
name: "nil cost returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "all priced usage fields nil returns nil",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
TotalTokens: ptr.Ref[int64](1234),
|
||||
ContextLimit: ptr.Ref[int64](8192),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "sub-micro total rounds up to 1",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
|
||||
},
|
||||
want: ptr.Ref[int64](1),
|
||||
},
|
||||
{
|
||||
name: "simple input only",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "simple output only",
|
||||
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "reasoning tokens included in output total",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
OutputTokens: ptr.Ref[int64](500),
|
||||
ReasoningTokens: ptr.Ref[int64](200),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "cache read tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "cache creation tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
|
||||
},
|
||||
want: ptr.Ref[int64](18750),
|
||||
},
|
||||
{
|
||||
name: "full mixed usage totals all components exactly",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](101),
|
||||
OutputTokens: ptr.Ref[int64](201),
|
||||
ReasoningTokens: ptr.Ref[int64](52),
|
||||
CacheReadTokens: ptr.Ref[int64](1005),
|
||||
CacheCreationTokens: ptr.Ref[int64](33),
|
||||
TotalTokens: ptr.Ref[int64](1391),
|
||||
ContextLimit: ptr.Ref[int64](4096),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
|
||||
},
|
||||
want: ptr.Ref[int64](2005),
|
||||
},
|
||||
{
|
||||
name: "partial pricing only input contributes",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](1234),
|
||||
OutputTokens: ptr.Ref[int64](999),
|
||||
ReasoningTokens: ptr.Ref[int64](111),
|
||||
CacheReadTokens: ptr.Ref[int64](500),
|
||||
CacheCreationTokens: ptr.Ref[int64](250),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
|
||||
},
|
||||
want: ptr.Ref[int64](3085),
|
||||
},
|
||||
{
|
||||
name: "zero tokens with pricing returns zero pointer",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](0),
|
||||
},
|
||||
{
|
||||
name: "usage only in unpriced categories returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "non nil usage with empty cost config returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
||||
cost: &codersdk.ModelCostConfig{},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
||||
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+503
-952
File diff suppressed because it is too large
Load Diff
@@ -2,20 +2,13 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
)
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
@@ -91,135 +84,3 @@ func TestRefreshChatWorkspaceSnapshot_ReturnsReloadError(t *testing.T) {
|
||||
require.ErrorContains(t, err, loadErr.Error())
|
||||
require.Equal(t, chat, refreshed)
|
||||
}
|
||||
|
||||
func TestResolveInstructionsReusesTurnLocalWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
workspaceAgent := database.WorkspaceAgent{
|
||||
ID: uuid.New(),
|
||||
OperatingSystem: "linux",
|
||||
Directory: "/home/coder/project",
|
||||
ExpandedDirectory: "/home/coder/project",
|
||||
}
|
||||
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{workspaceAgent}, nil).Times(1)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
conn.EXPECT().LS(gomock.Any(), "", gomock.Any()).Return(
|
||||
workspacesdk.LSResponse{},
|
||||
codersdk.NewTestError(404, "POST", "/api/v0/list-directory"),
|
||||
).Times(1)
|
||||
conn.EXPECT().ReadFile(
|
||||
gomock.Any(),
|
||||
"/home/coder/project/AGENTS.md",
|
||||
int64(0),
|
||||
int64(maxInstructionFileBytes+1),
|
||||
).Return(
|
||||
nil,
|
||||
"",
|
||||
codersdk.NewTestError(404, "GET", "/api/v0/read-file"),
|
||||
).Times(1)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
instructionCache: make(map[uuid.UUID]cachedInstruction),
|
||||
agentConnFn: func(context.Context, uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
return conn, func() {}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
instruction := server.resolveInstructions(
|
||||
ctx,
|
||||
chat,
|
||||
workspaceCtx.getWorkspaceAgent,
|
||||
workspaceCtx.getWorkspaceConn,
|
||||
)
|
||||
require.Contains(t, instruction, "Operating System: linux")
|
||||
require.Contains(t, instruction, "Working Directory: /home/coder/project")
|
||||
}
|
||||
|
||||
func TestTurnWorkspaceContextGetWorkspaceConnRefreshesWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
chat := database.Chat{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspaceID,
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
initialAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
refreshedAgent := database.WorkspaceAgent{ID: uuid.New()}
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{initialAgent}, nil),
|
||||
db.EXPECT().GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
||||
gomock.Any(),
|
||||
workspaceID,
|
||||
).Return([]database.WorkspaceAgent{refreshedAgent}, nil),
|
||||
)
|
||||
|
||||
conn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
conn.EXPECT().SetExtraHeaders(gomock.Any()).Times(1)
|
||||
|
||||
var dialed []uuid.UUID
|
||||
server := &Server{db: db}
|
||||
server.agentConnFn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
|
||||
dialed = append(dialed, agentID)
|
||||
if agentID == initialAgent.ID {
|
||||
return nil, nil, xerrors.New("dial failed")
|
||||
}
|
||||
return conn, func() {}, nil
|
||||
}
|
||||
|
||||
chatStateMu := &sync.Mutex{}
|
||||
currentChat := chat
|
||||
workspaceCtx := turnWorkspaceContext{
|
||||
server: server,
|
||||
chatStateMu: chatStateMu,
|
||||
currentChat: ¤tChat,
|
||||
loadChatSnapshot: func(context.Context, uuid.UUID) (database.Chat, error) { return database.Chat{}, nil },
|
||||
}
|
||||
t.Cleanup(workspaceCtx.close)
|
||||
|
||||
gotConn, err := workspaceCtx.getWorkspaceConn(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, conn, gotConn)
|
||||
require.Equal(t, []uuid.UUID{initialAgent.ID, refreshedAgent.ID}, dialed)
|
||||
}
|
||||
|
||||
+101
-1701
File diff suppressed because it is too large
Load Diff
@@ -42,11 +42,6 @@ type PersistedStep struct {
|
||||
Content []fantasy.Content
|
||||
Usage fantasy.Usage
|
||||
ContextLimit sql.NullInt64
|
||||
// Runtime is the wall-clock duration of this step,
|
||||
// covering LLM streaming, tool execution, and retries.
|
||||
// Zero indicates the duration was not measured (e.g.
|
||||
// interrupted steps).
|
||||
Runtime time.Duration
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
@@ -68,16 +63,15 @@ type RunOptions struct {
|
||||
// of the provider, which lives in chatd, not chatloop.
|
||||
ProviderOptions fantasy.ProviderOptions
|
||||
|
||||
// ProviderTools are provider-native tools (like web search
|
||||
// and computer use) whose definitions are passed directly
|
||||
// to the provider API. When a ProviderTool has a non-nil
|
||||
// Runner, tool calls are executed locally; otherwise the
|
||||
// provider handles execution (e.g. web search).
|
||||
ProviderTools []ProviderTool
|
||||
// ProviderTools are provider-native tools (like web search)
|
||||
// that are passed directly to the provider API alongside
|
||||
// function tool definitions. These are not necessarily
|
||||
// executed server-side; handling is provider-specific.
|
||||
ProviderTools []fantasy.Tool
|
||||
|
||||
PersistStep func(context.Context, PersistedStep) error
|
||||
PublishMessagePart func(
|
||||
role codersdk.ChatMessageRole,
|
||||
role fantasy.MessageRole,
|
||||
part codersdk.ChatMessagePart,
|
||||
)
|
||||
Compaction *CompactionOptions
|
||||
@@ -94,16 +88,6 @@ type RunOptions struct {
|
||||
OnInterruptedPersistError func(error)
|
||||
}
|
||||
|
||||
// ProviderTool pairs a provider-native tool definition with an
|
||||
// optional local executor. When Runner is nil the tool is fully
|
||||
// provider-executed (e.g. web search). When Runner is non-nil
|
||||
// the definition is sent to the API but execution is handled
|
||||
// locally (e.g. computer use).
|
||||
type ProviderTool struct {
|
||||
Definition fantasy.Tool
|
||||
Runner fantasy.AgentTool
|
||||
}
|
||||
|
||||
// stepResult holds the accumulated output of a single streaming
|
||||
// step. Since we own the stream consumer, all content is tracked
|
||||
// directly here — no shadow draft state needed.
|
||||
@@ -232,7 +216,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
opts.MaxSteps = 1
|
||||
}
|
||||
|
||||
publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
if opts.PublishMessagePart == nil {
|
||||
return
|
||||
}
|
||||
@@ -265,7 +249,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
|
||||
for step := 0; totalSteps < opts.MaxSteps; step++ {
|
||||
totalSteps++
|
||||
stepStart := time.Now()
|
||||
// Copy messages so that provider-specific caching
|
||||
// mutations don't leak back to the caller's slice.
|
||||
// copy copies Message structs by value, so field
|
||||
@@ -332,9 +315,9 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
fantasy.MessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
})
|
||||
@@ -371,7 +354,6 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
Runtime: time.Since(stepStart),
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
@@ -473,7 +455,7 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
func processStepStream(
|
||||
ctx context.Context,
|
||||
stream fantasy.StreamResponse,
|
||||
publishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart),
|
||||
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
|
||||
) (stepResult, error) {
|
||||
var result stepResult
|
||||
|
||||
@@ -492,7 +474,10 @@ func processStepStream(
|
||||
if _, exists := activeTextContent[part.ID]; exists {
|
||||
activeTextContent[part.ID] += part.Delta
|
||||
}
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(part.Delta))
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: part.Delta,
|
||||
})
|
||||
|
||||
case fantasy.StreamPartTypeTextEnd:
|
||||
if text, exists := activeTextContent[part.ID]; exists {
|
||||
@@ -515,7 +500,10 @@ func processStepStream(
|
||||
active.options = part.ProviderMetadata
|
||||
activeReasoningContent[part.ID] = active
|
||||
}
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageReasoning(part.Delta))
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: part.Delta,
|
||||
})
|
||||
|
||||
case fantasy.StreamPartTypeReasoningEnd:
|
||||
if active, exists := activeReasoningContent[part.ID]; exists {
|
||||
@@ -547,7 +535,7 @@ func processStepStream(
|
||||
providerExecuted = toolCall.ProviderExecuted
|
||||
}
|
||||
toolName := toolNames[part.ID]
|
||||
publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: part.ID,
|
||||
ToolName: toolName,
|
||||
@@ -575,7 +563,7 @@ func processStepStream(
|
||||
delete(activeToolCalls, part.ID)
|
||||
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(tc),
|
||||
)
|
||||
|
||||
@@ -589,7 +577,7 @@ func processStepStream(
|
||||
}
|
||||
result.content = append(result.content, sourceContent)
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
fantasy.MessageRoleAssistant,
|
||||
chatprompt.PartFromContent(sourceContent),
|
||||
)
|
||||
|
||||
@@ -607,7 +595,7 @@ func processStepStream(
|
||||
}
|
||||
result.content = append(result.content, tr)
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
fantasy.MessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
}
|
||||
@@ -617,12 +605,10 @@ func processStepStream(
|
||||
result.providerMetadata = part.ProviderMetadata
|
||||
|
||||
case fantasy.StreamPartTypeError:
|
||||
// Detect interruption: the stream may surface the
|
||||
// cancel as context.Canceled or propagate the
|
||||
// ErrInterrupted cause directly, depending on
|
||||
// the provider implementation.
|
||||
if errors.Is(context.Cause(ctx), ErrInterrupted) &&
|
||||
(errors.Is(part.Error, context.Canceled) || errors.Is(part.Error, ErrInterrupted)) {
|
||||
// Detect interruption: context canceled with
|
||||
// ErrInterrupted as the cause.
|
||||
if errors.Is(part.Error, context.Canceled) &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
// Flush in-progress content so that
|
||||
// persistInterruptedStep has access to partial
|
||||
// text, reasoning, and tool calls that were
|
||||
@@ -640,23 +626,6 @@ func processStepStream(
|
||||
}
|
||||
}
|
||||
|
||||
// The stream iterator may stop yielding parts without
|
||||
// producing a StreamPartTypeError when the context is
|
||||
// canceled (e.g. some providers close the response body
|
||||
// silently). Detect this case and flush partial content
|
||||
// so that persistInterruptedStep can save it.
|
||||
if ctx.Err() != nil &&
|
||||
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
||||
flushActiveState(
|
||||
&result,
|
||||
activeTextContent,
|
||||
activeReasoningContent,
|
||||
activeToolCalls,
|
||||
toolNames,
|
||||
)
|
||||
return result, ErrInterrupted
|
||||
}
|
||||
|
||||
hasLocalToolCalls := false
|
||||
for _, tc := range result.toolCalls {
|
||||
if !tc.ProviderExecuted {
|
||||
@@ -676,7 +645,6 @@ func processStepStream(
|
||||
func executeTools(
|
||||
ctx context.Context,
|
||||
allTools []fantasy.AgentTool,
|
||||
providerTools []ProviderTool,
|
||||
toolCalls []fantasy.ToolCallContent,
|
||||
onResult func(fantasy.ToolResultContent),
|
||||
) []fantasy.ToolResultContent {
|
||||
@@ -702,13 +670,6 @@ func executeTools(
|
||||
for _, t := range allTools {
|
||||
toolMap[t.Info().Name] = t
|
||||
}
|
||||
// Include runners from provider tools so locally-executed
|
||||
// provider tools (e.g. computer use) can be dispatched.
|
||||
for _, pt := range providerTools {
|
||||
if pt.Runner != nil {
|
||||
toolMap[pt.Runner.Info().Name] = pt.Runner
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]fantasy.ToolResultContent, len(localToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
@@ -908,16 +869,15 @@ func persistInterruptedStep(
|
||||
// buildToolDefinitions converts AgentTool definitions into the
|
||||
// fantasy.Tool slice expected by fantasy.Call. When activeTools
|
||||
// is non-empty, only function tools whose name appears in the
|
||||
// list are included. Provider tool definitions are always
|
||||
// appended unconditionally.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []ProviderTool) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools)+len(providerTools))
|
||||
// list are included. Provider tools bypass this filter and are
|
||||
// always appended unconditionally.
|
||||
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, providerTools []fantasy.Tool) []fantasy.Tool {
|
||||
prepared := make([]fantasy.Tool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
inputSchema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": info.Parameters,
|
||||
@@ -931,9 +891,7 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi
|
||||
ProviderOptions: tool.ProviderOptions(),
|
||||
})
|
||||
}
|
||||
for _, pt := range providerTools {
|
||||
prepared = append(prepared, pt.Definition)
|
||||
}
|
||||
prepared = append(prepared, providerTools...)
|
||||
return prepared
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
@@ -65,8 +64,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
require.Equal(t, 1, persistStepCalls)
|
||||
require.True(t, persistedStep.ContextLimit.Valid)
|
||||
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
|
||||
require.Greater(t, persistedStep.Runtime, time.Duration(0),
|
||||
"step runtime should be positive")
|
||||
|
||||
require.NotEmpty(t, capturedCall.Prompt)
|
||||
require.False(t, containsPromptSentinel(capturedCall.Prompt))
|
||||
|
||||
@@ -55,7 +55,7 @@ type CompactionOptions struct {
|
||||
// PublishMessagePart publishes streaming parts to connected
|
||||
// clients so they see "Summarizing..." / "Summarized" UI
|
||||
// transitions during compaction.
|
||||
PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart)
|
||||
PublishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart)
|
||||
|
||||
OnError func(error)
|
||||
}
|
||||
@@ -110,8 +110,12 @@ func tryCompact(
|
||||
// connected clients see activity during summary generation.
|
||||
if config.PublishMessagePart != nil && config.ToolCallID != "" {
|
||||
config.PublishMessagePart(
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
codersdk.ChatMessageToolCall(config.ToolCallID, config.ToolName, nil),
|
||||
fantasy.MessageRoleAssistant,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -159,8 +163,13 @@ func tryCompact(
|
||||
"context_limit_tokens": contextLimit,
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false),
|
||||
fantasy.MessageRoleTool,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
Result: resultJSON,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -177,8 +186,14 @@ func publishCompactionError(config CompactionOptions, msg string) {
|
||||
"error": msg,
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true),
|
||||
fantasy.MessageRoleTool,
|
||||
codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: config.ToolCallID,
|
||||
ToolName: config.ToolName,
|
||||
Result: errJSON,
|
||||
IsError: true,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
SummaryPrompt: "summarize now",
|
||||
ToolCallID: "test-tool-call-id",
|
||||
ToolName: "chat_summarized",
|
||||
PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) {
|
||||
PublishMessagePart: func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeToolCall:
|
||||
callOrder = append(callOrder, "publish_tool_call")
|
||||
@@ -218,7 +218,7 @@ func TestRun_Compaction(t *testing.T) {
|
||||
ThresholdPercent: 70,
|
||||
ToolCallID: "test-tool-call-id",
|
||||
ToolName: "chat_summarized",
|
||||
PublishMessagePart: func(_ codersdk.ChatMessageRole, _ codersdk.ChatMessagePart) {
|
||||
PublishMessagePart: func(_ fantasy.MessageRole, _ codersdk.ChatMessagePart) {
|
||||
publishCalled = true
|
||||
},
|
||||
Persist: func(_ context.Context, _ CompactionResult) error {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -553,6 +553,34 @@ func normalizedEnumValue(value string, allowed ...string) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeMissingCallConfig fills unset call config values from a provider or
|
||||
// profile default config.
|
||||
func MergeMissingCallConfig(
|
||||
dst *codersdk.ChatModelCallConfig,
|
||||
defaults codersdk.ChatModelCallConfig,
|
||||
) {
|
||||
if dst.MaxOutputTokens == nil {
|
||||
dst.MaxOutputTokens = defaults.MaxOutputTokens
|
||||
}
|
||||
if dst.Temperature == nil {
|
||||
dst.Temperature = defaults.Temperature
|
||||
}
|
||||
if dst.TopP == nil {
|
||||
dst.TopP = defaults.TopP
|
||||
}
|
||||
if dst.TopK == nil {
|
||||
dst.TopK = defaults.TopK
|
||||
}
|
||||
if dst.PresencePenalty == nil {
|
||||
dst.PresencePenalty = defaults.PresencePenalty
|
||||
}
|
||||
if dst.FrequencyPenalty == nil {
|
||||
dst.FrequencyPenalty = defaults.FrequencyPenalty
|
||||
}
|
||||
MergeMissingModelCostConfig(&dst.Cost, defaults.Cost)
|
||||
MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions)
|
||||
}
|
||||
|
||||
// MergeMissingModelCostConfig fills unset pricing metadata from defaults.
|
||||
func MergeMissingModelCostConfig(
|
||||
dst **codersdk.ModelCostConfig,
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
|
||||
options := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(true),
|
||||
},
|
||||
Provider: &codersdk.ChatModelOpenRouterProvider{
|
||||
@@ -92,7 +92,7 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
}
|
||||
defaults := &codersdk.ChatModelProviderOptions{
|
||||
OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{
|
||||
Reasoning: &codersdk.ChatModelReasoningOptions{
|
||||
Reasoning: &codersdk.ChatModelOpenRouterReasoningOptions{
|
||||
Enabled: boolPtr(false),
|
||||
Exclude: boolPtr(true),
|
||||
MaxTokens: int64Ptr(123),
|
||||
@@ -137,6 +137,61 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := codersdk.ChatModelCallConfig{
|
||||
Temperature: float64Ptr(0.2),
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: float64Ptr(0.7),
|
||||
},
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("alice"),
|
||||
},
|
||||
},
|
||||
}
|
||||
defaultCallConfig := codersdk.ChatModelCallConfig{
|
||||
MaxOutputTokens: int64Ptr(512),
|
||||
Temperature: float64Ptr(0.9),
|
||||
TopP: float64Ptr(0.8),
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: float64Ptr(0.15),
|
||||
OutputPricePerMillionTokens: float64Ptr(0.9),
|
||||
CacheReadPricePerMillionTokens: float64Ptr(0.03),
|
||||
CacheWritePricePerMillionTokens: float64Ptr(0.3),
|
||||
},
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("bob"),
|
||||
ReasoningEffort: stringPtr("medium"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingCallConfig(&dst, defaultCallConfig)
|
||||
|
||||
require.NotNil(t, dst.MaxOutputTokens)
|
||||
require.EqualValues(t, 512, *dst.MaxOutputTokens)
|
||||
require.NotNil(t, dst.Temperature)
|
||||
require.Equal(t, 0.2, *dst.Temperature)
|
||||
require.NotNil(t, dst.TopP)
|
||||
require.Equal(t, 0.8, *dst.TopP)
|
||||
require.NotNil(t, dst.Cost)
|
||||
require.NotNil(t, dst.Cost.InputPricePerMillionTokens)
|
||||
require.Equal(t, 0.15, *dst.Cost.InputPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.OutputPricePerMillionTokens)
|
||||
require.Equal(t, 0.7, *dst.Cost.OutputPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.CacheReadPricePerMillionTokens)
|
||||
require.Equal(t, 0.03, *dst.Cost.CacheReadPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.CacheWritePricePerMillionTokens)
|
||||
require.Equal(t, 0.3, *dst.Cost.CacheWritePricePerMillionTokens)
|
||||
require.NotNil(t, dst.ProviderOptions)
|
||||
require.NotNil(t, dst.ProviderOptions.OpenAI)
|
||||
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
|
||||
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
@@ -148,3 +203,7 @@ func boolPtr(value bool) *bool {
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@@ -96,7 +96,6 @@ type AnthropicDeltaBlock struct {
|
||||
// anthropicServer is a test server that mocks the Anthropic API.
|
||||
type anthropicServer struct {
|
||||
mu sync.Mutex
|
||||
t testing.TB
|
||||
server *httptest.Server
|
||||
handler AnthropicHandler
|
||||
request *AnthropicRequest
|
||||
@@ -110,7 +109,6 @@ func NewAnthropic(t testing.TB, handler AnthropicHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &anthropicServer{
|
||||
t: t,
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
@@ -145,7 +143,7 @@ func (s *anthropicServer) handleMessages(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
func (s *anthropicServer) writeResponse(w http.ResponseWriter, req *AnthropicRequest, resp AnthropicResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(s.t, w, resp.Error)
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -225,6 +223,7 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
|
||||
}
|
||||
|
||||
func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp *AnthropicMessage) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
response := map[string]interface{}{
|
||||
"id": resp.ID,
|
||||
"type": resp.Type,
|
||||
@@ -242,9 +241,7 @@ func (s *anthropicServer) writeNonStreamingResponse(w http.ResponseWriter, resp
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("anthropic-version", "2023-06-01")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
s.t.Errorf("writeNonStreamingResponse: failed to encode response: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// AnthropicStreamingResponse creates a streaming response from chunks.
|
||||
|
||||
@@ -3,7 +3,6 @@ package chattest
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ErrorResponse describes an HTTP error that a test server should return
|
||||
@@ -16,7 +15,7 @@ type ErrorResponse struct {
|
||||
|
||||
// writeErrorResponse writes a JSON error response matching the common
|
||||
// provider error format used by both Anthropic and OpenAI.
|
||||
func writeErrorResponse(t testing.TB, w http.ResponseWriter, errResp *ErrorResponse) {
|
||||
func writeErrorResponse(w http.ResponseWriter, errResp *ErrorResponse) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(errResp.StatusCode)
|
||||
body := map[string]interface{}{
|
||||
@@ -25,9 +24,7 @@ func writeErrorResponse(t testing.TB, w http.ResponseWriter, errResp *ErrorRespo
|
||||
"message": errResp.Message,
|
||||
},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(body); err != nil {
|
||||
t.Errorf("writeErrorResponse: failed to encode error response: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
|
||||
// AnthropicErrorResponse returns an AnthropicResponse that causes the
|
||||
|
||||
@@ -113,7 +113,6 @@ type OpenAICompletion struct {
|
||||
// openAIServer is a test server that mocks the OpenAI API.
|
||||
type openAIServer struct {
|
||||
mu sync.Mutex
|
||||
t testing.TB
|
||||
server *httptest.Server
|
||||
handler OpenAIHandler
|
||||
request *OpenAIRequest
|
||||
@@ -127,7 +126,6 @@ func NewOpenAI(t testing.TB, handler OpenAIHandler) string {
|
||||
t.Helper()
|
||||
|
||||
s := &openAIServer{
|
||||
t: t,
|
||||
handler: handler,
|
||||
}
|
||||
|
||||
@@ -178,7 +176,7 @@ func (s *openAIServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(s.t, w, resp.Error)
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,7 +205,7 @@ func (s *openAIServer) writeChatCompletionsResponse(w http.ResponseWriter, req *
|
||||
|
||||
func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *OpenAIRequest, resp OpenAIResponse) {
|
||||
if resp.Error != nil {
|
||||
writeErrorResponse(s.t, w, resp.Error)
|
||||
writeErrorResponse(w, resp.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -228,7 +226,7 @@ func (s *openAIServer) writeResponsesAPIResponse(w http.ResponseWriter, req *Ope
|
||||
http.Error(w, "handler returned streaming response for non-streaming request", http.StatusInternalServerError)
|
||||
return
|
||||
case hasStreaming:
|
||||
writeResponsesAPIStreaming(s.t, w, req.Request, resp.StreamingChunks)
|
||||
writeResponsesAPIStreaming(w, req.Request, resp.StreamingChunks)
|
||||
default:
|
||||
s.writeResponsesAPINonStreaming(w, resp.Response)
|
||||
}
|
||||
@@ -320,7 +318,7 @@ func writeSSEEvent(w http.ResponseWriter, v interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
|
||||
func writeResponsesAPIStreaming(w http.ResponseWriter, r *http.Request, chunks <-chan OpenAIChunk) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
@@ -347,28 +345,19 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
|
||||
// the fantasy client closes open text
|
||||
// blocks and persists the step content.
|
||||
for outputIndex, itemID := range itemIDs {
|
||||
if err := writeSSEEvent(w, responses.ResponseTextDoneEvent{
|
||||
_ = writeSSEEvent(w, responses.ResponseTextDoneEvent{
|
||||
ItemID: itemID,
|
||||
OutputIndex: int64(outputIndex),
|
||||
}); err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to write ResponseTextDoneEvent: %v", err)
|
||||
return
|
||||
}
|
||||
if err := writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
|
||||
})
|
||||
_ = writeSSEEvent(w, responses.ResponseOutputItemDoneEvent{
|
||||
OutputIndex: int64(outputIndex),
|
||||
Item: responses.ResponseOutputItemUnion{
|
||||
ID: itemID,
|
||||
Type: "message",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemDoneEvent: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := writeSSEEvent(w, responses.ResponseCompletedEvent{}); err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to write ResponseCompletedEvent: %v", err)
|
||||
return
|
||||
})
|
||||
}
|
||||
_ = writeSSEEvent(w, responses.ResponseCompletedEvent{})
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
@@ -393,7 +382,6 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
|
||||
Type: "message",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to write ResponseOutputItemAddedEvent: %v", err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -411,12 +399,10 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
|
||||
|
||||
chunkBytes, err := json.Marshal(chunkData)
|
||||
if err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to marshal chunk data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "data: %s\n\n", chunkBytes); err != nil {
|
||||
t.Logf("writeResponsesAPIStreaming: failed to write chunk data: %v", err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
@@ -425,13 +411,13 @@ func writeResponsesAPIStreaming(t testing.TB, w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeChatCompletionsNonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.t.Errorf("writeChatCompletionsNonStreaming: failed to encode response: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp *OpenAICompletion) {
|
||||
_ = s // receiver unused but kept for consistency
|
||||
// Convert all choices to output format
|
||||
outputs := make([]map[string]interface{}, len(resp.Choices))
|
||||
for i, choice := range resp.Choices {
|
||||
@@ -457,9 +443,7 @@ func (s *openAIServer) writeResponsesAPINonStreaming(w http.ResponseWriter, resp
|
||||
"usage": resp.Usage,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
s.t.Errorf("writeResponsesAPINonStreaming: failed to encode response: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// OpenAIStreamingResponse creates a streaming response from chunks.
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// ComputerUseModelProvider is the provider for the computer
|
||||
// use model.
|
||||
ComputerUseModelProvider = "anthropic"
|
||||
// ComputerUseModelName is the model used for computer use
|
||||
// subagents.
|
||||
ComputerUseModelName = "claude-opus-4-6"
|
||||
)
|
||||
|
||||
// computerUseTool implements fantasy.AgentTool and
|
||||
// chatloop.ToolDefiner for Anthropic computer use.
|
||||
type computerUseTool struct {
|
||||
displayWidth int
|
||||
displayHeight int
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error)
|
||||
providerOptions fantasy.ProviderOptions
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewComputerUseTool creates a computer use AgentTool that
|
||||
// delegates to the agent's desktop endpoints.
|
||||
func NewComputerUseTool(
|
||||
displayWidth, displayHeight int,
|
||||
getWorkspaceConn func(ctx context.Context) (workspacesdk.AgentConn, error),
|
||||
clock quartz.Clock,
|
||||
) fantasy.AgentTool {
|
||||
return &computerUseTool{
|
||||
displayWidth: displayWidth,
|
||||
displayHeight: displayHeight,
|
||||
getWorkspaceConn: getWorkspaceConn,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
func (*computerUseTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: "computer",
|
||||
Description: "Control the desktop: take screenshots, move the mouse, click, type, and scroll.",
|
||||
Parameters: map[string]any{},
|
||||
Required: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// ComputerUseProviderTool creates the provider-defined tool
|
||||
// definition for Anthropic computer use. This is passed via
|
||||
// ProviderTools so the API receives the correct wire format.
|
||||
func ComputerUseProviderTool(displayWidth, displayHeight int) fantasy.Tool {
|
||||
return fantasyanthropic.NewComputerUseTool(
|
||||
fantasyanthropic.ComputerUseToolOptions{
|
||||
DisplayWidthPx: int64(displayWidth),
|
||||
DisplayHeightPx: int64(displayHeight),
|
||||
ToolVersion: fantasyanthropic.ComputerUse20251124,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (t *computerUseTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
func (t *computerUseTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
|
||||
func (t *computerUseTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
input, err := fantasyanthropic.ParseComputerUseInput(call.Input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("invalid computer use input: %v", err),
|
||||
), nil
|
||||
}
|
||||
|
||||
conn, err := t.getWorkspaceConn(ctx)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("failed to connect to workspace: %v", err),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Compute scaled screenshot size for Anthropic constraints.
|
||||
scaledW, scaledH := computeScaledScreenshotSize(
|
||||
t.displayWidth, t.displayHeight,
|
||||
)
|
||||
|
||||
// For wait actions, sleep then return a screenshot.
|
||||
if input.Action == fantasyanthropic.ActionWait {
|
||||
d := input.Duration
|
||||
if d <= 0 {
|
||||
d = 1000
|
||||
}
|
||||
timer := t.clock.NewTimer(time.Duration(d)*time.Millisecond, "computeruse", "wait")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timer.C:
|
||||
}
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// For screenshot action, use ExecuteDesktopAction.
|
||||
if input.Action == fantasyanthropic.ActionScreenshot {
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// Build the action request.
|
||||
action := workspacesdk.DesktopAction{
|
||||
Action: string(input.Action),
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
if input.Coordinate != ([2]int64{}) {
|
||||
coord := [2]int{int(input.Coordinate[0]), int(input.Coordinate[1])}
|
||||
action.Coordinate = &coord
|
||||
}
|
||||
if input.StartCoordinate != ([2]int64{}) {
|
||||
coord := [2]int{int(input.StartCoordinate[0]), int(input.StartCoordinate[1])}
|
||||
action.StartCoordinate = &coord
|
||||
}
|
||||
if input.Text != "" {
|
||||
action.Text = &input.Text
|
||||
}
|
||||
if input.Duration > 0 {
|
||||
d := int(input.Duration)
|
||||
action.Duration = &d
|
||||
}
|
||||
if input.ScrollAmount > 0 {
|
||||
s := int(input.ScrollAmount)
|
||||
action.ScrollAmount = &s
|
||||
}
|
||||
if input.ScrollDirection != "" {
|
||||
action.ScrollDirection = &input.ScrollDirection
|
||||
}
|
||||
|
||||
// Execute the action.
|
||||
_, err = conn.ExecuteDesktopAction(ctx, action)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("action %q failed: %v", input.Action, err),
|
||||
), nil
|
||||
}
|
||||
|
||||
// Take a screenshot after every action (Anthropic pattern).
|
||||
screenshotAction := workspacesdk.DesktopAction{
|
||||
Action: "screenshot",
|
||||
ScaledWidth: &scaledW,
|
||||
ScaledHeight: &scaledH,
|
||||
}
|
||||
screenResp, sErr := conn.ExecuteDesktopAction(ctx, screenshotAction)
|
||||
if sErr != nil {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("screenshot failed: %v", sErr),
|
||||
), nil
|
||||
}
|
||||
|
||||
return fantasy.NewImageResponse(
|
||||
[]byte(screenResp.ScreenshotData), "image/png",
|
||||
), nil
|
||||
}
|
||||
|
||||
// computeScaledScreenshotSize computes the target screenshot
|
||||
// dimensions to fit within Anthropic's constraints.
|
||||
func computeScaledScreenshotSize(width, height int) (scaledWidth int, scaledHeight int) {
|
||||
const maxLongEdge = 1568
|
||||
const maxTotalPixels = 1_150_000
|
||||
|
||||
longEdge := max(width, height)
|
||||
totalPixels := width * height
|
||||
longEdgeScale := float64(maxLongEdge) / float64(longEdge)
|
||||
totalPixelsScale := math.Sqrt(
|
||||
float64(maxTotalPixels) / float64(totalPixels),
|
||||
)
|
||||
scale := min(1.0, longEdgeScale, totalPixelsScale)
|
||||
|
||||
if scale >= 1.0 {
|
||||
return width, height
|
||||
}
|
||||
return max(1, int(float64(width)*scale)),
|
||||
max(1, int(float64(height)*scale))
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package chattool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestComputeScaledScreenshotSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
width, height int
|
||||
wantW, wantH int
|
||||
}{
|
||||
{
|
||||
name: "1920x1080_scales_down",
|
||||
width: 1920,
|
||||
height: 1080,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
},
|
||||
{
|
||||
name: "1280x800_no_scaling",
|
||||
width: 1280,
|
||||
height: 800,
|
||||
wantW: 1280,
|
||||
wantH: 800,
|
||||
},
|
||||
{
|
||||
name: "3840x2160_large_display",
|
||||
width: 3840,
|
||||
height: 2160,
|
||||
wantW: 1429,
|
||||
wantH: 804,
|
||||
},
|
||||
{
|
||||
name: "1568x1000_pixel_cap_applies",
|
||||
width: 1568,
|
||||
height: 1000,
|
||||
wantW: 1342,
|
||||
wantH: 856,
|
||||
},
|
||||
{
|
||||
name: "100x100_small_display",
|
||||
width: 100,
|
||||
height: 100,
|
||||
wantW: 100,
|
||||
wantH: 100,
|
||||
},
|
||||
{
|
||||
name: "4000x3000_stays_within_limits",
|
||||
width: 4000,
|
||||
// Both constraints apply. The function should keep
|
||||
// the result within maxLongEdge=1568 and
|
||||
// totalPixels<=1,150,000.
|
||||
height: 3000,
|
||||
wantW: 1238,
|
||||
wantH: 928,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
gotW, gotH := computeScaledScreenshotSize(tt.width, tt.height)
|
||||
assert.Equal(t, tt.wantW, gotW)
|
||||
assert.Equal(t, tt.wantH, gotH)
|
||||
|
||||
// Invariant: results must respect Anthropic constraints.
|
||||
const maxLongEdge = 1568
|
||||
const maxTotalPixels = 1_150_000
|
||||
longEdge := max(gotW, gotH)
|
||||
assert.LessOrEqual(t, longEdge, maxLongEdge,
|
||||
"long edge %d exceeds max %d", longEdge, maxLongEdge)
|
||||
assert.LessOrEqual(t, gotW*gotH, maxTotalPixels,
|
||||
"total pixels %d exceeds max %d", gotW*gotH, maxTotalPixels)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
package chattool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestComputerUseTool_Info(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, nil, quartz.NewReal())
|
||||
info := tool.Info()
|
||||
assert.Equal(t, "computer", info.Name)
|
||||
assert.NotEmpty(t, info.Description)
|
||||
}
|
||||
|
||||
func TestComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
def := chattool.ComputerUseProviderTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight)
|
||||
pdt, ok := def.(fantasy.ProviderDefinedTool)
|
||||
require.True(t, ok, "ComputerUseProviderTool should return a ProviderDefinedTool")
|
||||
assert.Contains(t, pdt.ID, "computer")
|
||||
assert.Equal(t, "computer", pdt.Name)
|
||||
// Verify display dimensions are passed through.
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayWidth), pdt.Args["display_width_px"])
|
||||
assert.Equal(t, int64(workspacesdk.DesktopDisplayHeight), pdt.Args["display_height_px"])
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Screenshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "base64png",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-1",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("base64png"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_LeftClick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
|
||||
// Expect the action call first.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "left_click performed",
|
||||
}, nil)
|
||||
|
||||
// Then expect a screenshot (auto-screenshot after action).
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-click",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-2",
|
||||
Name: "computer",
|
||||
Input: `{"action":"left_click","coordinate":[100,200]}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, []byte("after-click"), resp.Data)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_Wait(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockConn := agentconnmock.NewMockAgentConn(ctrl)
|
||||
// Expect a screenshot after the wait completes.
|
||||
mockConn.EXPECT().ExecuteDesktopAction(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).Return(workspacesdk.DesktopActionResponse{
|
||||
Output: "screenshot",
|
||||
ScreenshotData: "after-wait",
|
||||
ScreenshotWidth: 1024,
|
||||
ScreenshotHeight: 768,
|
||||
}, nil)
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return mockConn, nil
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-3",
|
||||
Name: "computer",
|
||||
Input: `{"action":"wait","duration":10}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "image", resp.Type)
|
||||
assert.Equal(t, "image/png", resp.MediaType)
|
||||
assert.Equal(t, []byte("after-wait"), resp.Data)
|
||||
assert.False(t, resp.IsError)
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_ConnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("workspace not available")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-4",
|
||||
Name: "computer",
|
||||
Input: `{"action":"screenshot"}`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "workspace not available")
|
||||
}
|
||||
|
||||
func TestComputerUseTool_Run_InvalidInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := chattool.NewComputerUseTool(workspacesdk.DesktopDisplayWidth, workspacesdk.DesktopDisplayHeight, func(_ context.Context) (workspacesdk.AgentConn, error) {
|
||||
return nil, xerrors.New("should not be called")
|
||||
}, quartz.NewReal())
|
||||
|
||||
call := fantasy.ToolCall{
|
||||
ID: "test-5",
|
||||
Name: "computer",
|
||||
Input: `{invalid json`,
|
||||
}
|
||||
|
||||
resp, err := tool.Run(context.Background(), call)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsError)
|
||||
assert.Contains(t, resp.Content, "invalid computer use input")
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/util/namesgenerator"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -68,7 +67,6 @@ type CreateWorkspaceOptions struct {
|
||||
CreateFn CreateWorkspaceFn
|
||||
AgentConnFn AgentConnFunc
|
||||
WorkspaceMu *sync.Mutex
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
type createWorkspaceArgs struct {
|
||||
@@ -194,19 +192,13 @@ func CreateWorkspace(options CreateWorkspaceOptions) fantasy.AgentTool {
|
||||
|
||||
// Persist workspace + agent association on the chat.
|
||||
if options.DB != nil && options.ChatID != uuid.Nil {
|
||||
if _, err := options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
_, _ = options.DB.UpdateChatWorkspace(ctx, database.UpdateChatWorkspaceParams{
|
||||
ID: options.ChatID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
options.Logger.Error(ctx, "failed to persist chat workspace association",
|
||||
slog.F("chat_id", options.ChatID),
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for the agent to come online and startup scripts to finish.
|
||||
|
||||
@@ -3,14 +3,12 @@ package chattool
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
@@ -78,10 +76,10 @@ type ProcessToolOptions struct {
|
||||
|
||||
// ExecuteArgs are the parameters accepted by the execute tool.
|
||||
type ExecuteArgs struct {
|
||||
Command string `json:"command" description:"The shell command to execute."`
|
||||
Timeout *string `json:"timeout,omitempty" description:"Timeout duration (e.g. '30s', '5m'). Default is 10s. Only applies to foreground commands."`
|
||||
WorkDir *string `json:"workdir,omitempty" description:"Working directory for the command."`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty" description:"Run this command in the background without blocking. Use for long-running processes like dev servers, file watchers, or builds that run longer than 5 seconds. Do NOT use shell & to background processes — it will not work correctly. Always use this parameter instead."`
|
||||
Command string `json:"command"`
|
||||
Timeout *string `json:"timeout,omitempty"`
|
||||
WorkDir *string `json:"workdir,omitempty"`
|
||||
RunInBackground *bool `json:"run_in_background,omitempty"`
|
||||
}
|
||||
|
||||
// Execute returns an AgentTool that runs a shell command in the
|
||||
@@ -89,7 +87,7 @@ type ExecuteArgs struct {
|
||||
func Execute(options ExecuteOptions) fantasy.AgentTool {
|
||||
return fantasy.NewAgentTool(
|
||||
"execute",
|
||||
"Execute a shell command in the workspace. Use run_in_background=true for long-running processes (dev servers, file watchers, builds). Never use shell '&' for backgrounding.",
|
||||
"Execute a shell command in the workspace.",
|
||||
func(ctx context.Context, args ExecuteArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if options.GetWorkspaceConn == nil {
|
||||
return fantasy.NewTextErrorResponse("workspace connection resolver is not configured"), nil
|
||||
@@ -122,16 +120,6 @@ func executeTool(
|
||||
|
||||
background := args.RunInBackground != nil && *args.RunInBackground
|
||||
|
||||
// Detect shell-style backgrounding (trailing &) and promote to
|
||||
// background mode. Models sometimes use "cmd &" instead of the
|
||||
// run_in_background parameter, which causes the shell to fork
|
||||
// and exit immediately, leaving an untracked orphan process.
|
||||
trimmed := strings.TrimSpace(args.Command)
|
||||
if !background && strings.HasSuffix(trimmed, "&") && !strings.HasSuffix(trimmed, "&&") {
|
||||
background = true
|
||||
args.Command = strings.TrimSpace(strings.TrimSuffix(trimmed, "&"))
|
||||
}
|
||||
|
||||
var workDir string
|
||||
if args.WorkDir != nil {
|
||||
workDir = *args.WorkDir
|
||||
@@ -257,18 +245,14 @@ func pollProcess(
|
||||
context.Background(),
|
||||
5*time.Second,
|
||||
)
|
||||
outputResp, outputErr := conn.ProcessOutput(bgCtx, processID)
|
||||
outputResp, _ := conn.ProcessOutput(bgCtx, processID)
|
||||
bgCancel()
|
||||
output := truncateOutput(outputResp.Output)
|
||||
timeoutErr := xerrors.Errorf("command timed out after %s", timeout)
|
||||
if outputErr != nil {
|
||||
timeoutErr = errors.Join(timeoutErr, xerrors.Errorf("failed to get output: %w", outputErr))
|
||||
}
|
||||
return ExecuteResult{
|
||||
Success: false,
|
||||
Output: output,
|
||||
ExitCode: -1,
|
||||
Error: timeoutErr.Error(),
|
||||
Error: fmt.Sprintf("command timed out after %s", timeout),
|
||||
Truncated: outputResp.Truncated,
|
||||
}
|
||||
case <-ticker.C:
|
||||
|
||||
@@ -92,19 +92,17 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
|
||||
// Verify the chat completed and messages were persisted.
|
||||
chatData, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 1: %s, messages: %d",
|
||||
chatData.Status, len(chatMsgs.Messages))
|
||||
logMessages(t, chatMsgs.Messages)
|
||||
chatData.Chat.Status, len(chatData.Messages))
|
||||
logMessages(t, chatData.Messages)
|
||||
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Status,
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData.Chat.Status,
|
||||
"chat should be in waiting status after step 1")
|
||||
|
||||
// Find the first assistant message and verify it has the
|
||||
// content parts the UI needs to render web search results:
|
||||
// tool-call(PE), source, tool-result(PE), and text.
|
||||
assistantMsg := findAssistantWithText(t, chatMsgs.Messages)
|
||||
assistantMsg := findAssistantWithText(t, chatData.Messages)
|
||||
require.NotNil(t, assistantMsg,
|
||||
"expected an assistant message with text content after step 1")
|
||||
|
||||
@@ -154,19 +152,17 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
|
||||
// Verify the follow-up completed and produced content.
|
||||
chatData2, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
chatMsgs2, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Chat status after step 2: %s, messages: %d",
|
||||
chatData2.Status, len(chatMsgs2.Messages))
|
||||
logMessages(t, chatMsgs2.Messages)
|
||||
chatData2.Chat.Status, len(chatData2.Messages))
|
||||
logMessages(t, chatData2.Messages)
|
||||
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Status,
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, chatData2.Chat.Status,
|
||||
"chat should be in waiting status after step 2")
|
||||
require.Greater(t, len(chatMsgs2.Messages), len(chatMsgs.Messages),
|
||||
require.Greater(t, len(chatData2.Messages), len(chatData.Messages),
|
||||
"follow-up should have added more messages")
|
||||
|
||||
// The last assistant message should have text.
|
||||
lastAssistant := findLastAssistantWithText(t, chatMsgs2.Messages)
|
||||
lastAssistant := findLastAssistantWithText(t, chatData2.Messages)
|
||||
require.NotNil(t, lastAssistant,
|
||||
"expected an assistant message with text in the follow-up")
|
||||
|
||||
|
||||
+17
-23
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
const titleGenerationPrompt = "You are a title generator. Your ONLY job is to output a short title (2-8 words) " +
|
||||
@@ -62,7 +61,6 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
) {
|
||||
input, ok := titleInput(chat, messages)
|
||||
@@ -112,8 +110,7 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
return
|
||||
}
|
||||
chat.Title = title
|
||||
generatedTitle.Store(title)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindTitleChange)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -161,12 +158,14 @@ func titleInput(
|
||||
}
|
||||
|
||||
switch message.Role {
|
||||
case database.ChatMessageRoleAssistant, database.ChatMessageRoleTool:
|
||||
case string(fantasy.MessageRoleAssistant), string(fantasy.MessageRoleTool):
|
||||
return "", false
|
||||
case database.ChatMessageRoleUser:
|
||||
case string(fantasy.MessageRoleUser):
|
||||
userCount++
|
||||
if firstUserText == "" {
|
||||
parsed, err := chatprompt.ParseContent(message)
|
||||
parsed, err := chatprompt.ParseContent(
|
||||
string(fantasy.MessageRoleUser), message.Content,
|
||||
)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
@@ -227,21 +226,22 @@ func fallbackChatTitle(message string) string {
|
||||
return truncateRunes(title, maxRunes)
|
||||
}
|
||||
|
||||
// contentBlocksToText concatenates the text parts of SDK chat
|
||||
// message parts into a single space-separated string.
|
||||
func contentBlocksToText(parts []codersdk.ChatMessagePart) string {
|
||||
texts := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeText {
|
||||
// contentBlocksToText concatenates the text parts of content blocks
|
||||
// into a single space-separated string.
|
||||
func contentBlocksToText(content []fantasy.Content) string {
|
||||
parts := make([]string, 0, len(content))
|
||||
for _, block := range content {
|
||||
textBlock, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
text := strings.TrimSpace(part.Text)
|
||||
text := strings.TrimSpace(textBlock.Text)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
texts = append(texts, text)
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.Join(texts, " ")
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func truncateRunes(value string, maxLen int) string {
|
||||
@@ -343,13 +343,7 @@ func generateShortText(
|
||||
return "", xerrors.Errorf("generate short text: %w", err)
|
||||
}
|
||||
|
||||
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
|
||||
for _, block := range response.Content {
|
||||
if p := chatprompt.PartFromContent(block); p.Type != "" {
|
||||
responseParts = append(responseParts, p)
|
||||
}
|
||||
}
|
||||
text := strings.TrimSpace(contentBlocksToText(responseParts))
|
||||
text := strings.TrimSpace(contentBlocksToText(response.Content))
|
||||
text = strings.Trim(text, "\"'`")
|
||||
return text, nil
|
||||
}
|
||||
|
||||
+6
-135
@@ -13,10 +13,8 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat")
|
||||
@@ -27,30 +25,11 @@ const (
|
||||
defaultSubagentWaitTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// computerUseSubagentSystemPrompt is the system prompt prepended to
|
||||
// every computer use subagent chat. It instructs the model on how to
|
||||
// interact with the desktop environment via the computer tool.
|
||||
const computerUseSubagentSystemPrompt = `You are a computer use agent with access to a desktop environment. You can see the screen, move the mouse, click, type, scroll, and drag.
|
||||
|
||||
Your primary tool is the "computer" tool which lets you interact with the desktop. After every action you take, you will receive a screenshot showing the current state of the screen. Use these screenshots to verify your actions and plan next steps.
|
||||
|
||||
Guidelines:
|
||||
- Always start by taking a screenshot to see the current state of the desktop.
|
||||
- Be precise with coordinates when clicking or typing.
|
||||
- Wait for UI elements to load before interacting with them.
|
||||
- If an action doesn't produce the expected result, try alternative approaches.
|
||||
- Report what you accomplished when done.`
|
||||
|
||||
type spawnAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type spawnComputerUseAgentArgs struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title,omitempty"`
|
||||
}
|
||||
|
||||
type waitAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
TimeoutSeconds *int `json:"timeout_seconds,omitempty"`
|
||||
@@ -66,34 +45,8 @@ type closeAgentArgs struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
}
|
||||
|
||||
// isAnthropicConfigured reports whether an Anthropic API key is
|
||||
// available, either from static provider keys or from the database.
|
||||
func (p *Server) isAnthropicConfigured(ctx context.Context) bool {
|
||||
if p.providerAPIKeys.APIKey("anthropic") != "" {
|
||||
return true
|
||||
}
|
||||
dbProviders, err := p.db.GetEnabledChatProviders(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, prov := range dbProviders {
|
||||
if chatprovider.NormalizeProvider(prov.Provider) == "anthropic" && strings.TrimSpace(prov.APIKey) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Server) isDesktopEnabled(ctx context.Context) bool {
|
||||
enabled, err := p.db.GetChatDesktopEnabled(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (p *Server) subagentTools(ctx context.Context, currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
tools := []fantasy.AgentTool{
|
||||
func (p *Server) subagentTools(currentChat func() database.Chat) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
fantasy.NewAgentTool(
|
||||
"spawn_agent",
|
||||
"Spawn a delegated child agent to work on a clearly scoped, "+
|
||||
@@ -259,88 +212,6 @@ func (p *Server) subagentTools(ctx context.Context, currentChat func() database.
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
// Only include the computer use tool when an Anthropic
|
||||
// provider is configured and desktop is enabled.
|
||||
if p.isAnthropicConfigured(ctx) && p.isDesktopEnabled(ctx) {
|
||||
tools = append(tools, fantasy.NewAgentTool(
|
||||
"spawn_computer_use_agent",
|
||||
"Spawn a dedicated computer use agent that can see the desktop "+
|
||||
"(take screenshots) and interact with it (mouse, keyboard, "+
|
||||
"scroll). The agent runs on a model optimized for computer "+
|
||||
"use and has the same workspace tools as a standard subagent "+
|
||||
"plus the native Anthropic computer tool. Use this for tasks "+
|
||||
"that require visual interaction with a desktop GUI (e.g. "+
|
||||
"browser automation, GUI testing, visual inspection). After "+
|
||||
"spawning, use wait_agent to collect the result.",
|
||||
func(ctx context.Context, args spawnComputerUseAgentArgs, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if currentChat == nil {
|
||||
return fantasy.NewTextErrorResponse("subagent callbacks are not configured"), nil
|
||||
}
|
||||
|
||||
parent := currentChat()
|
||||
if parent.ParentChatID.Valid {
|
||||
return fantasy.NewTextErrorResponse("delegated chats cannot create child subagents"), nil
|
||||
}
|
||||
|
||||
parent, err := p.db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(args.Prompt)
|
||||
if prompt == "" {
|
||||
return fantasy.NewTextErrorResponse("prompt is required"), nil
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(args.Title)
|
||||
if title == "" {
|
||||
title = subagentFallbackChatTitle(prompt)
|
||||
}
|
||||
|
||||
rootChatID := parent.ID
|
||||
if parent.RootChatID.Valid {
|
||||
rootChatID = parent.RootChatID.UUID
|
||||
}
|
||||
if parent.LastModelConfigID == uuid.Nil {
|
||||
return fantasy.NewTextErrorResponse("parent chat model config id is required"), nil
|
||||
}
|
||||
|
||||
// Create the child chat with Mode set to
|
||||
// computer_use. This signals runChat to use the
|
||||
// predefined computer use model and include the
|
||||
// computer tool.
|
||||
childChat, err := p.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
WorkspaceID: parent.WorkspaceID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: rootChatID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: computerUseSubagentSystemPrompt + "\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
return toolJSONResponse(map[string]any{
|
||||
"chat_id": childChat.ID.String(),
|
||||
"title": childChat.Title,
|
||||
"status": string(childChat.Status),
|
||||
}), nil
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
func parseSubagentToolChatID(raw string) (uuid.UUID, error) {
|
||||
@@ -392,7 +263,7 @@ func (p *Server) createChildSubagentChat(
|
||||
},
|
||||
ModelConfigID: parent.LastModelConfigID,
|
||||
Title: title,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: prompt}},
|
||||
})
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("create child chat: %w", err)
|
||||
@@ -430,7 +301,7 @@ func (p *Server) sendSubagentMessage(
|
||||
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
|
||||
ChatID: targetChatID,
|
||||
CreatedBy: targetChat.OwnerID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText(message)},
|
||||
Content: []fantasy.Content{fantasy.TextContent{Text: message}},
|
||||
BusyBehavior: busyBehavior,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -610,12 +481,12 @@ func latestSubagentAssistantMessage(
|
||||
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != database.ChatMessageRoleAssistant ||
|
||||
if message.Role != string(fantasy.MessageRoleAssistant) ||
|
||||
message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
content, parseErr := chatprompt.ParseContent(message)
|
||||
content, parseErr := chatprompt.ParseContent(message.Role, message.Content)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestComputerUseSubagentSystemPrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Verify the system prompt constant is non-empty and contains
|
||||
// key instructions for the computer use agent.
|
||||
assert.NotEmpty(t, computerUseSubagentSystemPrompt)
|
||||
assert.Contains(t, computerUseSubagentSystemPrompt, "computer")
|
||||
assert.Contains(t, computerUseSubagentSystemPrompt, "screenshot")
|
||||
}
|
||||
|
||||
func TestSubagentFallbackChatTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "EmptyPrompt",
|
||||
input: "",
|
||||
want: "New Chat",
|
||||
},
|
||||
{
|
||||
name: "ShortPrompt",
|
||||
input: "Open Firefox",
|
||||
want: "Open Firefox",
|
||||
},
|
||||
{
|
||||
name: "LongPrompt",
|
||||
input: "Please open the Firefox browser and navigate to the settings page",
|
||||
want: "Please open the Firefox browser and...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := subagentFallbackChatTitle(tt.input)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newInternalTestServer creates a Server for internal tests with
|
||||
// custom provider API keys. The server is automatically closed
|
||||
// when the test finishes.
|
||||
func newInternalTestServer(
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ps pubsub.Pubsub,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) *Server {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := New(Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
// Use a very long interval so the background loop
|
||||
// does not interfere with test assertions.
|
||||
PendingChatAcquireInterval: testutil.WaitLong,
|
||||
ProviderAPIKeys: keys,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
return server
|
||||
}
|
||||
|
||||
// seedInternalChatDeps inserts an OpenAI provider and model config
|
||||
// into the database and returns the created user and model. This
|
||||
// deliberately does NOT create an Anthropic provider.
|
||||
func seedInternalChatDeps(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
) (database.User, database.ChatModelConfig) {
|
||||
t.Helper()
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: "",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return user, model
|
||||
}
|
||||
|
||||
// findToolByName returns the tool with the given name from the
|
||||
// slice, or nil if no match is found.
|
||||
func findToolByName(tools []fantasy.AgentTool, name string) fantasy.AgentTool {
|
||||
for _, tool := range tools {
|
||||
if tool.Info().Name == name {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatdTestContext(t *testing.T) context.Context {
|
||||
t.Helper()
|
||||
return dbauthz.AsChatd(testutil.Context(t, testutil.WaitLong))
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NoAnthropicProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// No Anthropic key in ProviderAPIKeys.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-no-anthropic",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-fetch so LastModelConfigID is populated from the DB.
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when Anthropic is not configured")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_NotAvailableForChildChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the provider check passes.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// Create a root parent chat.
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "root-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a child chat under the parent.
|
||||
child, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
Title: "child-subagent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("do something")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-fetch the child so ParentChatID is populated.
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, childChat.ParentChatID.Valid,
|
||||
"child chat must have a parent")
|
||||
|
||||
// Get tools as if the child chat is the current chat.
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return childChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool, "spawn_computer_use_agent tool must be present")
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-2",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"open browser"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, resp.IsError, "expected an error response")
|
||||
assert.Contains(t, resp.Content, "delegated chats cannot create child subagents")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_DesktopDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-desktop-disabled",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
assert.Nil(t, tool, "spawn_computer_use_agent tool must be omitted when desktop is disabled")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_UsesComputerUseModelNotParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
require.NoError(t, db.UpsertChatDesktopEnabled(chatdTestContext(t), true))
|
||||
// Provide an Anthropic key so the tool can proceed.
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{
|
||||
Anthropic: "test-anthropic-key",
|
||||
})
|
||||
|
||||
ctx := chatdTestContext(t)
|
||||
user, model := seedInternalChatDeps(ctx, t, db)
|
||||
|
||||
// The parent uses an OpenAI model.
|
||||
require.Equal(t, "openai", model.Provider,
|
||||
"seed helper must create an OpenAI model")
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-openai",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools := server.subagentTools(ctx, func() database.Chat { return parentChat })
|
||||
tool := findToolByName(tools, "spawn_computer_use_agent")
|
||||
require.NotNil(t, tool)
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
||||
ID: "call-3",
|
||||
Name: "spawn_computer_use_agent",
|
||||
Input: `{"prompt":"take a screenshot"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, resp.IsError, "expected success but got: %s", resp.Content)
|
||||
|
||||
// Parse the response to get the child chat ID.
|
||||
var result map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(resp.Content), &result))
|
||||
childIDStr, ok := result["chat_id"].(string)
|
||||
require.True(t, ok, "response must contain chat_id")
|
||||
|
||||
childID, err := uuid.Parse(childIDStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
childChat, err := db.GetChatByID(ctx, childID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The child must have Mode=computer_use which causes
|
||||
// runChat to override the model to the predefined computer
|
||||
// use model instead of using the parent's model config.
|
||||
require.True(t, childChat.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, childChat.Mode.ChatMode)
|
||||
|
||||
// The predefined computer use model is Anthropic, which
|
||||
// differs from the parent's OpenAI model. This confirms
|
||||
// that the child will not inherit the parent's model at
|
||||
// runtime.
|
||||
assert.NotEqual(t, model.Provider, chattool.ComputerUseModelProvider,
|
||||
"computer use model provider must differ from parent model provider")
|
||||
assert.Equal(t, "anthropic", chattool.ComputerUseModelProvider)
|
||||
assert.NotEmpty(t, chattool.ComputerUseModelName)
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
package chatd_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSpawnComputerUseAgent_CreatesChildWithChatMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a parent chat.
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate what spawn_computer_use_agent does: set ChatMode
|
||||
// to computer_use and provide a system prompt.
|
||||
prompt := "Use the desktop to open Firefox"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify parent-child relationship.
|
||||
require.True(t, child.ParentChatID.Valid)
|
||||
require.Equal(t, parent.ID, child.ParentChatID.UUID)
|
||||
|
||||
// Verify the chat type is set correctly.
|
||||
require.True(t, child.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, child.Mode.ChatMode)
|
||||
|
||||
// Confirm via a fresh DB read as well.
|
||||
got, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.Mode.Valid)
|
||||
assert.Equal(t, database.ChatModeComputerUse, got.Mode.ChatMode)
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_SystemPromptFormat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Navigate to settings page"
|
||||
systemPrompt := "Computer use instructions\n\n" + prompt
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-format",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: systemPrompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The system message raw content is a JSON-encoded string.
|
||||
// It should contain the system prompt with the user prompt.
|
||||
var rawSystemContent string
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "system" {
|
||||
continue
|
||||
}
|
||||
if msg.Content.Valid {
|
||||
rawSystemContent = string(msg.Content.RawMessage)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, rawSystemContent, prompt,
|
||||
"system prompt raw content should contain the user prompt")
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_ChildIsListedUnderParent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Check the UI layout"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-child",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the child is linked to the parent.
|
||||
fetchedChild, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, fetchedChild.ParentChatID.Valid)
|
||||
assert.Equal(t, parent.ID, fetchedChild.ParentChatID.UUID)
|
||||
}
|
||||
|
||||
func TestSpawnComputerUseAgent_RootChatIDPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newTestServer(t, db, ps, uuid.New())
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Create a root parent chat (no parent of its own).
|
||||
parent, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "root-parent",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
prompt := "Take a screenshot"
|
||||
|
||||
child, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: parent.OwnerID,
|
||||
ParentChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
RootChatID: uuid.NullUUID{
|
||||
UUID: parent.ID,
|
||||
Valid: true,
|
||||
},
|
||||
ModelConfigID: model.ID,
|
||||
Title: "computer-use-root-test",
|
||||
ChatMode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
SystemPrompt: "Computer use instructions\n\n" + prompt,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When the parent has no RootChatID, the child's RootChatID
|
||||
// should point to the parent.
|
||||
require.True(t, child.RootChatID.Valid)
|
||||
assert.Equal(t, parent.ID, child.RootChatID.UUID)
|
||||
|
||||
// Verify chat was retrieved correctly from the DB.
|
||||
got, err := db.GetChatByID(ctx, child.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, got.RootChatID.Valid)
|
||||
assert.Equal(t, parent.ID, got.RootChatID.UUID)
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ComputeUsagePeriodBounds returns the UTC-aligned start and end bounds for the
|
||||
// active usage-limit period containing now.
|
||||
func ComputeUsagePeriodBounds(now time.Time, period codersdk.ChatUsageLimitPeriod) (start, end time.Time) {
|
||||
utcNow := now.UTC()
|
||||
|
||||
switch period {
|
||||
case codersdk.ChatUsageLimitPeriodDay:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 0, 1)
|
||||
case codersdk.ChatUsageLimitPeriodWeek:
|
||||
// Walk backward to Monday of the current ISO week.
|
||||
// ISO 8601 weeks always start on Monday, so this never
|
||||
// crosses an ISO-week boundary.
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
for start.Weekday() != time.Monday {
|
||||
start = start.AddDate(0, 0, -1)
|
||||
}
|
||||
end = start.AddDate(0, 0, 7)
|
||||
case codersdk.ChatUsageLimitPeriodMonth:
|
||||
start = time.Date(utcNow.Year(), utcNow.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
end = start.AddDate(0, 1, 0)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown chat usage limit period: %q", period))
|
||||
}
|
||||
|
||||
return start, end
|
||||
}
|
||||
|
||||
// ResolveUsageLimitStatus resolves the current usage-limit status for userID.
|
||||
//
|
||||
// Note: There is a potential race condition where two concurrent messages
|
||||
// from the same user can both pass the limit check if processed in
|
||||
// parallel, allowing brief overage. This is acceptable because:
|
||||
// - Cost is only known after the LLM API returns.
|
||||
// - Overage is bounded by message cost × concurrency.
|
||||
// - Fail-open is the deliberate design choice for this feature.
|
||||
//
|
||||
// Architecture note: today this path enforces one period globally
|
||||
// (day/week/month) from config.
|
||||
// To support simultaneous periods, add nullable
|
||||
// daily/weekly/monthly_limit_micros columns on override tables, where NULL
|
||||
// means no limit for that period.
|
||||
// Then scan spend once over the widest active window with conditional SUMs
|
||||
// for each period and compare each spend/limit pair Go-side, blocking on
|
||||
// whichever period is tightest.
|
||||
func ResolveUsageLimitStatus(ctx context.Context, db database.Store, userID uuid.UUID, now time.Time) (*codersdk.ChatUsageLimitStatus, error) {
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon access for
|
||||
// deployment config reads and cross-user chat spend aggregation.
|
||||
authCtx := dbauthz.AsChatd(ctx)
|
||||
|
||||
config, err := db.GetChatUsageLimitConfig(authCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if !config.Enabled {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
period, ok := mapDBPeriodToSDK(config.Period)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("invalid chat usage limit period %q", config.Period)
|
||||
}
|
||||
|
||||
// Resolve effective limit in a single query:
|
||||
// individual override > group limit > global default.
|
||||
effectiveLimit, err := db.ResolveUserChatSpendLimit(authCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// -1 means limits are disabled (shouldn't happen since we checked above,
|
||||
// but handle gracefully).
|
||||
if effectiveLimit < 0 {
|
||||
return nil, nil //nolint:nilnil // Nil status cleanly signals disabled limits.
|
||||
}
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(now, period)
|
||||
|
||||
spendTotal, err := db.GetUserChatSpendInPeriod(authCtx, database.GetUserChatSpendInPeriodParams{
|
||||
UserID: userID,
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &codersdk.ChatUsageLimitStatus{
|
||||
IsLimited: true,
|
||||
Period: period,
|
||||
SpendLimitMicros: &effectiveLimit,
|
||||
CurrentSpend: spendTotal,
|
||||
PeriodStart: start,
|
||||
PeriodEnd: end,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mapDBPeriodToSDK(dbPeriod string) (codersdk.ChatUsageLimitPeriod, bool) {
|
||||
switch dbPeriod {
|
||||
case string(codersdk.ChatUsageLimitPeriodDay):
|
||||
return codersdk.ChatUsageLimitPeriodDay, true
|
||||
case string(codersdk.ChatUsageLimitPeriodWeek):
|
||||
return codersdk.ChatUsageLimitPeriodWeek, true
|
||||
case string(codersdk.ChatUsageLimitPeriodMonth):
|
||||
return codersdk.ChatUsageLimitPeriodMonth, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package chatd //nolint:testpackage // Keeps chatd unit tests in the package.
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestComputeUsagePeriodBounds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newYork, err := time.LoadLocation("America/New_York")
|
||||
if err != nil {
|
||||
t.Fatalf("load America/New_York: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
period codersdk.ChatUsageLimitPeriod
|
||||
wantStart time.Time
|
||||
wantEnd time.Time
|
||||
}{
|
||||
{
|
||||
name: "day/mid_day",
|
||||
now: time.Date(2025, time.June, 15, 14, 30, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/midnight_exactly",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/end_of_day",
|
||||
now: time.Date(2025, time.June, 15, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/wednesday",
|
||||
now: time.Date(2025, time.June, 11, 10, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/monday",
|
||||
now: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/sunday",
|
||||
now: time.Date(2025, time.June, 15, 23, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2025, time.June, 9, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "week/year_boundary",
|
||||
now: time.Date(2024, time.December, 31, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodWeek,
|
||||
wantStart: time.Date(2024, time.December, 30, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.January, 6, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/mid_month",
|
||||
now: time.Date(2025, time.June, 15, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/first_day",
|
||||
now: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/last_day",
|
||||
now: time.Date(2025, time.June, 30, 23, 59, 59, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.June, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.July, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/february",
|
||||
now: time.Date(2025, time.February, 15, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2025, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "month/leap_year_february",
|
||||
now: time.Date(2024, time.February, 29, 12, 0, 0, 0, time.UTC),
|
||||
period: codersdk.ChatUsageLimitPeriodMonth,
|
||||
wantStart: time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2024, time.March, 1, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
name: "day/non_utc_timezone",
|
||||
now: time.Date(2025, time.June, 15, 22, 0, 0, 0, newYork),
|
||||
period: codersdk.ChatUsageLimitPeriodDay,
|
||||
wantStart: time.Date(2025, time.June, 16, 0, 0, 0, 0, time.UTC),
|
||||
wantEnd: time.Date(2025, time.June, 17, 0, 0, 0, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
start, end := ComputeUsagePeriodBounds(tc.now, tc.period)
|
||||
if !start.Equal(tc.wantStart) {
|
||||
t.Errorf("start: got %v, want %v", start, tc.wantStart)
|
||||
}
|
||||
if !end.Equal(tc.wantEnd) {
|
||||
t.Errorf("end: got %v, want %v", end, tc.wantEnd)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+195
-1310
File diff suppressed because it is too large
Load Diff
+158
-1381
File diff suppressed because it is too large
Load Diff
+16
-79
@@ -10,7 +10,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
httppprof "net/http/pprof"
|
||||
"net/url"
|
||||
@@ -45,7 +44,6 @@ import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/agentapi"
|
||||
"github.com/coder/coder/v2/coderd/agentapi/metadatabatcher"
|
||||
"github.com/coder/coder/v2/coderd/aiseats"
|
||||
_ "github.com/coder/coder/v2/coderd/apidoc" // Used for swagger docs.
|
||||
"github.com/coder/coder/v2/coderd/appearance"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
@@ -629,11 +627,8 @@ func New(options *Options) *API {
|
||||
options.Database,
|
||||
options.Pubsub,
|
||||
),
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
ProfileCollector: defaultProfileCollector{},
|
||||
AISeatTracker: aiseats.Noop{},
|
||||
dbRolluper: options.DatabaseRolluper,
|
||||
}
|
||||
|
||||
api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider(
|
||||
ctx,
|
||||
options.Logger.Named("workspaceapps"),
|
||||
@@ -767,27 +762,17 @@ func New(options *Options) *API {
|
||||
}
|
||||
api.agentProvider = stn
|
||||
|
||||
maxChatsPerAcquire := options.DeploymentValues.AI.Chat.AcquireBatchSize.Value()
|
||||
if maxChatsPerAcquire > math.MaxInt32 {
|
||||
maxChatsPerAcquire = math.MaxInt32
|
||||
}
|
||||
if maxChatsPerAcquire < math.MinInt32 {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above.
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
Logger: options.Logger.Named("chats"),
|
||||
Database: options.Database,
|
||||
ReplicaID: api.ID,
|
||||
SubscribeFn: options.ChatSubscribeFn,
|
||||
ProviderAPIKeys: chatProviderAPIKeysFromDeploymentValues(options.DeploymentValues),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
})
|
||||
gitSyncLogger := options.Logger.Named("gitsync")
|
||||
refresher := gitsync.NewRefresher(
|
||||
@@ -1153,16 +1138,6 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/cost", func(r chi.Router) {
|
||||
r.Get("/users", api.chatCostUsers)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractUserParam(options.Database))
|
||||
r.Get("/summary", api.chatCostSummary)
|
||||
})
|
||||
})
|
||||
r.Route("/insights", func(r chi.Router) {
|
||||
r.Get("/pull-requests", api.prInsights)
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
@@ -1171,8 +1146,6 @@ func New(options *Options) *API {
|
||||
r.Route("/config", func(r chi.Router) {
|
||||
r.Get("/system-prompt", api.getChatSystemPrompt)
|
||||
r.Put("/system-prompt", api.putChatSystemPrompt)
|
||||
r.Get("/desktop-enabled", api.getChatDesktopEnabled)
|
||||
r.Put("/desktop-enabled", api.putChatDesktopEnabled)
|
||||
r.Get("/user-prompt", api.getUserChatCustomPrompt)
|
||||
r.Put("/user-prompt", api.putUserChatCustomPrompt)
|
||||
})
|
||||
@@ -1194,32 +1167,17 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteChatModelConfig)
|
||||
})
|
||||
})
|
||||
r.Route("/usage-limits", func(r chi.Router) {
|
||||
r.Get("/", api.getChatUsageLimitConfig)
|
||||
r.Put("/", api.updateChatUsageLimitConfig)
|
||||
r.Get("/status", api.getMyChatUsageLimitStatus)
|
||||
r.Route("/overrides/{user}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitOverride)
|
||||
})
|
||||
r.Route("/group-overrides/{group}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertChatUsageLimitGroupOverride)
|
||||
r.Delete("/", api.deleteChatUsageLimitGroupOverride)
|
||||
})
|
||||
})
|
||||
r.Route("/{chat}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractChatParam(options.Database))
|
||||
r.Get("/", api.getChat)
|
||||
r.Patch("/", api.patchChat)
|
||||
r.Get("/messages", api.getChatMessages)
|
||||
r.Get("/git/watch", api.watchChatGit)
|
||||
r.Post("/archive", api.archiveChat)
|
||||
r.Post("/unarchive", api.unarchiveChat)
|
||||
r.Post("/messages", api.postChatMessages)
|
||||
r.Patch("/messages/{message}", api.patchChatMessage)
|
||||
r.Route("/stream", func(r chi.Router) {
|
||||
r.Get("/", api.streamChat)
|
||||
r.Get("/desktop", api.watchChatDesktop)
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Get("/stream", api.streamChat)
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Get("/diff-status", api.getChatDiffStatus)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
@@ -1236,13 +1194,6 @@ func New(options *Options) *API {
|
||||
// MCP HTTP transport endpoint with mandatory authentication
|
||||
r.Mount("/http", api.mcpHTTPHandler())
|
||||
})
|
||||
r.Route("/watch-all-workspacebuilds", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.RequireExperiment(api.Experiments, codersdk.ExperimentWorkspaceBuildUpdates),
|
||||
)
|
||||
r.Get("/", api.watchAllWorkspaceBuilds)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
@@ -1780,8 +1731,6 @@ func New(options *Options) *API {
|
||||
}
|
||||
r.Method("GET", "/expvar", expvar.Handler()) // contains DERP metrics as well as cmdline and memstats
|
||||
|
||||
r.Post("/profile", api.debugCollectProfile)
|
||||
|
||||
r.Route("/pprof", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
// Some of the pprof handlers strip the `/debug/pprof`
|
||||
@@ -2066,20 +2015,9 @@ type API struct {
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// AISeatTracker records AI seat usage.
|
||||
AISeatTracker aiseats.SeatTracker
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
|
||||
// ProfileCollector abstracts the runtime/pprof and runtime/trace
|
||||
// calls used by the /debug/profile endpoint. Tests override this
|
||||
// with a stub to avoid process-global side-effects.
|
||||
ProfileCollector ProfileCollector
|
||||
// ProfileCollecting is used as a concurrency guard so that only one
|
||||
// profile collection (via /debug/profile) can run at a time. The CPU
|
||||
// profiler is process-global, so concurrent collections would fail.
|
||||
ProfileCollecting atomic.Bool
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -2280,7 +2218,6 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n
|
||||
provisionerdserver.Options{
|
||||
OIDCConfig: api.OIDCConfig,
|
||||
ExternalAuthConfigs: api.ExternalAuthConfigs,
|
||||
AISeatTracker: api.AISeatTracker,
|
||||
Clock: api.Clock,
|
||||
HeartbeatFn: options.heartbeatFn,
|
||||
},
|
||||
|
||||
@@ -879,15 +879,6 @@ func createAnotherUserRetry(t testing.TB, client *codersdk.Client, organizationI
|
||||
m(&req)
|
||||
}
|
||||
|
||||
// Service accounts cannot have a password or email and must
|
||||
// use login_type=none. Enforce this after mutators so callers
|
||||
// only need to set ServiceAccount=true.
|
||||
if req.ServiceAccount {
|
||||
req.Password = ""
|
||||
req.Email = ""
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
}
|
||||
|
||||
user, err := client.CreateUserWithOrgs(context.Background(), req)
|
||||
var apiError *codersdk.Error
|
||||
// If the user already exists by username or email conflict, try again up to "retries" times.
|
||||
|
||||
@@ -13,64 +13,32 @@ var _ usage.Inserter = (*UsageInserter)(nil)
|
||||
|
||||
type UsageInserter struct {
|
||||
sync.Mutex
|
||||
discreteEvents []usagetypes.DiscreteEvent
|
||||
heartbeatEvents []usagetypes.HeartbeatEvent
|
||||
seenHeartbeats map[string]struct{}
|
||||
events []usagetypes.DiscreteEvent
|
||||
}
|
||||
|
||||
func NewUsageInserter() *UsageInserter {
|
||||
return &UsageInserter{
|
||||
discreteEvents: []usagetypes.DiscreteEvent{},
|
||||
seenHeartbeats: map[string]struct{}{},
|
||||
heartbeatEvents: []usagetypes.HeartbeatEvent{},
|
||||
events: []usagetypes.DiscreteEvent{},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UsageInserter) InsertDiscreteUsageEvent(_ context.Context, _ database.Store, event usagetypes.DiscreteEvent) error {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.discreteEvents = append(u.discreteEvents, event)
|
||||
u.events = append(u.events, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) InsertHeartbeatUsageEvent(_ context.Context, _ database.Store, id string, event usagetypes.HeartbeatEvent) error {
|
||||
func (u *UsageInserter) GetEvents() []usagetypes.DiscreteEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
if _, seen := u.seenHeartbeats[id]; seen {
|
||||
return nil
|
||||
}
|
||||
|
||||
u.seenHeartbeats[id] = struct{}{}
|
||||
u.heartbeatEvents = append(u.heartbeatEvents, event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetHeartbeatEvents() []usagetypes.HeartbeatEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.HeartbeatEvent, len(u.heartbeatEvents))
|
||||
copy(eventsCopy, u.heartbeatEvents)
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.events))
|
||||
copy(eventsCopy, u.events)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) GetDiscreteEvents() []usagetypes.DiscreteEvent {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
eventsCopy := make([]usagetypes.DiscreteEvent, len(u.discreteEvents))
|
||||
copy(eventsCopy, u.discreteEvents)
|
||||
return eventsCopy
|
||||
}
|
||||
|
||||
func (u *UsageInserter) TotalEventCount() int {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
return len(u.discreteEvents) + len(u.heartbeatEvents)
|
||||
}
|
||||
|
||||
func (u *UsageInserter) Reset() {
|
||||
u.Lock()
|
||||
defer u.Unlock()
|
||||
u.seenHeartbeats = map[string]struct{}{}
|
||||
u.discreteEvents = []usagetypes.DiscreteEvent{}
|
||||
u.heartbeatEvents = []usagetypes.HeartbeatEvent{}
|
||||
u.events = []usagetypes.DiscreteEvent{}
|
||||
}
|
||||
|
||||
@@ -6,27 +6,22 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config
|
||||
CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs
|
||||
CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs
|
||||
CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers
|
||||
CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users
|
||||
CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents
|
||||
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
|
||||
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
|
||||
CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces
|
||||
CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces
|
||||
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
|
||||
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
|
||||
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
|
||||
)
|
||||
|
||||
+259
-111
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/hcl/v2"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/render"
|
||||
@@ -195,14 +195,13 @@ func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser
|
||||
|
||||
func ReducedUser(user database.User) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
IsServiceAccount: user.IsServiceAccount,
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
Status: codersdk.UserStatus(user.Status),
|
||||
LoginType: codersdk.LoginType(user.LoginType),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1070,10 +1069,10 @@ func ChatMessage(m database.ChatMessage) codersdk.ChatMessage {
|
||||
CreatedBy: createdBy,
|
||||
ModelConfigID: modelConfigID,
|
||||
CreatedAt: m.CreatedAt,
|
||||
Role: codersdk.ChatMessageRole(m.Role),
|
||||
Role: m.Role,
|
||||
}
|
||||
if m.Content.Valid {
|
||||
parts, err := chatMessageParts(m)
|
||||
parts, err := chatMessageParts(m.Role, m.Content)
|
||||
if err == nil {
|
||||
msg.Content = parts
|
||||
}
|
||||
@@ -1115,15 +1114,9 @@ func chatMessageUsage(m database.ChatMessage) *codersdk.ChatMessageUsage {
|
||||
|
||||
// ChatQueuedMessage converts a queued message to its SDK representation.
|
||||
func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
||||
// Queued messages are always written by current code via
|
||||
// MarshalParts, so they are always current content version.
|
||||
parts, err := chatMessageParts(database.ChatMessage{
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: message.Content,
|
||||
Valid: len(message.Content) > 0,
|
||||
},
|
||||
ContentVersion: chatprompt.CurrentContentVersion,
|
||||
parts, err := chatMessageParts(string(fantasy.MessageRoleUser), pqtype.NullRawMessage{
|
||||
RawMessage: message.Content,
|
||||
Valid: len(message.Content) > 0,
|
||||
})
|
||||
if err != nil {
|
||||
parts = nil
|
||||
@@ -1147,16 +1140,254 @@ func ChatQueuedMessages(messages []database.ChatQueuedMessage) []codersdk.ChatQu
|
||||
return out
|
||||
}
|
||||
|
||||
func chatMessageParts(m database.ChatMessage) ([]codersdk.ChatMessagePart, error) {
|
||||
parts, err := chatprompt.ParseContent(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func chatMessageParts(role string, raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error) {
|
||||
switch role {
|
||||
case string(fantasy.MessageRoleSystem):
|
||||
content, err := parseSystemContent(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: content,
|
||||
}}, nil
|
||||
case string(fantasy.MessageRoleUser), string(fantasy.MessageRoleAssistant):
|
||||
content, err := parseContentBlocks(role, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawBlocks []json.RawMessage
|
||||
_ = json.Unmarshal(raw.RawMessage, &rawBlocks)
|
||||
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(content))
|
||||
for i, block := range content {
|
||||
part := contentBlockToPart(block)
|
||||
if part.Type == "" {
|
||||
continue
|
||||
}
|
||||
if i < len(rawBlocks) {
|
||||
if part.Type == codersdk.ChatMessagePartTypeFile {
|
||||
if fid, err := chatprompt.ExtractFileID(rawBlocks[i]); err == nil {
|
||||
part.FileID = uuid.NullUUID{UUID: fid, Valid: true}
|
||||
}
|
||||
// When a file_id is present, omit inline data
|
||||
// from the response. Clients fetch content via
|
||||
// the GET /chats/files/{id} endpoint instead.
|
||||
if part.FileID.Valid {
|
||||
part.Data = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
return parts, nil
|
||||
case string(fantasy.MessageRoleTool):
|
||||
results, err := parseToolResults(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(results))
|
||||
for _, result := range results {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolName: result.ToolName,
|
||||
Result: result.Result,
|
||||
IsError: result.IsError,
|
||||
ProviderExecuted: result.ProviderExecuted,
|
||||
})
|
||||
}
|
||||
return parts, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
// Strip internal-only fields before API responses.
|
||||
for i := range parts {
|
||||
parts[i].StripInternal()
|
||||
}
|
||||
|
||||
func parseSystemContent(raw pqtype.NullRawMessage) (string, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return parts, nil
|
||||
var content string
|
||||
if err := json.Unmarshal(raw.RawMessage, &content); err != nil {
|
||||
return "", xerrors.Errorf("parse system content: %w", err)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func parseContentBlocks(role string, raw pqtype.NullRawMessage) ([]fantasy.Content, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if role == string(fantasy.MessageRoleUser) {
|
||||
var text string
|
||||
if err := json.Unmarshal(raw.RawMessage, &text); err == nil {
|
||||
return []fantasy.Content{
|
||||
fantasy.TextContent{Text: text},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var blocks []json.RawMessage
|
||||
if err := json.Unmarshal(raw.RawMessage, &blocks); err != nil {
|
||||
return nil, xerrors.Errorf("parse content blocks: %w", err)
|
||||
}
|
||||
|
||||
content := make([]fantasy.Content, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
decoded, err := fantasy.UnmarshalContent(block)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse content block: %w", err)
|
||||
}
|
||||
content = append(content, decoded)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// toolResultRow is used only for extracting top-level fields from
|
||||
// persisted tool result JSON. The result payload is kept as raw JSON.
|
||||
type toolResultRow struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
ProviderExecuted bool `json:"provider_executed,omitempty"`
|
||||
}
|
||||
|
||||
func parseToolResults(raw pqtype.NullRawMessage) ([]toolResultRow, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []toolResultRow
|
||||
if err := json.Unmarshal(raw.RawMessage, &results); err != nil {
|
||||
return nil, xerrors.Errorf("parse tool results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func contentBlockToPart(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
switch value := block.(type) {
|
||||
case fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.TextContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeText,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case *fantasy.ReasoningContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeReasoning,
|
||||
Text: value.Text,
|
||||
}
|
||||
case fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
}
|
||||
case *fantasy.ToolCallContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: value.ToolCallID,
|
||||
ToolName: value.ToolName,
|
||||
Args: []byte(value.Input),
|
||||
ProviderExecuted: value.ProviderExecuted,
|
||||
}
|
||||
case fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case *fantasy.SourceContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeSource,
|
||||
SourceID: value.ID,
|
||||
URL: value.URL,
|
||||
Title: value.Title,
|
||||
}
|
||||
case fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case *fantasy.FileContent:
|
||||
return codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeFile,
|
||||
MediaType: value.MediaType,
|
||||
Data: value.Data,
|
||||
}
|
||||
case fantasy.ToolResultContent:
|
||||
part := chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
part.ProviderExecuted = value.ProviderExecuted
|
||||
return part
|
||||
case *fantasy.ToolResultContent:
|
||||
part := chatprompt.ToolResultToPart(
|
||||
value.ToolCallID,
|
||||
value.ToolName,
|
||||
toolResultOutputToRawJSON(value.Result),
|
||||
toolResultOutputIsError(value.Result),
|
||||
)
|
||||
part.ProviderExecuted = value.ProviderExecuted
|
||||
return part
|
||||
default:
|
||||
return codersdk.ChatMessagePart{}
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputToRawJSON(output fantasy.ToolResultOutputContent) json.RawMessage {
|
||||
switch v := output.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
if v.Error != nil {
|
||||
data, _ := json.Marshal(map[string]any{"error": v.Error.Error()})
|
||||
return data
|
||||
}
|
||||
return json.RawMessage(`{"error":""}`)
|
||||
case fantasy.ToolResultOutputContentText:
|
||||
raw := json.RawMessage(v.Text)
|
||||
if json.Valid(raw) {
|
||||
return raw
|
||||
}
|
||||
data, _ := json.Marshal(map[string]any{"output": v.Text})
|
||||
return data
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
data, _ := json.Marshal(map[string]any{
|
||||
"data": v.Data,
|
||||
"mime_type": v.MediaType,
|
||||
"text": v.Text,
|
||||
})
|
||||
return data
|
||||
default:
|
||||
return json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
|
||||
func toolResultOutputIsError(output fantasy.ToolResultOutputContent) bool {
|
||||
_, ok := output.(fantasy.ToolResultOutputContentError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
@@ -1166,86 +1397,3 @@ func nullInt64Ptr(v sql.NullInt64) *int64 {
|
||||
value := v.Int64
|
||||
return &value
|
||||
}
|
||||
|
||||
// ChatDiffStatus converts a database.ChatDiffStatus to a
|
||||
// codersdk.ChatDiffStatus. When status is nil an empty value
|
||||
// containing only the chatID is returned.
|
||||
func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.ChatDiffStatus {
|
||||
result := codersdk.ChatDiffStatus{
|
||||
ChatID: chatID,
|
||||
}
|
||||
if status == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
result.ChatID = status.ChatID
|
||||
if status.Url.Valid {
|
||||
u := strings.TrimSpace(status.Url.String)
|
||||
if u != "" {
|
||||
result.URL = &u
|
||||
}
|
||||
}
|
||||
if result.URL == nil {
|
||||
// Try to build a branch URL from the stored origin.
|
||||
// Since this function does not have access to the API
|
||||
// instance, we construct a GitHub provider directly as
|
||||
// a best-effort fallback.
|
||||
// TODO: This uses the default github.com API base URL,
|
||||
// so branch URLs for GitHub Enterprise instances will
|
||||
// be incorrect. To fix this, this function would need
|
||||
// access to the external auth configs.
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
if gp != nil {
|
||||
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
|
||||
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
|
||||
if branchURL != "" {
|
||||
result.URL = &branchURL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if status.PullRequestState.Valid {
|
||||
pullRequestState := strings.TrimSpace(status.PullRequestState.String)
|
||||
if pullRequestState != "" {
|
||||
result.PullRequestState = &pullRequestState
|
||||
}
|
||||
}
|
||||
result.PullRequestTitle = status.PullRequestTitle
|
||||
result.PullRequestDraft = status.PullRequestDraft
|
||||
result.ChangesRequested = status.ChangesRequested
|
||||
result.Additions = status.Additions
|
||||
result.Deletions = status.Deletions
|
||||
result.ChangedFiles = status.ChangedFiles
|
||||
if status.AuthorLogin.Valid {
|
||||
result.AuthorLogin = &status.AuthorLogin.String
|
||||
}
|
||||
if status.AuthorAvatarUrl.Valid {
|
||||
result.AuthorAvatarURL = &status.AuthorAvatarUrl.String
|
||||
}
|
||||
if status.BaseBranch.Valid {
|
||||
result.BaseBranch = &status.BaseBranch.String
|
||||
}
|
||||
if status.HeadBranch.Valid {
|
||||
result.HeadBranch = &status.HeadBranch.String
|
||||
}
|
||||
if status.PrNumber.Valid {
|
||||
result.PRNumber = &status.PrNumber.Int32
|
||||
}
|
||||
if status.Commits.Valid {
|
||||
result.Commits = &status.Commits.Int32
|
||||
}
|
||||
if status.Approved.Valid {
|
||||
result.Approved = &status.Approved.Bool
|
||||
}
|
||||
if status.ReviewerCount.Valid {
|
||||
result.ReviewerCount = &status.ReviewerCount.Int32
|
||||
}
|
||||
if status.RefreshedAt.Valid {
|
||||
refreshedAt := status.RefreshedAt.Time
|
||||
result.RefreshedAt = &refreshedAt
|
||||
}
|
||||
staleAt := status.StaleAt
|
||||
result.StaleAt = &staleAt
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func TestChatMessage_PreservesProviderExecutedOnToolResults(t *testing.T) {
|
||||
dbMsg := database.ChatMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Role: string(fantasy.MessageRoleAssistant),
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: rawContent,
|
||||
Valid: true,
|
||||
@@ -495,9 +495,8 @@ func TestChatMessage_PreservesProviderExecutedOnToolResults(t *testing.T) {
|
||||
func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Queued messages are always written via MarshalParts (SDK format).
|
||||
rawContent, err := json.Marshal([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("queued text"),
|
||||
rawContent, err := json.Marshal([]fantasy.Content{
|
||||
fantasy.TextContent{Text: "queued text"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -513,15 +512,35 @@ func TestChatQueuedMessage_ParsesUserContentParts(t *testing.T) {
|
||||
require.Equal(t, "queued text", queued.Content[0].Text)
|
||||
}
|
||||
|
||||
func TestChatQueuedMessage_MalformedContent(t *testing.T) {
|
||||
func TestChatQueuedMessage_FallsBackToTextForLegacyContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: json.RawMessage(`{"unexpected":"shape"}`),
|
||||
CreatedAt: time.Now(),
|
||||
t.Run("legacy_string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: json.RawMessage(`"legacy queued text"`),
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
require.Len(t, queued.Content, 1)
|
||||
require.Equal(t, codersdk.ChatMessagePartTypeText, queued.Content[0].Type)
|
||||
require.Equal(t, "legacy queued text", queued.Content[0].Text)
|
||||
})
|
||||
|
||||
require.Empty(t, queued.Content)
|
||||
t.Run("malformed_payload", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
raw := json.RawMessage(`{"unexpected":"shape"}`)
|
||||
queued := db2sdk.ChatQueuedMessage(database.ChatQueuedMessage{
|
||||
ID: 1,
|
||||
ChatID: uuid.New(),
|
||||
Content: raw,
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
require.Empty(t, queued.Content)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1264,7 +1264,7 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID uuid.UUID, added, re
|
||||
// System roles are stored in the database but have a fixed, code-defined
|
||||
// meaning. Do not rewrite the name for them so the static "who can assign
|
||||
// what" mapping applies.
|
||||
if !rolestore.IsSystemRoleName(roleName.Name) {
|
||||
if !rbac.SystemRoleName(roleName.Name) {
|
||||
// To support a dynamic mapping of what roles can assign what, we need
|
||||
// to store this in the database. For now, just use a static role so
|
||||
// owners and org admins can assign roles.
|
||||
@@ -1726,13 +1726,6 @@ func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountCon
|
||||
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.CountEnabledModelsWithoutPricing(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -1824,6 +1817,18 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u
|
||||
return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -1849,20 +1854,6 @@ func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.Dele
|
||||
return q.db.DeleteChatQueuedMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceCryptoKey); err != nil {
|
||||
return database.CryptoKey{}, err
|
||||
@@ -2133,12 +2124,12 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, params database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, params)
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
@@ -2336,13 +2327,6 @@ func (q *querier) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Tim
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetAPIKeysLastUsedAfter)(ctx, lastUsed)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceLicense); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2442,45 +2426,6 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab
|
||||
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerChat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerUser(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return database.GetChatCostSummaryRow{}, err
|
||||
}
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
// The desktop-enabled flag is a deployment-wide setting read by any
|
||||
// authenticated chat user and by chatd when deciding whether to expose
|
||||
// computer-use tooling. We only require that an explicit actor is present
|
||||
// in the context so unauthenticated calls fail closed.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return false, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatDesktopEnabled(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2559,14 +2504,6 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
|
||||
return q.db.GetChatMessagesByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
@@ -2631,33 +2568,8 @@ func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
return q.db.GetChatSystemPrompt(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitConfig(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceChat.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.GetAuthorizedChats(ctx, arg, prep)
|
||||
func (q *querier) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByOwnerID)(ctx, ownerID)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
@@ -3147,34 +3059,6 @@ func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg da
|
||||
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.GetPRInsightsSummaryRow{}, err
|
||||
}
|
||||
return q.db.GetPRInsightsSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsTimeSeries(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
|
||||
if err != nil {
|
||||
@@ -3838,13 +3722,6 @@ func (q *querier) GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID)
|
||||
return q.db.GetUserChatCustomPrompt(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserChatSpendInPeriod(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -3852,13 +3729,6 @@ func (q *querier) GetUserCount(ctx context.Context, includeSystem bool) (int64,
|
||||
return q.db.GetUserCount(ctx, includeSystem)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.GetUserGroupSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
// Used by insights endpoints. Need to check both for auditors and for regular users with template acl perms.
|
||||
if err := q.authorizeContext(ctx, policy.ActionViewInsights, rbac.ResourceTemplate); err != nil {
|
||||
@@ -4528,13 +4398,6 @@ func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.I
|
||||
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
return database.AIBridgeModelThought{}, err
|
||||
}
|
||||
return q.db.InsertAIBridgeModelThought(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
// All aibridge_token_usages records belong to the initiator of their associated interception.
|
||||
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, arg.InterceptionID); err != nil {
|
||||
@@ -4591,16 +4454,16 @@ func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFil
|
||||
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
// Authorize create on the parent chat (using update permission).
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return nil, err
|
||||
return database.ChatMessage{}, err
|
||||
}
|
||||
return q.db.InsertChatMessages(ctx, arg)
|
||||
return q.db.InsertChatMessage(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
|
||||
@@ -5224,20 +5087,6 @@ func (q *querier) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context,
|
||||
return q.db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, interceptionIDs)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitGroupOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListChatUsageLimitOverrides(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.ListProvisionerKeysByOrganization)(ctx, organizationID)
|
||||
}
|
||||
@@ -5357,13 +5206,6 @@ func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveU
|
||||
return q.db.RemoveUserFromGroups(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.ResolveUserChatSpendLimit(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
@@ -5379,32 +5221,6 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
|
||||
return q.db.SelectUsageEventsForPublishing(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
msg, err := q.db.GetChatMessageByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
chat, err := q.db.GetChatByID(ctx, msg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteChatMessageByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -6573,13 +6389,6 @@ func (q *querier) UpdateWorkspacesTTLByTemplateID(ctx context.Context, arg datab
|
||||
return q.db.UpdateWorkspacesTTLByTemplateID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.UpsertAISeatState(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
@@ -6601,13 +6410,6 @@ func (q *querier) UpsertBoundaryUsageStats(ctx context.Context, arg database.Ups
|
||||
return q.db.UpsertBoundaryUsageStats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
// Authorize update on the parent chat.
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -6639,27 +6441,6 @@ func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) erro
|
||||
return q.db.UpsertChatSystemPrompt(ctx, value)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatUsageLimitConfig{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitGroupOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.UpsertChatUsageLimitUserOverrideRow{}, err
|
||||
}
|
||||
return q.db.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceConnectionLog); err != nil {
|
||||
return database.ConnectionLog{}, err
|
||||
@@ -6847,13 +6628,6 @@ func (q *querier) UpsertWorkspaceAppAuditSession(ctx context.Context, arg databa
|
||||
return q.db.UpsertWorkspaceAppAuditSession(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceUsageEvent); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return q.db.UsageEventExistsByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) ValidateGroupIDs(ctx context.Context, groupIDs []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
// This check is probably overly restrictive, but the "correct" check isn't
|
||||
// necessarily obvious. It's only used as a verification check for ACLs right
|
||||
@@ -6949,7 +6723,3 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -401,27 +401,16 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UnarchiveChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("DeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.SoftDeleteChatMessagesAfterIDParams{
|
||||
arg := database.DeleteChatMessagesAfterIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 123,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteChatMessagesAfterID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("SoftDeleteChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := database.ChatMessage{
|
||||
ID: 456,
|
||||
ChatID: chat.ID,
|
||||
}
|
||||
dbm.EXPECT().GetChatMessageByID(gomock.Any(), msg.ID).Return(msg, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), msg.ID).Return(nil).AnyTimes()
|
||||
check.Args(msg.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
id := uuid.New()
|
||||
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
|
||||
@@ -449,85 +438,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
rows := []database.GetChatCostPerChatRow{{
|
||||
RootChatID: uuid.New(),
|
||||
ChatTitle: "chat-cost",
|
||||
TotalCostMicros: 123,
|
||||
MessageCount: 4,
|
||||
TotalInputTokens: 55,
|
||||
TotalOutputTokens: 89,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerChat(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerModelParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
rows := []database.GetChatCostPerModelRow{{
|
||||
ModelConfigID: uuid.New(),
|
||||
DisplayName: "GPT 4.1",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4.1",
|
||||
TotalCostMicros: 456,
|
||||
MessageCount: 7,
|
||||
TotalInputTokens: 144,
|
||||
TotalOutputTokens: 233,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerModel(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostPerUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerUserParams{
|
||||
PageOffset: 0,
|
||||
PageLimit: 25,
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
Username: "cost-user",
|
||||
}
|
||||
rows := []database.GetChatCostPerUserRow{{
|
||||
UserID: uuid.New(),
|
||||
Username: "cost-user",
|
||||
Name: "Cost User",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
TotalCostMicros: 789,
|
||||
MessageCount: 11,
|
||||
ChatCount: 3,
|
||||
TotalInputTokens: 377,
|
||||
TotalOutputTokens: 610,
|
||||
TotalCount: 1,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerUser(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostSummaryParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
row := database.GetChatCostSummaryRow{
|
||||
TotalCostMicros: 987,
|
||||
PricedMessageCount: 12,
|
||||
UnpricedMessageCount: 2,
|
||||
TotalInputTokens: 400,
|
||||
TotalOutputTokens: 800,
|
||||
}
|
||||
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
|
||||
}))
|
||||
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
@@ -573,18 +483,10 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := database.GetChatMessagesByChatIDDescPaginatedParams{ChatID: chat.ID, BeforeID: 0, LimitVal: 50}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
arg := database.GetLastChatMessageByRoleParams{ChatID: chat.ID, Role: database.ChatMessageRoleAssistant}
|
||||
arg := database.GetLastChatMessageByRoleParams{ChatID: chat.ID, Role: "assistant"}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetLastChatMessageByRole(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msg)
|
||||
@@ -629,17 +531,12 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB})
|
||||
}))
|
||||
s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
s.Run("GetAuthorizedChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
params := database.GetChatsParams{}
|
||||
dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.Chat{}, nil).AnyTimes()
|
||||
// No asserts here because it re-routes through GetChats which uses SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
s.Run("GetChatsByOwnerID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
c1 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
c2 := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
params := database.GetChatsByOwnerIDParams{OwnerID: c1.OwnerID}
|
||||
dbm.EXPECT().GetChatsByOwnerID(gomock.Any(), params).Return([]database.Chat{c1, c2}, nil).AnyTimes()
|
||||
check.Args(params).Asserts(c1, policy.ActionRead, c2, policy.ActionRead).Returns([]database.Chat{c1, c2})
|
||||
}))
|
||||
s.Run("GetChatQueuedMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -652,10 +549,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatSystemPrompt(gomock.Any()).Return("prompt", nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("GetEnabledChatModelConfigs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
configA := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
configB := testutil.Fake(s.T(), faker, database.ChatModelConfig{})
|
||||
@@ -686,13 +579,13 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
|
||||
}))
|
||||
s.Run("InsertChatMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessagesParams{ChatID: chat.ID})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().InsertChatMessages(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msgs)
|
||||
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
|
||||
}))
|
||||
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
@@ -865,146 +758,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
|
||||
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("UpsertChatDesktopEnabled", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes()
|
||||
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetUserChatSpendInPeriod", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetUserChatSpendInPeriodParams{
|
||||
UserID: uuid.New(),
|
||||
StartTime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndTime: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
spend := int64(123)
|
||||
dbm.EXPECT().GetUserChatSpendInPeriod(gomock.Any(), arg).Return(spend, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.UserID.String()), policy.ActionRead).Returns(spend)
|
||||
}))
|
||||
s.Run("GetUserGroupSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(456)
|
||||
dbm.EXPECT().GetUserGroupSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("ResolveUserChatSpendLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
limit := int64(789)
|
||||
dbm.EXPECT().ResolveUserChatSpendLimit(gomock.Any(), userID).Return(limit, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceChat.WithOwner(userID.String()), policy.ActionRead).Returns(limit)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 1_000_000,
|
||||
Period: "monthly",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(config, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(config)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
override := database.GetChatUsageLimitGroupOverrideRow{
|
||||
GroupID: groupID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 2_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(override, nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("GetChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
override := database.GetChatUsageLimitUserOverrideRow{
|
||||
UserID: userID,
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 3_000_000, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().GetChatUsageLimitUserOverride(gomock.Any(), userID).Return(override, nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(override)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitGroupOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitGroupOverridesRow{{
|
||||
GroupID: uuid.New(),
|
||||
GroupName: "group-name",
|
||||
GroupDisplayName: "Group Name",
|
||||
GroupAvatarUrl: "https://example.com/group.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 4_000_000, Valid: true},
|
||||
MemberCount: 5,
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitGroupOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("ListChatUsageLimitOverrides", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
overrides := []database.ListChatUsageLimitOverridesRow{{
|
||||
UserID: uuid.New(),
|
||||
Username: "usage-limit-user",
|
||||
Name: "Usage Limit User",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: 5_000_000, Valid: true},
|
||||
}}
|
||||
dbm.EXPECT().ListChatUsageLimitOverrides(gomock.Any()).Return(overrides, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(overrides)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitConfig", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
arg := database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
DefaultLimitMicros: 6_000_000,
|
||||
Period: "monthly",
|
||||
}
|
||||
config := database.ChatUsageLimitConfig{
|
||||
ID: 1,
|
||||
Singleton: true,
|
||||
Enabled: arg.Enabled,
|
||||
DefaultLimitMicros: arg.DefaultLimitMicros,
|
||||
Period: arg.Period,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitConfig(gomock.Any(), arg).Return(config, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitGroupOverrideParams{
|
||||
SpendLimitMicros: 7_000_000,
|
||||
GroupID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitGroupOverrideRow{
|
||||
GroupID: arg.GroupID,
|
||||
Name: "group",
|
||||
DisplayName: "Group",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitGroupOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("UpsertChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpsertChatUsageLimitUserOverrideParams{
|
||||
SpendLimitMicros: 8_000_000,
|
||||
UserID: uuid.New(),
|
||||
}
|
||||
override := database.UpsertChatUsageLimitUserOverrideRow{
|
||||
UserID: arg.UserID,
|
||||
Username: "user",
|
||||
Name: "User",
|
||||
AvatarURL: "",
|
||||
SpendLimitMicros: sql.NullInt64{Int64: arg.SpendLimitMicros, Valid: true},
|
||||
}
|
||||
dbm.EXPECT().UpsertChatUsageLimitUserOverride(gomock.Any(), arg).Return(override, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(override)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitGroupOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
groupID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitGroupOverride(gomock.Any(), groupID).Return(nil).AnyTimes()
|
||||
check.Args(groupID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("DeleteChatUsageLimitUserOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
userID := uuid.New()
|
||||
dbm.EXPECT().DeleteChatUsageLimitUserOverride(gomock.Any(), userID).Return(nil).AnyTimes()
|
||||
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
@@ -1327,14 +1080,6 @@ func (s *MethodTestSuite) TestProvisionerJob() {
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestLicense() {
|
||||
s.Run("GetActiveAISeatCount", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetActiveAISeatCount(gomock.Any()).Return(int64(100), nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceLicense, policy.ActionRead).Returns(int64(100))
|
||||
}))
|
||||
s.Run("UpsertAISeatState", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertAISeatState(gomock.Any(), gomock.Any()).Return(true, nil).AnyTimes()
|
||||
check.Args(database.UpsertAISeatStateParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
|
||||
}))
|
||||
s.Run("GetLicenses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
a := database.License{ID: 1}
|
||||
b := database.License{ID: 2}
|
||||
@@ -1504,7 +1249,7 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
ShareableWorkspaceOwners: database.ShareableWorkspaceOwnersNone,
|
||||
WorkspaceSharingDisabled: true,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
@@ -1935,26 +1680,6 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetTemplateInsightsByTemplate(gomock.Any(), arg).Return([]database.GetTemplateInsightsByTemplateRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceTemplate, policy.ActionViewInsights)
|
||||
}))
|
||||
s.Run("GetPRInsightsSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsSummaryParams{}
|
||||
dbm.EXPECT().GetPRInsightsSummary(gomock.Any(), arg).Return(database.GetPRInsightsSummaryRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsTimeSeries", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsTimeSeriesParams{}
|
||||
dbm.EXPECT().GetPRInsightsTimeSeries(gomock.Any(), arg).Return([]database.GetPRInsightsTimeSeriesRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPerModelParams{}
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetTelemetryTaskEventsParams{}
|
||||
dbm.EXPECT().GetTelemetryTaskEvents(gomock.Any(), arg).Return([]database.GetTelemetryTaskEventsRow{}, nil).AnyTimes()
|
||||
@@ -2443,12 +2168,9 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.DeleteWorkspaceACLsByOrganizationParams{
|
||||
OrganizationID: uuid.New(),
|
||||
ExcludeServiceAccounts: false,
|
||||
}
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -5154,12 +4876,6 @@ func (s *MethodTestSuite) TestUsageEvents() {
|
||||
check.Args(params).Asserts(rbac.ResourceUsageEvent, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("UsageEventExistsByID", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
id := uuid.NewString()
|
||||
db.EXPECT().UsageEventExistsByID(gomock.Any(), id).Return(true, nil)
|
||||
check.Args(id).Asserts(rbac.ResourceUsageEvent, policy.ActionRead)
|
||||
}))
|
||||
|
||||
s.Run("SelectUsageEventsForPublishing", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
now := dbtime.Now()
|
||||
db.EXPECT().SelectUsageEventsForPublishing(gomock.Any(), now).Return([]database.UsageEvent{}, nil)
|
||||
@@ -5220,17 +4936,6 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params).Asserts(intc, policy.ActionCreate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeModelThought", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intID).Return(intc, nil).AnyTimes() // Validation.
|
||||
|
||||
params := database.InsertAIBridgeModelThoughtParams{InterceptionID: intc.ID}
|
||||
expected := testutil.Fake(s.T(), faker, database.AIBridgeModelThought{InterceptionID: intc.ID})
|
||||
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), params).Return(expected, nil).AnyTimes()
|
||||
check.Args(params).Asserts(intc, policy.ActionUpdate)
|
||||
}))
|
||||
|
||||
s.Run("InsertAIBridgeTokenUsage", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
intID := uuid.UUID{2}
|
||||
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intID})
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
)
|
||||
|
||||
@@ -144,7 +143,7 @@ func (s *MethodTestSuite) Mocked(testCaseF func(dmb *dbmock.MockStore, faker *go
|
||||
UUID: pair.OrganizationID,
|
||||
Valid: pair.OrganizationID != uuid.Nil,
|
||||
},
|
||||
IsSystem: rolestore.IsSystemRoleName(pair.Name),
|
||||
IsSystem: rbac.SystemRoleName(pair.Name),
|
||||
ID: uuid.New(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -650,26 +650,34 @@ func Organization(t testing.TB, db database.Store, orig database.Organization) d
|
||||
})
|
||||
require.NoError(t, err, "insert organization")
|
||||
|
||||
// Populate the placeholder system roles (created by DB
|
||||
// trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileSystemRole needs the system:update
|
||||
// Populate the placeholder organization-member system role (created by
|
||||
// DB trigger/migration) so org members have expected permissions.
|
||||
//nolint:gocritic // ReconcileOrgMemberRole needs the system:update
|
||||
// permission that `genCtx` does not have.
|
||||
sysCtx := dbauthz.AsSystemRestricted(genCtx)
|
||||
for roleName := range rolestore.SystemRoleNames {
|
||||
role := database.CustomRole{
|
||||
Name: roleName,
|
||||
OrganizationID: uuid.NullUUID{UUID: org.ID, Valid: true},
|
||||
}
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateSystemRole(sysCtx, db, org, roleName)
|
||||
require.NoError(t, err, "create role "+roleName)
|
||||
_, _, err = rolestore.ReconcileSystemRole(sysCtx, db, role, org)
|
||||
}
|
||||
require.NoError(t, err, "reconcile role "+roleName)
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// The trigger that creates the placeholder role didn't run (e.g.,
|
||||
// triggers were disabled in the test). Create the role manually.
|
||||
err = rolestore.CreateOrgMemberRole(sysCtx, db, org)
|
||||
require.NoError(t, err, "create organization-member role")
|
||||
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, db, database.CustomRole{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: uuid.NullUUID{
|
||||
UUID: org.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, org.WorkspaceSharingDisabled)
|
||||
}
|
||||
require.NoError(t, err, "reconcile organization-member role")
|
||||
|
||||
return org
|
||||
}
|
||||
|
||||
@@ -288,14 +288,6 @@ func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountEnabledModelsWithoutPricing(ctx)
|
||||
m.queryLatencies.WithLabelValues("CountEnabledModelsWithoutPricing").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountEnabledModelsWithoutPricing").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountInProgressPrebuilds(ctx)
|
||||
@@ -384,6 +376,14 @@ func (m queryMetricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.C
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatMessagesAfterID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatMessagesAfterID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatModelConfigByID(ctx, id)
|
||||
@@ -408,22 +408,6 @@ func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg data
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitGroupOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("DeleteChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatUsageLimitUserOverride").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteCryptoKey(ctx, arg)
|
||||
@@ -688,11 +672,10 @@ func (m queryMetricsStore) DeleteWorkspaceACLByID(ctx context.Context, id uuid.U
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m queryMetricsStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, arg)
|
||||
r0 := m.s.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
m.queryLatencies.WithLabelValues("DeleteWorkspaceACLsByOrganization").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteWorkspaceACLsByOrganization").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -888,14 +871,6 @@ func (m queryMetricsStore) GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveAISeatCount(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetActiveAISeatCount").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveAISeatCount").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1008,46 +983,6 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerUser(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerUser").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerUser").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostSummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostSummary").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostSummary").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDesktopEnabled(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatDesktopEnabled").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
@@ -1096,14 +1031,6 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDDescPaginated").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDDescPaginated").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
@@ -1168,35 +1095,11 @@ func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
func (m queryMetricsStore) GetChatsByOwnerID(ctx context.Context, ownerID database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitConfig(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitGroupOverride(ctx, groupID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatUsageLimitUserOverride(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChats").Inc()
|
||||
r0, r1 := m.s.GetChatsByOwnerID(ctx, ownerID)
|
||||
m.queryLatencies.WithLabelValues("GetChatsByOwnerID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsByOwnerID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -1736,38 +1639,6 @@ func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Contex
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsPerModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPerModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPerModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsSummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsSummary").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsSummary").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsTimeSeries(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsTimeSeries").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsTimeSeries").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetParameterSchemasByJobID(ctx, jobID)
|
||||
@@ -2352,14 +2223,6 @@ func (m queryMetricsStore) GetUserChatCustomPrompt(ctx context.Context, userID u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetUserChatSpendInPeriod").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatSpendInPeriod").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserCount(ctx, includeSystem)
|
||||
@@ -2368,14 +2231,6 @@ func (m queryMetricsStore) GetUserCount(ctx context.Context, includeSystem bool)
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserGroupSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetUserGroupSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserGroupSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserLatencyInsights(ctx, arg)
|
||||
@@ -2984,14 +2839,6 @@ func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeModelThought(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertAIBridgeModelThought").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertAIBridgeModelThought").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeTokenUsage(ctx, arg)
|
||||
@@ -3056,11 +2903,11 @@ func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.Inse
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertChatMessages(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessages").Inc()
|
||||
r0, r1 := m.s.InsertChatMessage(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -3616,22 +3463,6 @@ func (m queryMetricsStore) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitGroupOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitGroupOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitGroupOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListChatUsageLimitOverrides(ctx)
|
||||
m.queryLatencies.WithLabelValues("ListChatUsageLimitOverrides").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListChatUsageLimitOverrides").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListProvisionerKeysByOrganization(ctx, organizationID)
|
||||
@@ -3744,14 +3575,6 @@ func (m queryMetricsStore) RemoveUserFromGroups(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ResolveUserChatSpendLimit(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ResolveUserChatSpendLimit").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ResolveUserChatSpendLimit").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||
@@ -3768,22 +3591,6 @@ func (m queryMetricsStore) SelectUsageEventsForPublishing(ctx context.Context, n
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteChatMessageByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteChatMessageByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessageByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteChatMessagesAfterID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteChatMessagesAfterID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4020,7 +3827,6 @@ func (m queryMetricsStore) UpdateOrganizationWorkspaceSharingSettings(ctx contex
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateOrganizationWorkspaceSharingSettings(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateOrganizationWorkspaceSharingSettings").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateOrganizationWorkspaceSharingSettings").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4568,14 +4374,6 @@ func (m queryMetricsStore) UpdateWorkspacesTTLByTemplateID(ctx context.Context,
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertAISeatState(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertAISeatState").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertAISeatState").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertAnnouncementBanners(ctx, value)
|
||||
@@ -4600,14 +4398,6 @@ func (m queryMetricsStore) UpsertBoundaryUsageStats(ctx context.Context, arg dat
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatDesktopEnabled(ctx, enableDesktop)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatDesktopEnabled").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatDesktopEnabled").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatDiffStatus(ctx, arg)
|
||||
@@ -4632,30 +4422,6 @@ func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value str
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitConfig(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitConfig").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitConfig").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitGroupOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitGroupOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitGroupOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertChatUsageLimitUserOverride(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatUsageLimitUserOverride").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatUsageLimitUserOverride").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpsertConnectionLog(ctx, arg)
|
||||
@@ -4832,14 +4598,6 @@ func (m queryMetricsStore) UpsertWorkspaceAppAuditSession(ctx context.Context, a
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UsageEventExistsByID(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("UsageEventExistsByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UsageEventExistsByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ValidateGroupIDs(ctx, groupIds)
|
||||
@@ -4959,11 +4717,3 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeModels").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("GetAuthorizedChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAuthorizedChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -424,21 +424,6 @@ func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing mocks base method.
|
||||
func (m *MockStore) CountEnabledModelsWithoutPricing(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountEnabledModelsWithoutPricing", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountEnabledModelsWithoutPricing indicates an expected call of CountEnabledModelsWithoutPricing.
|
||||
func (mr *MockStoreMockRecorder) CountEnabledModelsWithoutPricing(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEnabledModelsWithoutPricing", reflect.TypeOf((*MockStore)(nil).CountEnabledModelsWithoutPricing), ctx)
|
||||
}
|
||||
|
||||
// CountInProgressPrebuilds mocks base method.
|
||||
func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -598,6 +583,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(ctx, us
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID mocks base method.
|
||||
func (m *MockStore) DeleteChatMessagesAfterID(ctx context.Context, arg database.DeleteChatMessagesAfterIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatMessagesAfterID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatMessagesAfterID indicates an expected call of DeleteChatMessagesAfterID.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).DeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatModelConfigByID mocks base method.
|
||||
func (m *MockStore) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -640,34 +639,6 @@ func (mr *MockStoreMockRecorder) DeleteChatQueuedMessage(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatQueuedMessage", reflect.TypeOf((*MockStore)(nil).DeleteChatQueuedMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitGroupOverride indicates an expected call of DeleteChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteChatUsageLimitUserOverride indicates an expected call of DeleteChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) DeleteChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).DeleteChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// DeleteCryptoKey mocks base method.
|
||||
func (m *MockStore) DeleteCryptoKey(ctx context.Context, arg database.DeleteCryptoKeyParams) (database.CryptoKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1141,17 +1112,17 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, arg database.DeleteWorkspaceACLsByOrganizationParams) error {
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, arg any) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
@@ -1507,21 +1478,6 @@ func (mr *MockStoreMockRecorder) GetAPIKeysLastUsedAfter(ctx, lastUsed any) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAPIKeysLastUsedAfter", reflect.TypeOf((*MockStore)(nil).GetAPIKeysLastUsedAfter), ctx, lastUsed)
|
||||
}
|
||||
|
||||
// GetActiveAISeatCount mocks base method.
|
||||
func (m *MockStore) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveAISeatCount", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveAISeatCount indicates an expected call of GetActiveAISeatCount.
|
||||
func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1717,21 +1673,6 @@ func (mr *MockStoreMockRecorder) GetAuthorizedAuditLogsOffset(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedAuditLogsOffset", reflect.TypeOf((*MockStore)(nil).GetAuthorizedAuditLogsOffset), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedChats mocks base method.
|
||||
func (m *MockStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizedChats", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthorizedChats indicates an expected call of GetAuthorizedChats.
|
||||
func (mr *MockStoreMockRecorder) GetAuthorizedChats(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizedChats", reflect.TypeOf((*MockStore)(nil).GetAuthorizedChats), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// GetAuthorizedConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1837,81 +1778,6 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatCostPerChat mocks base method.
|
||||
func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerChat", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerChat indicates an expected call of GetChatCostPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerChat", reflect.TypeOf((*MockStore)(nil).GetChatCostPerChat), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostPerModel mocks base method.
|
||||
func (m *MockStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerModel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerModelRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerModel indicates an expected call of GetChatCostPerModel.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerModel", reflect.TypeOf((*MockStore)(nil).GetChatCostPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostPerUser mocks base method.
|
||||
func (m *MockStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerUser", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerUserRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerUser indicates an expected call of GetChatCostPerUser.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerUser(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerUser", reflect.TypeOf((*MockStore)(nil).GetChatCostPerUser), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostSummary mocks base method.
|
||||
func (m *MockStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostSummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.GetChatCostSummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostSummary indicates an expected call of GetChatCostSummary.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) GetChatDesktopEnabled(ctx context.Context) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatDesktopEnabled", ctx)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatDesktopEnabled indicates an expected call of GetChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) GetChatDesktopEnabled(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).GetChatDesktopEnabled), ctx)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2002,21 +1868,6 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDDescPaginated", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated indicates an expected call of GetChatMessagesByChatIDDescPaginated.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDDescPaginated(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDDescPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDDescPaginated), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesForPromptByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2137,64 +1988,19 @@ func (mr *MockStoreMockRecorder) GetChatSystemPrompt(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).GetChatSystemPrompt), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitConfig(ctx context.Context) (database.ChatUsageLimitConfig, error) {
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitConfig", ctx)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitConfig indicates an expected call of GetChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitConfig(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitConfig), ctx)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) (database.GetChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitGroupOverride", ctx, groupID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitGroupOverride indicates an expected call of GetChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitGroupOverride(ctx, groupID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitGroupOverride), ctx, groupID)
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) GetChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) (database.GetChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatUsageLimitUserOverride", ctx, userID)
|
||||
ret0, _ := ret[0].(database.GetChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatUsageLimitUserOverride indicates an expected call of GetChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) GetChatUsageLimitUserOverride(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).GetChatUsageLimitUserOverride), ctx, userID)
|
||||
}
|
||||
|
||||
// GetChats mocks base method.
|
||||
func (m *MockStore) GetChats(ctx context.Context, arg database.GetChatsParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChats", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChats indicates an expected call of GetChats.
|
||||
func (mr *MockStoreMockRecorder) GetChats(ctx, arg any) *gomock.Call {
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChats", reflect.TypeOf((*MockStore)(nil).GetChats), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
@@ -3202,66 +3008,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRInsightsPerModelParams) ([]database.GetPRInsightsPerModelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPerModel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPerModelRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsPerModel indicates an expected call of GetPRInsightsPerModel.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
func (m *MockStore) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsSummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.GetPRInsightsSummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary indicates an expected call of GetPRInsightsSummary.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsSummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsSummary", reflect.TypeOf((*MockStore)(nil).GetPRInsightsSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries mocks base method.
|
||||
func (m *MockStore) GetPRInsightsTimeSeries(ctx context.Context, arg database.GetPRInsightsTimeSeriesParams) ([]database.GetPRInsightsTimeSeriesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsTimeSeries", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsTimeSeriesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsTimeSeries indicates an expected call of GetPRInsightsTimeSeries.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsTimeSeries(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsTimeSeries", reflect.TypeOf((*MockStore)(nil).GetPRInsightsTimeSeries), ctx, arg)
|
||||
}
|
||||
|
||||
// GetParameterSchemasByJobID mocks base method.
|
||||
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4387,21 +4133,6 @@ func (mr *MockStoreMockRecorder) GetUserChatCustomPrompt(ctx, userID any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).GetUserChatCustomPrompt), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod mocks base method.
|
||||
func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserChatSpendInPeriod", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserChatSpendInPeriod indicates an expected call of GetUserChatSpendInPeriod.
|
||||
func (mr *MockStoreMockRecorder) GetUserChatSpendInPeriod(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatSpendInPeriod", reflect.TypeOf((*MockStore)(nil).GetUserChatSpendInPeriod), ctx, arg)
|
||||
}
|
||||
|
||||
// GetUserCount mocks base method.
|
||||
func (m *MockStore) GetUserCount(ctx context.Context, includeSystem bool) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4417,21 +4148,6 @@ func (mr *MockStoreMockRecorder) GetUserCount(ctx, includeSystem any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserCount", reflect.TypeOf((*MockStore)(nil).GetUserCount), ctx, includeSystem)
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit mocks base method.
|
||||
func (m *MockStore) GetUserGroupSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserGroupSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserGroupSpendLimit indicates an expected call of GetUserGroupSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) GetUserGroupSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserGroupSpendLimit", reflect.TypeOf((*MockStore)(nil).GetUserGroupSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserLatencyInsights mocks base method.
|
||||
func (m *MockStore) GetUserLatencyInsights(ctx context.Context, arg database.GetUserLatencyInsightsParams) ([]database.GetUserLatencyInsightsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5586,21 +5302,6 @@ func (mr *MockStoreMockRecorder) InsertAIBridgeInterception(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeInterception", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeInterception), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeModelThought(ctx context.Context, arg database.InsertAIBridgeModelThoughtParams) (database.AIBridgeModelThought, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertAIBridgeModelThought", ctx, arg)
|
||||
ret0, _ := ret[0].(database.AIBridgeModelThought)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertAIBridgeModelThought indicates an expected call of InsertAIBridgeModelThought.
|
||||
func (mr *MockStoreMockRecorder) InsertAIBridgeModelThought(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeModelThought", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeModelThought), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertAIBridgeTokenUsage mocks base method.
|
||||
func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) (database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5721,19 +5422,19 @@ func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatMessages mocks base method.
|
||||
func (m *MockStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
// InsertChatMessage mocks base method.
|
||||
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InsertChatMessages", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertChatMessages indicates an expected call of InsertChatMessages.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessages(ctx, arg any) *gomock.Call {
|
||||
// InsertChatMessage indicates an expected call of InsertChatMessage.
|
||||
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessages", reflect.TypeOf((*MockStore)(nil).InsertChatMessages), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
|
||||
}
|
||||
|
||||
// InsertChatModelConfig mocks base method.
|
||||
@@ -6786,36 +6487,6 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitGroupOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitGroupOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides indicates an expected call of ListChatUsageLimitGroupOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitGroupOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitGroupOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitGroupOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitOverrides(ctx context.Context) ([]database.ListChatUsageLimitOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListChatUsageLimitOverrides", ctx)
|
||||
ret0, _ := ret[0].([]database.ListChatUsageLimitOverridesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListChatUsageLimitOverrides indicates an expected call of ListChatUsageLimitOverrides.
|
||||
func (mr *MockStoreMockRecorder) ListChatUsageLimitOverrides(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListChatUsageLimitOverrides", reflect.TypeOf((*MockStore)(nil).ListChatUsageLimitOverrides), ctx)
|
||||
}
|
||||
|
||||
// ListProvisionerKeysByOrganization mocks base method.
|
||||
func (m *MockStore) ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7054,21 +6725,6 @@ func (mr *MockStoreMockRecorder) RemoveUserFromGroups(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), ctx, arg)
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit mocks base method.
|
||||
func (m *MockStore) ResolveUserChatSpendLimit(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ResolveUserChatSpendLimit", ctx, userID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ResolveUserChatSpendLimit indicates an expected call of ResolveUserChatSpendLimit.
|
||||
func (mr *MockStoreMockRecorder) ResolveUserChatSpendLimit(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveUserChatSpendLimit", reflect.TypeOf((*MockStore)(nil).ResolveUserChatSpendLimit), ctx, userID)
|
||||
}
|
||||
|
||||
// RevokeDBCryptKey mocks base method.
|
||||
func (m *MockStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7098,34 +6754,6 @@ func (mr *MockStoreMockRecorder) SelectUsageEventsForPublishing(ctx, now any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SelectUsageEventsForPublishing", reflect.TypeOf((*MockStore)(nil).SelectUsageEventsForPublishing), ctx, now)
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessageByID mocks base method.
|
||||
func (m *MockStore) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteChatMessageByID", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessageByID indicates an expected call of SoftDeleteChatMessageByID.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteChatMessageByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessageByID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessagesAfterID mocks base method.
|
||||
func (m *MockStore) SoftDeleteChatMessagesAfterID(ctx context.Context, arg database.SoftDeleteChatMessagesAfterIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteChatMessagesAfterID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteChatMessagesAfterID indicates an expected call of SoftDeleteChatMessagesAfterID.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8541,21 +8169,6 @@ func (mr *MockStoreMockRecorder) UpdateWorkspacesTTLByTemplateID(ctx, arg any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspacesTTLByTemplateID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspacesTTLByTemplateID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertAISeatState mocks base method.
|
||||
func (m *MockStore) UpsertAISeatState(ctx context.Context, arg database.UpsertAISeatStateParams) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertAISeatState", ctx, arg)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertAISeatState indicates an expected call of UpsertAISeatState.
|
||||
func (mr *MockStoreMockRecorder) UpsertAISeatState(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAISeatState", reflect.TypeOf((*MockStore)(nil).UpsertAISeatState), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertAnnouncementBanners mocks base method.
|
||||
func (m *MockStore) UpsertAnnouncementBanners(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8599,20 +8212,6 @@ func (mr *MockStoreMockRecorder) UpsertBoundaryUsageStats(ctx, arg any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertBoundaryUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertBoundaryUsageStats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled mocks base method.
|
||||
func (m *MockStore) UpsertChatDesktopEnabled(ctx context.Context, enableDesktop bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDesktopEnabled", ctx, enableDesktop)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatDesktopEnabled indicates an expected call of UpsertChatDesktopEnabled.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDesktopEnabled(ctx, enableDesktop any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDesktopEnabled", reflect.TypeOf((*MockStore)(nil).UpsertChatDesktopEnabled), ctx, enableDesktop)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8657,51 +8256,6 @@ func (mr *MockStoreMockRecorder) UpsertChatSystemPrompt(ctx, value any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatSystemPrompt), ctx, value)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitConfig(ctx context.Context, arg database.UpsertChatUsageLimitConfigParams) (database.ChatUsageLimitConfig, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitConfig", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatUsageLimitConfig)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitConfig indicates an expected call of UpsertChatUsageLimitConfig.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitConfig(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitConfig", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitConfig), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitGroupOverride(ctx context.Context, arg database.UpsertChatUsageLimitGroupOverrideParams) (database.UpsertChatUsageLimitGroupOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitGroupOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitGroupOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitGroupOverride indicates an expected call of UpsertChatUsageLimitGroupOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitGroupOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitGroupOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitGroupOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride mocks base method.
|
||||
func (m *MockStore) UpsertChatUsageLimitUserOverride(ctx context.Context, arg database.UpsertChatUsageLimitUserOverrideParams) (database.UpsertChatUsageLimitUserOverrideRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatUsageLimitUserOverride", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UpsertChatUsageLimitUserOverrideRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatUsageLimitUserOverride indicates an expected call of UpsertChatUsageLimitUserOverride.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatUsageLimitUserOverride(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatUsageLimitUserOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatUsageLimitUserOverride), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertConnectionLog mocks base method.
|
||||
func (m *MockStore) UpsertConnectionLog(ctx context.Context, arg database.UpsertConnectionLogParams) (database.ConnectionLog, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9019,21 +8573,6 @@ func (mr *MockStoreMockRecorder) UpsertWorkspaceAppAuditSession(ctx, arg any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertWorkspaceAppAuditSession", reflect.TypeOf((*MockStore)(nil).UpsertWorkspaceAppAuditSession), ctx, arg)
|
||||
}
|
||||
|
||||
// UsageEventExistsByID mocks base method.
|
||||
func (m *MockStore) UsageEventExistsByID(ctx context.Context, id string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UsageEventExistsByID", ctx, id)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UsageEventExistsByID indicates an expected call of UsageEventExistsByID.
|
||||
func (mr *MockStoreMockRecorder) UsageEventExistsByID(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageEventExistsByID", reflect.TypeOf((*MockStore)(nil).UsageEventExistsByID), ctx, id)
|
||||
}
|
||||
|
||||
// ValidateGroupIDs mocks base method.
|
||||
func (m *MockStore) ValidateGroupIDs(ctx context.Context, groupIds []uuid.UUID) (database.ValidateGroupIDsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+17
-141
@@ -10,11 +10,6 @@ CREATE TYPE agent_key_scope_enum AS ENUM (
|
||||
'no_user_data'
|
||||
);
|
||||
|
||||
CREATE TYPE ai_seat_usage_reason AS ENUM (
|
||||
'aibridge',
|
||||
'task'
|
||||
);
|
||||
|
||||
CREATE TYPE api_key_scope AS ENUM (
|
||||
'coder:all',
|
||||
'coder:application_connect',
|
||||
@@ -270,23 +265,12 @@ CREATE TYPE build_reason AS ENUM (
|
||||
'task_resume'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_message_role AS ENUM (
|
||||
'system',
|
||||
'user',
|
||||
'assistant',
|
||||
'tool'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_message_visibility AS ENUM (
|
||||
'user',
|
||||
'model',
|
||||
'both'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_mode AS ENUM (
|
||||
'computer_use'
|
||||
);
|
||||
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
@@ -508,14 +492,7 @@ CREATE TYPE resource_type AS ENUM (
|
||||
'workspace_agent',
|
||||
'workspace_app',
|
||||
'prebuilds_settings',
|
||||
'task',
|
||||
'ai_seat'
|
||||
);
|
||||
|
||||
CREATE TYPE shareable_workspace_owners AS ENUM (
|
||||
'none',
|
||||
'everyone',
|
||||
'service_accounts'
|
||||
'task'
|
||||
);
|
||||
|
||||
CREATE TYPE startup_script_behavior AS ENUM (
|
||||
@@ -620,35 +597,28 @@ CREATE FUNCTION aggregate_usage_event() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
-- Check for supported event types and throw error for unknown types.
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1', 'hb_ai_seats_v1') THEN
|
||||
-- Check for supported event types and throw error for unknown types
|
||||
IF NEW.event_type NOT IN ('dc_managed_agents_v1') THEN
|
||||
RAISE EXCEPTION 'Unhandled usage event type in aggregate_usage_event: %', NEW.event_type;
|
||||
END IF;
|
||||
|
||||
INSERT INTO usage_events_daily (day, event_type, usage_data)
|
||||
VALUES (
|
||||
-- Extract the date from the created_at timestamp, always using UTC for
|
||||
-- consistency
|
||||
date_trunc('day', NEW.created_at AT TIME ZONE 'UTC')::date,
|
||||
NEW.event_type,
|
||||
NEW.event_data
|
||||
)
|
||||
ON CONFLICT (day, event_type) DO UPDATE SET
|
||||
usage_data = CASE
|
||||
-- Handle simple counter events by summing the count.
|
||||
-- Handle simple counter events by summing the count
|
||||
WHEN NEW.event_type IN ('dc_managed_agents_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0) +
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
-- Heartbeat events: keep the max value seen that day
|
||||
WHEN NEW.event_type IN ('hb_ai_seats_v1') THEN
|
||||
jsonb_build_object(
|
||||
'count',
|
||||
GREATEST(
|
||||
COALESCE((usage_events_daily.usage_data->>'count')::bigint, 0),
|
||||
COALESCE((NEW.event_data->>'count')::bigint, 0)
|
||||
)
|
||||
)
|
||||
END;
|
||||
|
||||
RETURN NEW;
|
||||
@@ -805,7 +775,7 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE FUNCTION insert_organization_system_roles() RETURNS trigger
|
||||
CREATE FUNCTION insert_org_member_system_role() RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
@@ -820,8 +790,7 @@ BEGIN
|
||||
is_system,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES
|
||||
(
|
||||
) VALUES (
|
||||
'organization-member',
|
||||
'',
|
||||
NEW.id,
|
||||
@@ -832,18 +801,6 @@ BEGIN
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
),
|
||||
(
|
||||
'organization-service-account',
|
||||
'',
|
||||
NEW.id,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
'[]'::jsonb,
|
||||
true,
|
||||
NOW(),
|
||||
NOW()
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
@@ -1078,15 +1035,6 @@ BEGIN
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE TABLE ai_seat_state (
|
||||
user_id uuid NOT NULL,
|
||||
first_used_at timestamp with time zone NOT NULL,
|
||||
last_used_at timestamp with time zone NOT NULL,
|
||||
last_event_type ai_seat_usage_reason NOT NULL,
|
||||
last_event_description text NOT NULL,
|
||||
updated_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE aibridge_interceptions (
|
||||
id uuid NOT NULL,
|
||||
initiator_id uuid NOT NULL,
|
||||
@@ -1112,15 +1060,6 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
metadata jsonb,
|
||||
created_at timestamp with time zone NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_model_thoughts IS 'Audit log of model thinking in intercepted requests in AI Bridge';
|
||||
|
||||
CREATE TABLE aibridge_token_usages (
|
||||
id uuid NOT NULL,
|
||||
interception_id uuid NOT NULL,
|
||||
@@ -1250,15 +1189,7 @@ CREATE TABLE chat_diff_statuses (
|
||||
git_branch text DEFAULT ''::text NOT NULL,
|
||||
git_remote_origin text DEFAULT ''::text NOT NULL,
|
||||
pull_request_title text DEFAULT ''::text NOT NULL,
|
||||
pull_request_draft boolean DEFAULT false NOT NULL,
|
||||
author_login text,
|
||||
author_avatar_url text,
|
||||
base_branch text,
|
||||
pr_number integer,
|
||||
commits integer,
|
||||
approved boolean,
|
||||
reviewer_count integer,
|
||||
head_branch text
|
||||
pull_request_draft boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE chat_files (
|
||||
@@ -1276,7 +1207,7 @@ CREATE TABLE chat_messages (
|
||||
chat_id uuid NOT NULL,
|
||||
model_config_id uuid,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
role chat_message_role NOT NULL,
|
||||
role text NOT NULL,
|
||||
content jsonb,
|
||||
visibility chat_message_visibility DEFAULT 'both'::chat_message_visibility NOT NULL,
|
||||
input_tokens bigint,
|
||||
@@ -1287,11 +1218,7 @@ CREATE TABLE chat_messages (
|
||||
cache_read_tokens bigint,
|
||||
context_limit bigint,
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint,
|
||||
runtime_ms bigint,
|
||||
deleted boolean DEFAULT false NOT NULL
|
||||
created_by uuid
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -1355,28 +1282,6 @@ CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
|
||||
ALTER SEQUENCE chat_queued_messages_id_seq OWNED BY chat_queued_messages.id;
|
||||
|
||||
CREATE TABLE chat_usage_limit_config (
|
||||
id bigint NOT NULL,
|
||||
singleton boolean DEFAULT true NOT NULL,
|
||||
enabled boolean DEFAULT false NOT NULL,
|
||||
default_limit_micros bigint DEFAULT 0 NOT NULL,
|
||||
period text DEFAULT 'month'::text NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
CONSTRAINT chat_usage_limit_config_default_limit_micros_check CHECK ((default_limit_micros >= 0)),
|
||||
CONSTRAINT chat_usage_limit_config_period_check CHECK ((period = ANY (ARRAY['day'::text, 'week'::text, 'month'::text]))),
|
||||
CONSTRAINT chat_usage_limit_config_singleton_check CHECK (singleton)
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_usage_limit_config_id_seq
|
||||
START WITH 1
|
||||
INCREMENT BY 1
|
||||
NO MINVALUE
|
||||
NO MAXVALUE
|
||||
CACHE 1;
|
||||
|
||||
ALTER SEQUENCE chat_usage_limit_config_id_seq OWNED BY chat_usage_limit_config.id;
|
||||
|
||||
CREATE TABLE chats (
|
||||
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||
owner_id uuid NOT NULL,
|
||||
@@ -1392,8 +1297,7 @@ CREATE TABLE chats (
|
||||
root_chat_id uuid,
|
||||
last_model_config_id uuid NOT NULL,
|
||||
archived boolean DEFAULT false NOT NULL,
|
||||
last_error text,
|
||||
mode chat_mode
|
||||
last_error text
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -1533,9 +1437,7 @@ CREATE TABLE groups (
|
||||
avatar_url text DEFAULT ''::text NOT NULL,
|
||||
quota_allowance integer DEFAULT 0 NOT NULL,
|
||||
display_name text DEFAULT ''::text NOT NULL,
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT groups_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0)))
|
||||
source group_source DEFAULT 'user'::group_source NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN groups.display_name IS 'Display name is a custom, human-friendly group name that user can set. This is not required to be unique and can be the empty string.';
|
||||
@@ -1570,9 +1472,7 @@ CREATE TABLE users (
|
||||
one_time_passcode_expires_at timestamp with time zone,
|
||||
is_system boolean DEFAULT false NOT NULL,
|
||||
is_service_account boolean DEFAULT false NOT NULL,
|
||||
chat_spend_limit_micros bigint,
|
||||
CONSTRAINT one_time_passcode_set CHECK ((((hashed_one_time_passcode IS NULL) AND (one_time_passcode_expires_at IS NULL)) OR ((hashed_one_time_passcode IS NOT NULL) AND (one_time_passcode_expires_at IS NOT NULL)))),
|
||||
CONSTRAINT users_chat_spend_limit_micros_check CHECK (((chat_spend_limit_micros IS NULL) OR (chat_spend_limit_micros > 0))),
|
||||
CONSTRAINT users_email_not_empty CHECK (((is_service_account = true) = (email = ''::text))),
|
||||
CONSTRAINT users_service_account_login_type CHECK (((is_service_account = false) OR (login_type = 'none'::login_type))),
|
||||
CONSTRAINT users_username_min_length CHECK ((length(username) >= 1))
|
||||
@@ -1860,11 +1760,9 @@ CREATE TABLE organizations (
|
||||
display_name text NOT NULL,
|
||||
icon text DEFAULT ''::text NOT NULL,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
shareable_workspace_owners shareable_workspace_owners DEFAULT 'everyone'::shareable_workspace_owners NOT NULL
|
||||
workspace_sharing_disabled boolean DEFAULT false NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON COLUMN organizations.shareable_workspace_owners IS 'Controls whose workspaces can be shared: none, everyone, or service_accounts.';
|
||||
|
||||
CREATE TABLE parameter_schemas (
|
||||
id uuid NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL,
|
||||
@@ -2664,7 +2562,7 @@ CREATE TABLE usage_events (
|
||||
publish_started_at timestamp with time zone,
|
||||
published_at timestamp with time zone,
|
||||
failure_message text,
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = ANY (ARRAY['dc_managed_agents_v1'::text, 'hb_ai_seats_v1'::text])))
|
||||
CONSTRAINT usage_event_type_check CHECK ((event_type = 'dc_managed_agents_v1'::text))
|
||||
);
|
||||
|
||||
COMMENT ON TABLE usage_events IS 'usage_events contains usage data that is collected from the product and potentially shipped to the usage collector service.';
|
||||
@@ -3221,8 +3119,6 @@ ALTER TABLE ONLY chat_messages ALTER COLUMN id SET DEFAULT nextval('chat_message
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages ALTER COLUMN id SET DEFAULT nextval('chat_queued_messages_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config ALTER COLUMN id SET DEFAULT nextval('chat_usage_limit_config_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY licenses ALTER COLUMN id SET DEFAULT nextval('licenses_id_seq'::regclass);
|
||||
|
||||
ALTER TABLE ONLY provisioner_job_logs ALTER COLUMN id SET DEFAULT nextval('provisioner_job_logs_id_seq'::regclass);
|
||||
@@ -3238,9 +3134,6 @@ ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval
|
||||
ALTER TABLE ONLY workspace_agent_stats
|
||||
ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY ai_seat_state
|
||||
ADD CONSTRAINT ai_seat_state_pkey PRIMARY KEY (user_id);
|
||||
|
||||
ALTER TABLE ONLY aibridge_interceptions
|
||||
ADD CONSTRAINT aibridge_interceptions_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3283,12 +3176,6 @@ ALTER TABLE ONLY chat_providers
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id);
|
||||
|
||||
ALTER TABLE ONLY chat_usage_limit_config
|
||||
ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton);
|
||||
|
||||
ALTER TABLE ONLY chats
|
||||
ADD CONSTRAINT chats_pkey PRIMARY KEY (id);
|
||||
|
||||
@@ -3601,8 +3488,6 @@ CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptio
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_root_id ON aibridge_interceptions USING btree (thread_root_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_model_thoughts_interception_id ON aibridge_model_thoughts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_interception_id ON aibridge_token_usages USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_token_usages_provider_response_id ON aibridge_token_usages USING btree (provider_response_id);
|
||||
@@ -3639,11 +3524,7 @@ CREATE INDEX idx_chat_messages_chat ON chat_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_id, created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::chat_message_role) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
|
||||
|
||||
CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id, created_at) WHERE (total_cost_micros IS NOT NULL);
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::text) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
@@ -3719,8 +3600,6 @@ CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree
|
||||
|
||||
CREATE UNIQUE INDEX idx_unique_preset_name ON template_version_presets USING btree (name, template_version_id);
|
||||
|
||||
CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, created_at) WHERE (event_type = 'hb_ai_seats_v1'::text);
|
||||
|
||||
CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at);
|
||||
|
||||
CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at);
|
||||
@@ -3895,7 +3774,7 @@ CREATE TRIGGER trigger_delete_oauth2_provider_app_token AFTER DELETE ON oauth2_p
|
||||
|
||||
CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted();
|
||||
|
||||
CREATE TRIGGER trigger_insert_organization_system_roles AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_organization_system_roles();
|
||||
CREATE TRIGGER trigger_insert_org_member_system_role AFTER INSERT ON organizations FOR EACH ROW EXECUTE FUNCTION insert_org_member_system_role();
|
||||
|
||||
CREATE TRIGGER trigger_nullify_next_start_at_on_workspace_autostart_modificati AFTER UPDATE ON workspaces FOR EACH ROW EXECUTE FUNCTION nullify_next_start_at_on_workspace_autostart_modification();
|
||||
|
||||
@@ -3913,9 +3792,6 @@ COMMENT ON TRIGGER workspace_agent_name_unique_trigger ON workspace_agents IS 'U
|
||||
the uniqueness requirement. A trigger allows us to enforce uniqueness going
|
||||
forward without requiring a migration to clean up historical data.';
|
||||
|
||||
ALTER TABLE ONLY ai_seat_state
|
||||
ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY aibridge_interceptions
|
||||
ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ type ForeignKeyConstraint string
|
||||
|
||||
// ForeignKeyConstraint enums.
|
||||
const (
|
||||
ForeignKeyAiSeatStateUserID ForeignKeyConstraint = "ai_seat_state_user_id_fkey" // ALTER TABLE ONLY ai_seat_state ADD CONSTRAINT ai_seat_state_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyAibridgeInterceptionsInitiatorID ForeignKeyConstraint = "aibridge_interceptions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_interceptions ADD CONSTRAINT aibridge_interceptions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id);
|
||||
ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatDiffStatusesChatID ForeignKeyConstraint = "chat_diff_statuses_chat_id_fkey" // ALTER TABLE ONLY chat_diff_statuses ADD CONSTRAINT chat_diff_statuses_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -26,7 +26,6 @@ func TestCustomQueriesSyncedRowScan(t *testing.T) {
|
||||
"GetTemplatesWithFilter": "GetAuthorizedTemplates",
|
||||
"GetWorkspaces": "GetAuthorizedWorkspaces",
|
||||
"GetUsers": "GetAuthorizedUsers",
|
||||
"GetChats": "GetAuthorizedChats",
|
||||
}
|
||||
|
||||
// Scan custom
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
ALTER TABLE chat_messages DROP COLUMN content_version;
|
||||
|
||||
DROP INDEX idx_chat_messages_compressed_summary_boundary;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
ALTER COLUMN role TYPE text
|
||||
USING (role::text);
|
||||
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary
|
||||
ON chat_messages(chat_id, created_at DESC, id DESC)
|
||||
WHERE compressed = TRUE
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both');
|
||||
|
||||
DROP TYPE chat_message_role;
|
||||
@@ -1,32 +0,0 @@
|
||||
-- Add chat_message_role enum.
|
||||
CREATE TYPE chat_message_role AS ENUM (
|
||||
'system',
|
||||
'user',
|
||||
'assistant',
|
||||
'tool'
|
||||
);
|
||||
|
||||
-- Drop the partial index that references role as text before
|
||||
-- converting the column type.
|
||||
DROP INDEX idx_chat_messages_compressed_summary_boundary;
|
||||
|
||||
-- Convert role column from text to enum.
|
||||
ALTER TABLE chat_messages
|
||||
ALTER COLUMN role TYPE chat_message_role
|
||||
USING (role::chat_message_role);
|
||||
|
||||
-- Recreate the partial index with enum-typed comparison.
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary
|
||||
ON chat_messages(chat_id, created_at DESC, id DESC)
|
||||
WHERE compressed = TRUE
|
||||
AND role = 'system'
|
||||
AND visibility IN ('model', 'both');
|
||||
|
||||
-- Add content_version column. Default 0 backfills existing rows.
|
||||
-- The default is then dropped so future inserts must specify the
|
||||
-- version explicitly.
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN content_version smallint NOT NULL DEFAULT 0;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
ALTER COLUMN content_version DROP DEFAULT;
|
||||
@@ -1,3 +0,0 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_created_at;
|
||||
|
||||
ALTER TABLE chat_messages DROP COLUMN total_cost_micros;
|
||||
@@ -1,68 +0,0 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN total_cost_micros BIGINT;
|
||||
|
||||
WITH message_costs AS (
|
||||
SELECT
|
||||
msg.id,
|
||||
ROUND(
|
||||
COALESCE(msg.input_tokens, 0)::numeric * COALESCE(pricing.input_price, 0)
|
||||
+ COALESCE(msg.output_tokens, 0)::numeric * COALESCE(pricing.output_price, 0)
|
||||
+ COALESCE(msg.cache_read_tokens, 0)::numeric * COALESCE(pricing.cache_read_price, 0)
|
||||
+ COALESCE(msg.cache_creation_tokens, 0)::numeric * COALESCE(pricing.cache_write_price, 0)
|
||||
)::bigint AS total_cost_micros
|
||||
FROM
|
||||
chat_messages AS msg
|
||||
JOIN
|
||||
chat_model_configs AS cfg
|
||||
ON
|
||||
cfg.id = msg.model_config_id
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'input_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'input_price_per_million_tokens')::numeric
|
||||
) AS input_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'output_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'output_price_per_million_tokens')::numeric
|
||||
) AS output_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'cache_read_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'cache_read_price_per_million_tokens')::numeric
|
||||
) AS cache_read_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'cache_write_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'cache_write_price_per_million_tokens')::numeric
|
||||
) AS cache_write_price
|
||||
) AS pricing
|
||||
WHERE
|
||||
msg.total_cost_micros IS NULL
|
||||
AND (
|
||||
msg.input_tokens IS NOT NULL
|
||||
OR msg.output_tokens IS NOT NULL
|
||||
OR msg.reasoning_tokens IS NOT NULL
|
||||
OR msg.cache_creation_tokens IS NOT NULL
|
||||
OR msg.cache_read_tokens IS NOT NULL
|
||||
)
|
||||
AND (
|
||||
pricing.input_price IS NOT NULL
|
||||
OR pricing.output_price IS NOT NULL
|
||||
OR pricing.cache_read_price IS NOT NULL
|
||||
OR pricing.cache_write_price IS NOT NULL
|
||||
)
|
||||
AND (
|
||||
(msg.input_tokens IS NOT NULL AND pricing.input_price IS NOT NULL)
|
||||
OR (msg.output_tokens IS NOT NULL AND pricing.output_price IS NOT NULL)
|
||||
OR (msg.cache_read_tokens IS NOT NULL AND pricing.cache_read_price IS NOT NULL)
|
||||
OR (msg.cache_creation_tokens IS NOT NULL AND pricing.cache_write_price IS NOT NULL)
|
||||
)
|
||||
)
|
||||
UPDATE
|
||||
chat_messages AS msg
|
||||
SET
|
||||
total_cost_micros = message_costs.total_cost_micros
|
||||
FROM
|
||||
message_costs
|
||||
WHERE
|
||||
msg.id = message_costs.id;
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages (created_at);
|
||||
@@ -1,2 +0,0 @@
|
||||
ALTER TABLE chats DROP COLUMN mode;
|
||||
DROP TYPE IF EXISTS chat_mode;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user