Compare commits

..

6 Commits

Author SHA1 Message Date
Muhammad Atif Ali e524206c78 docs(ai-bridge): restructure into multi-page documentation
Split AI Bridge documentation from single 530-line file into focused,
audience-specific pages for better organization and maintainability.

## Changes

### Structure
- Delete single-page docs/ai-coder/ai-bridge.md
- Create docs/ai-coder/ai-bridge/ directory with 7 pages:
  - index.md: Main overview, setup, general config, SDKs, templates
  - roo-code.md: Roo Code configuration
  - claude-code.md: Claude Code CLI and VSCode
  - cursor.md: Cursor IDE
  - github-copilot.md: GitHub Copilot (VSCode + CLI status)
  - goose.md: Goose Desktop and CLI
  - other-clients.md: Kilo Code and unsupported clients

### Content Organization
- Platform teams: Server setup in index.md
- Template admins: Template patterns in index.md + client pages
- End users: Client-specific configuration in dedicated pages
- Developers: SDK examples in index.md

### Quality Improvements
- Fix all markdown linting errors (8 total)
- Remove trailing spaces from all files
- Fix broken internal links
- Update client table with links to new pages
- Use proper heading hierarchy (h4 instead of h5)
- Use auto-numbering for ordered lists (1/1/1 style)

### Benefits
- Easier maintenance (update one client without affecting others)
- Better user experience (focused guides vs scrolling large file)
- Clear audience separation (platform/admin/end-user)
- Improved navigation (table links to dedicated pages)
2025-10-22 10:48:14 +05:00
Muhammad Atif Ali c674b937a0 docs(ai-bridge): fix Anthropic base URL - remove /v1 suffix
Corrected Anthropic base URL configuration based on official dogfood
template (commit 61fba2db47):

- Anthropic SDK appends /v1/messages to base URL
- Therefore base URL should NOT include /v1
- Verified from dogfood/coder/main.tf which uses:
  ANTHROPIC_BASE_URL: "https://dev.coder.com/api/experimental/aibridge/anthropic"

Final correct URLs:
- OpenAI: .../openai/v1 (SDK appends /chat/completions)
- Anthropic: .../anthropic (SDK appends /v1/messages)

This ensures compatibility with official Anthropic Python SDK and
tools like Claude Code that follow the same pattern.
2025-10-22 10:48:14 +05:00
Muhammad Atif Ali c2a2725128 docs(ai-bridge): fix base URLs to include /v1 suffix
Corrected all AI Bridge client configuration URLs to include the /v1
suffix based on source code verification:

- OpenAI: /api/experimental/aibridge/openai/v1 (was missing /v1)
- Anthropic: /api/experimental/aibridge/anthropic/v1 (was missing /v1)

Verified from enterprise/x/aibridged/aibridged_test.go which shows:
- /openai/v1/chat/completions
- /anthropic/v1/messages

Updated in all configuration examples:
- Setting Base URLs section
- Claude Code CLI example
- Custom Python scripts example
- Terraform template example
2025-10-22 10:48:14 +05:00
Atif Ali 5224a14ed7 fixup! 2025-10-22 10:48:14 +05:00
Muhammad Atif Ali f8d94c9ecf docs(ai-bridge): add client configuration section with template examples
- Add comprehensive Client Configuration section explaining how to set
  OPENAI_BASE_URL and ANTHROPIC_BASE_URL to point to AI Bridge
- Document authentication using Coder session tokens instead of provider keys
- Add configuration examples for Claude Code CLI and custom scripts
- Add Pre-configuring in Coder Templates subsection showing how to use
  data.coder_workspace_owner.me.session_token for automatic setup
- Include Terraform example for configuring AI agents in Tasks templates
- Fix unbalanced code fence that was causing markdown linting errors
2025-10-22 10:48:14 +05:00
Muhammad Atif Ali 420c4c28fc docs: add client configuration section and examples to AI Bridge 2025-10-22 10:48:14 +05:00
379 changed files with 8431 additions and 18289 deletions
+2 -10
View File
@@ -91,9 +91,6 @@
## Systematic Debugging Approach
YOU MUST ALWAYS find the root cause of any issue you are debugging
YOU MUST NEVER fix a symptom or add a workaround instead of finding a root cause, even if it is faster.
### Multi-Issue Problem Solving
When facing multiple failing tests or complex integration issues:
@@ -101,21 +98,16 @@ When facing multiple failing tests or complex integration issues:
1. **Identify Root Causes**:
- Run failing tests individually to isolate issues
- Use LSP tools to trace through call chains
- Read Error Messages Carefully: Check both compilation and runtime errors
- Reproduce Consistently: Ensure you can reliably reproduce the issue before investigating
- Check Recent Changes: What changed that could have caused this? Git diff, recent commits, etc.
- When You Don't Know: Say "I don't understand X" rather than pretending to know
- Check both compilation and runtime errors
2. **Fix in Logical Order**:
- Address compilation issues first (imports, syntax)
- Fix authorization and RBAC issues next
- Resolve business logic and validation issues
- Handle edge cases and race conditions last
- IF your first fix doesn't work, STOP and re-analyze rather than adding more fixes
3. **Verification Strategy**:
- Always Test each fix individually before moving to next issue
- Verify Before Continuing: Did your test work? If not, form new hypothesis - don't add more fixes
- Test each fix individually before moving to next issue
- Use `make lint` and `make gen` after database changes
- Verify RFC compliance with actual specifications
- Run comprehensive test suites before considering complete
+3 -7
View File
@@ -40,15 +40,11 @@
- Use proper error types
- Pattern: `xerrors.Errorf("failed to X: %w", err)`
## Naming Conventions
### Naming Conventions
- Names MUST tell what code does, not how it's implemented or its history
- Follow Go and TypeScript naming conventions
- When changing code, never document the old behavior or the behavior change
- NEVER use implementation details in names (e.g., "ZodValidator", "MCPWrapper", "JSONParser")
- NEVER use temporal/historical context in names (e.g., "LegacyHandler", "UnifiedTool", "ImprovedInterface", "EnhancedParser")
- NEVER use pattern names unless they add clarity (e.g., prefer "Tool" over "ToolFactory")
- Use clear, descriptive names
- Abbreviate only when obvious
- Follow Go and TypeScript naming conventions
### Comments
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.24.10"
default: "1.24.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
@@ -0,0 +1,34 @@
app = "sao-paulo-coder"
primary_region = "gru"
[experimental]
entrypoint = ["/bin/sh", "-c", "CODER_DERP_SERVER_RELAY_URL=\"http://[${FLY_PRIVATE_IP}]:3000\" /opt/coder wsproxy server"]
auto_rollback = true
[build]
image = "ghcr.io/coder/coder-preview:main"
[env]
CODER_ACCESS_URL = "https://sao-paulo.fly.dev.coder.com"
CODER_HTTP_ADDRESS = "0.0.0.0:3000"
CODER_PRIMARY_ACCESS_URL = "https://dev.coder.com"
CODER_WILDCARD_ACCESS_URL = "*--apps.sao-paulo.fly.dev.coder.com"
CODER_VERBOSE = "true"
[http_service]
internal_port = 3000
force_https = true
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
memory_mb = 512
+26 -12
View File
@@ -181,7 +181,7 @@ jobs:
echo "LINT_CACHE_DIR=$dir" >> "$GITHUB_ENV"
- name: golangci-lint cache
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4
with:
path: |
${{ env.LINT_CACHE_DIR }}
@@ -191,7 +191,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@80c8a4945eec0f6d464eaf9e65ed98ef085283d1 # v1.38.1
uses: crate-ci/typos@85f62a8a84f939ae994ab3763f01a0296d61a7ee # v1.36.2
with:
config: .github/workflows/typos.toml
@@ -376,6 +376,13 @@ jobs:
id: go-paths
uses: ./.github/actions/setup-go-paths
- name: Download Go Build Cache
id: download-go-build-cache
uses: ./.github/actions/test-cache/download
with:
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Setup Go
uses: ./.github/actions/setup-go
with:
@@ -383,7 +390,8 @@ jobs:
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
# Cache is already downloaded above
use-cache: false
- name: Setup Terraform
uses: ./.github/actions/setup-tf
@@ -492,11 +500,17 @@ jobs:
make test
- name: Upload failed test db dumps
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
- name: Upload Go Build Cache
uses: ./.github/actions/test-cache/upload
with:
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Upload Test Cache
uses: ./.github/actions/test-cache/upload
with:
@@ -748,7 +762,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -756,7 +770,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -792,7 +806,7 @@ jobs:
# the check to pass. This is desired in PRs, but not in mainline.
- name: Publish to Chromatic (non-mainline)
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -824,7 +838,7 @@ jobs:
# infinitely "in progress" in mainline unless we re-review each build.
- name: Publish to Chromatic (mainline)
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -1022,7 +1036,7 @@ jobs:
- name: Upload build artifacts
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -1109,7 +1123,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1187,7 +1201,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -1454,7 +1468,7 @@ jobs:
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: coder
path: |
+4 -2
View File
@@ -76,7 +76,7 @@ jobs:
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -92,7 +92,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@4a15fa6a023259353ef750acf1c98fe88407d4d0 # v2.7.2
uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.7.0"
@@ -163,10 +163,12 @@ jobs:
run: |
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
IMAGE: ${{ inputs.image }}
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
+1 -1
View File
@@ -48,7 +48,7 @@ jobs:
persist-credentials: false
- name: Docker login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
+1 -1
View File
@@ -30,7 +30,7 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
- uses: tj-actions/changed-files@4563c729c555b4141fac99c80f699f571219b836 # v45.0.7
id: changed-files
with:
files: |
+3 -3
View File
@@ -36,11 +36,11 @@ jobs:
persist-credentials: false
- name: Setup Nix
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
with:
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
# on version 2.29 and above.
nix_version: "2.28.5"
nix_version: "2.28.4"
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
@@ -82,7 +82,7 @@ jobs:
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
+5 -5
View File
@@ -189,7 +189,7 @@ jobs:
egress-policy: audit
- name: Find Comment
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: fc
with:
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
@@ -199,7 +199,7 @@ jobs:
- name: Comment on PR
id: comment_id
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
comment-id: ${{ steps.fc.outputs.comment-id }}
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -491,7 +491,7 @@ jobs:
PASSWORD: ${{ steps.setup_deployment.outputs.password }}
- name: Find Comment
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: fc
with:
issue-number: ${{ env.PR_NUMBER }}
@@ -500,7 +500,7 @@ jobs:
direction: last
- name: Comment on PR
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
env:
STATUS: ${{ needs.get_info.outputs.NEW == 'true' && 'Created' || 'Updated' }}
with:
+6 -6
View File
@@ -131,7 +131,7 @@ jobs:
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -239,7 +239,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -327,7 +327,7 @@ jobs:
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -761,7 +761,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: release-artifacts
path: |
@@ -777,7 +777,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -785,7 +785,7 @@ jobs:
- name: Send repository-dispatch event
if: ${{ !inputs.dry_run }}
uses: peter-evans/repository-dispatch@5fc4efd1a4797ddb68ffd0714a238564e4cc0e6f # v4.0.0
uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0
with:
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
repository: coder/packages
+3 -3
View File
@@ -30,7 +30,7 @@ jobs:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
uses: ossf/scorecard-action@05b42c624433fc40578a4040d5cf5e36ddca8cde # v2.4.2
with:
results_file: results.sarif
results_format: sarif
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: SARIF file
path: results.sarif
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
with:
sarif_file: results.sarif
+4 -4
View File
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/init@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/analyze@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
@@ -154,13 +154,13 @@ jobs:
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: trivy
path: trivy-results.sarif
+3 -3
View File
@@ -23,7 +23,7 @@ jobs:
egress-policy: audit
- name: stale
uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
@@ -125,7 +125,7 @@ jobs:
egress-policy: audit
- name: Delete PR Cleanup workflow runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
@@ -134,7 +134,7 @@ jobs:
delete_workflow_pattern: pr-cleanup.yaml
- name: Delete PR Deploy workflow skipped runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
+5 -5
View File
@@ -13,12 +13,12 @@ on:
template_name:
description: "Coder template to use for workspace"
required: true
default: "coder"
default: "traiage"
type: string
template_preset:
description: "Template preset to use"
required: true
default: "none"
default: "Default"
type: string
prefix:
description: "Prefix for workspace name"
@@ -66,8 +66,8 @@ jobs:
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'coder' }}
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'none'}}
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'traiage' }}
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'Default'}}
INPUTS_PREFIX: ${{ inputs.prefix || 'traiage' }}
GH_TOKEN: ${{ github.token }}
run: |
@@ -168,7 +168,7 @@ jobs:
echo "coder_username=${coder_username}" >> "${GITHUB_OUTPUT}"
- name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@v4
with:
persist-credentials: false
fetch-depth: 0
+1 -1
View File
@@ -31,7 +31,7 @@ jobs:
persist-credentials: false
- name: Check Markdown links
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
uses: umbrelladocs/action-linkspector@874d01cae9fd488e3077b08952093235bd626977 # v1.3.7
id: markdown-link-check
# checks all markdown files from /docs including all subfolders
with:
-5
View File
@@ -12,9 +12,6 @@ node_modules/
vendor/
yarn-error.log
# Test output files
test-output/
# VSCode settings.
**/.vscode/*
# Allow VSCode recommendations and default settings in project root.
@@ -89,5 +86,3 @@ result
__debug_bin*
**/.claude/settings.local.json
/.env
+19 -40
View File
@@ -1,41 +1,11 @@
# Coder Development Guidelines
You are an experienced, pragmatic software engineer. You don't over-engineer a solution when a simple one is possible.
Rule #1: If you want exception to ANY rule, YOU MUST STOP and get explicit permission first. BREAKING THE LETTER OR SPIRIT OF THE RULES IS FAILURE.
## Foundational rules
- Doing it right is better than doing it fast. You are not in a rush. NEVER skip steps or take shortcuts.
- Tedious, systematic work is often the correct solution. Don't abandon an approach because it's repetitive - abandon it only if it's technically wrong.
- Honesty is a core value.
## Our relationship
- Act as a critical peer reviewer. Your job is to disagree with me when I'm wrong, not to please me. Prioritize accuracy and reasoning over agreement.
- YOU MUST speak up immediately when you don't know something or we're in over our heads
- YOU MUST call out bad ideas, unreasonable expectations, and mistakes - I depend on this
- NEVER be agreeable just to be nice - I NEED your HONEST technical judgment
- NEVER write the phrase "You're absolutely right!" You are not a sycophant. We're working together because I value your opinion. Do not agree with me unless you can justify it with evidence or reasoning.
- YOU MUST ALWAYS STOP and ask for clarification rather than making assumptions.
- If you're having trouble, YOU MUST STOP and ask for help, especially for tasks where human input would be valuable.
- When you disagree with my approach, YOU MUST push back. Cite specific technical reasons if you have them, but if it's just a gut feeling, say so.
- If you're uncomfortable pushing back out loud, just say "Houston, we have a problem". I'll know what you mean
- We discuss architectutral decisions (framework changes, major refactoring, system design) together before implementation. Routine fixes and clear implementations don't need discussion.
## Proactiveness
When asked to do something, just do it - including obvious follow-up actions needed to complete the task properly.
Only pause to ask for confirmation when:
- Multiple valid approaches exist and the choice matters
- The action would delete or significantly restructure existing code
- You genuinely don't understand what's being asked
- Your partner asked a question (answer the question, don't jump to implementation)
@.claude/docs/WORKFLOWS.md
@.cursorrules
@README.md
@package.json
## Essential Commands
## 🚀 Essential Commands
| Task | Command | Notes |
|-------------------|--------------------------|----------------------------------|
@@ -51,13 +21,22 @@ Only pause to ask for confirmation when:
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
### Frontend Commands (site directory)
- `pnpm build` - Build frontend
- `pnpm dev` - Run development server
- `pnpm check` - Run code checks
- `pnpm format` - Format frontend code
- `pnpm lint` - Lint frontend code
- `pnpm test` - Run frontend tests
### Documentation Commands
- `pnpm run format-docs` - Format markdown tables in docs
- `pnpm run lint-docs` - Lint and fix markdown files
- `pnpm run storybook` - Run Storybook (from site directory)
## Critical Patterns
## 🔧 Critical Patterns
### Database Changes (ALWAYS FOLLOW)
@@ -99,7 +78,7 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
```
## Quick Reference
## 📋 Quick Reference
### Full workflows available in imported WORKFLOWS.md
@@ -109,14 +88,14 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
- [ ] Check if feature touches database - you'll need migrations
- [ ] Check if feature touches audit logs - update `enterprise/audit/table.go`
## Architecture
## 🏗️ Architecture
- **coderd**: Main API service
- **provisionerd**: Infrastructure provisioning
- **Agents**: Workspace services (SSH, port forwarding)
- **Database**: PostgreSQL with `dbauthz` authorization
## Testing
## 🧪 Testing
### Race Condition Prevention
@@ -133,21 +112,21 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
NEVER use `time.Sleep` to mitigate timing issues. If an issue
seems like it should use `time.Sleep`, read through https://github.com/coder/quartz and specifically the [README](https://github.com/coder/quartz/blob/main/README.md) to better understand how to handle timing issues.
## Code Style
## 🎯 Code Style
### Detailed guidelines in imported WORKFLOWS.md
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
## Detailed Development Guides
## 📚 Detailed Development Guides
@.claude/docs/OAUTH2.md
@.claude/docs/TESTING.md
@.claude/docs/TROUBLESHOOTING.md
@.claude/docs/DATABASE.md
## Common Pitfalls
## 🚨 Common Pitfalls
1. **Audit table errors** → Update `enterprise/audit/table.go`
2. **OAuth2 errors** → Return RFC-compliant format
+12
View File
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
scripts/apitypings/ @Emyrk
scripts/gensite/ @aslilac
site/ @aslilac @Parkreiner
site/src/hooks/ @Parkreiner
# These rules intentionally do not specify any owners. More specific rules
# override less specific rules, so these files are "ignored" by the site/ rule.
site/e2e/google/protobuf/timestampGenerated.ts
site/e2e/provisionerGenerated.ts
site/src/api/countriesGenerated.ts
site/src/api/rbacresourcesGenerated.ts
site/src/api/typesGenerated.ts
site/src/testHelpers/entities.ts
site/CLAUDE.md
# The blood and guts of the autostop algorithm, which is quite complex and
# requires elite ball knowledge of most of the scheduling code to make changes
# without inadvertently affecting other parts of the codebase.
+8 -18
View File
@@ -636,8 +636,8 @@ TAILNETTEST_MOCKS := \
tailnet/tailnettest/subscriptionmock.go
AIBRIDGED_MOCKS := \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go
enterprise/x/aibridged/aibridgedmock/clientmock.go \
enterprise/x/aibridged/aibridgedmock/poolmock.go
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
@@ -645,7 +645,7 @@ GEN_FILES := \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
$(DB_GEN_FILES) \
$(SITE_GEN_FILES) \
coderd/rbac/object_gen.go \
@@ -676,7 +676,6 @@ gen/db: $(DB_GEN_FILES)
.PHONY: gen/db
gen/golden-files: \
agent/unit/testdata/.gen-golden \
cli/testdata/.gen-golden \
coderd/.gen-golden \
coderd/notifications/.gen-golden \
@@ -697,7 +696,7 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
enterprise/x/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
@@ -768,8 +767,8 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
go generate ./codersdk/workspacesdk/agentconnmock/
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
go generate ./enterprise/aibridged/aibridgedmock/
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
go generate ./enterprise/x/aibridged/aibridgedmock/
touch "$@"
agent/agentcontainers/dcspec/dcspec_gen.go: \
@@ -822,13 +821,13 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
./enterprise/x/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
@@ -953,10 +952,6 @@ clean/golden-files:
-type f -name '*.golden' -delete
.PHONY: clean/golden-files
agent/unit/testdata/.gen-golden: $(wildcard agent/unit/testdata/*.golden) $(GO_SRC_FILES) $(wildcard agent/unit/*_test.go)
TZ=UTC go test ./agent/unit -run="TestGraph" -update
touch "$@"
cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) $(wildcard cli/*_test.go)
TZ=UTC go test ./cli -run="Test(CommandHelp|ServerYAML|ErrorExamples|.*Golden)" -update
touch "$@"
@@ -1182,8 +1177,3 @@ endif
dogfood/coder/nix.hash: flake.nix flake.lock
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
# Count the number of test databases created per test package.
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
+1 -1
View File
@@ -1087,7 +1087,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(ctx, "fetched manifest")
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
manifest, err := agentsdk.ManifestFromProto(mp)
if err != nil {
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
+33 -76
View File
@@ -3462,7 +3462,11 @@ func TestAgent_Metrics_SSH(t *testing.T) {
registry := prometheus.NewRegistry()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
// Make sure we always get a DERP connection for
// currently_reachable_peers.
DisableDirectConnections: true,
}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.PrometheusRegistry = registry
})
@@ -3477,31 +3481,16 @@ func TestAgent_Metrics_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
expected := []struct {
Name string
Type proto.Stats_Metric_Type
CheckFn func(float64) error
Labels []*proto.Stats_Metric_Label
}{
expected := []*proto.Stats_Metric{
{
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 1 {
return nil
}
return xerrors.Errorf("expected 1, got %f", v)
},
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
{
Name: "magic_type",
@@ -3514,44 +3503,24 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
// We can't reliably ping a peer here, and networking is out of
// scope of this test, so we just test that the metric exists
// with the correct labels.
return nil
},
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
@@ -3560,11 +3529,9 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
return nil
},
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 0,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
@@ -3573,20 +3540,9 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(f float64) error {
if f >= 0 {
return nil
}
return xerrors.Errorf("expected >= 0, got %f", f)
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "success",
Value: "true",
},
},
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
},
}
@@ -3608,10 +3564,11 @@ func TestAgent_Metrics_SSH(t *testing.T) {
for _, m := range mf.GetMetric() {
assert.Equal(t, expected[i].Name, mf.GetName())
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
// Value is max expected
if expected[i].Type == proto.Stats_Metric_GAUGE {
assert.NoError(t, expected[i].CheckFn(m.GetGauge().GetValue()), "check fn for %s failed", expected[i].Name)
assert.GreaterOrEqualf(t, expected[i].Value, m.GetGauge().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetGauge().GetValue())
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
assert.NoError(t, expected[i].CheckFn(m.GetCounter().GetValue()), "check fn for %s failed", expected[i].Name)
assert.GreaterOrEqualf(t, expected[i].Value, m.GetCounter().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetCounter().GetValue())
}
for j, lbl := range expected[i].Labels {
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
+2
View File
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
} else {
prevErr = nil
}
default:
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
}
return nil // Always nil to keep the ticker going.
-174
View File
@@ -1,174 +0,0 @@
package unit
import (
"fmt"
"sync"
"golang.org/x/xerrors"
"gonum.org/v1/gonum/graph/encoding/dot"
"gonum.org/v1/gonum/graph/simple"
"gonum.org/v1/gonum/graph/topo"
)
// Graph provides a bidirectional interface over gonum's directed graph implementation.
// While the underlying gonum graph is directed, we overlay bidirectional semantics
// by distinguishing between forward and reverse edges. Wanting and being wanted by
// other units are related but different concepts that have different graph traversal
// implications when Units update their status.
//
// The graph stores edge types to represent different relationships between units,
// allowing for domain-specific semantics beyond simple connectivity.
type Graph[EdgeType, VertexType comparable] struct {
mu sync.RWMutex
// The underlying gonum graph. It stores vertices and edges without knowing about the types of the vertices and edges.
gonumGraph *simple.DirectedGraph
// Maps vertices to their IDs so that a gonum vertex ID can be used to lookup the vertex type.
vertexToID map[VertexType]int64
// Maps vertex IDs to their types so that a vertex type can be used to lookup the gonum vertex ID.
idToVertex map[int64]VertexType
// The next ID to assign to a vertex.
nextID int64
// Store edge types by "fromID->toID" key. This is used to lookup the edge type for a given edge.
edgeTypes map[string]EdgeType
}
// Edge is a convenience type for representing an edge in the graph.
// It encapsulates the from and to vertices and the edge type itself.
type Edge[EdgeType, VertexType comparable] struct {
From VertexType
To VertexType
Edge EdgeType
}
// AddEdge adds an edge to the graph. It initializes the graph and metadata on first use,
// checks for cycles, and adds the edge to the gonum graph.
func (g *Graph[EdgeType, VertexType]) AddEdge(from, to VertexType, edge EdgeType) error {
g.mu.Lock()
defer g.mu.Unlock()
if g.gonumGraph == nil {
g.gonumGraph = simple.NewDirectedGraph()
g.vertexToID = make(map[VertexType]int64)
g.idToVertex = make(map[int64]VertexType)
g.edgeTypes = make(map[string]EdgeType)
g.nextID = 1
}
fromID := g.getOrCreateVertexID(from)
toID := g.getOrCreateVertexID(to)
if g.canReach(to, from) {
return xerrors.Errorf("adding edge (%v -> %v) would create a cycle", from, to)
}
g.gonumGraph.SetEdge(simple.Edge{F: simple.Node(fromID), T: simple.Node(toID)})
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
g.edgeTypes[edgeKey] = edge
return nil
}
// GetForwardAdjacentVertices returns all the edges that originate from the given vertex.
func (g *Graph[EdgeType, VertexType]) GetForwardAdjacentVertices(from VertexType) []Edge[EdgeType, VertexType] {
g.mu.RLock()
defer g.mu.RUnlock()
fromID, exists := g.vertexToID[from]
if !exists {
return []Edge[EdgeType, VertexType]{}
}
edges := []Edge[EdgeType, VertexType]{}
toNodes := g.gonumGraph.From(fromID)
for toNodes.Next() {
toID := toNodes.Node().ID()
to := g.idToVertex[toID]
// Get the edge type
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
edgeType := g.edgeTypes[edgeKey]
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
}
return edges
}
// GetReverseAdjacentVertices returns all the edges that terminate at the given vertex.
func (g *Graph[EdgeType, VertexType]) GetReverseAdjacentVertices(to VertexType) []Edge[EdgeType, VertexType] {
g.mu.RLock()
defer g.mu.RUnlock()
toID, exists := g.vertexToID[to]
if !exists {
return []Edge[EdgeType, VertexType]{}
}
edges := []Edge[EdgeType, VertexType]{}
fromNodes := g.gonumGraph.To(toID)
for fromNodes.Next() {
fromID := fromNodes.Node().ID()
from := g.idToVertex[fromID]
// Get the edge type
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
edgeType := g.edgeTypes[edgeKey]
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
}
return edges
}
// getOrCreateVertexID returns the ID for a vertex, creating it if it doesn't exist.
func (g *Graph[EdgeType, VertexType]) getOrCreateVertexID(vertex VertexType) int64 {
if id, exists := g.vertexToID[vertex]; exists {
return id
}
id := g.nextID
g.nextID++
g.vertexToID[vertex] = id
g.idToVertex[id] = vertex
// Add the node to the gonum graph
g.gonumGraph.AddNode(simple.Node(id))
return id
}
// canReach checks if there is a path from the start vertex to the end vertex.
func (g *Graph[EdgeType, VertexType]) canReach(start, end VertexType) bool {
if start == end {
return true
}
startID, startExists := g.vertexToID[start]
endID, endExists := g.vertexToID[end]
if !startExists || !endExists {
return false
}
// Use gonum's built-in path existence check
return topo.PathExistsIn(g.gonumGraph, simple.Node(startID), simple.Node(endID))
}
// ToDOT exports the graph to DOT format for visualization
func (g *Graph[EdgeType, VertexType]) ToDOT(name string) (string, error) {
g.mu.RLock()
defer g.mu.RUnlock()
if g.gonumGraph == nil {
return "", xerrors.New("graph is not initialized")
}
// Marshal the graph to DOT format
dotBytes, err := dot.Marshal(g.gonumGraph, name, "", " ")
if err != nil {
return "", xerrors.Errorf("failed to marshal graph to DOT: %w", err)
}
return string(dotBytes), nil
}
-454
View File
@@ -1,454 +0,0 @@
// Package unit_test provides tests for the unit package.
//
// DOT Graph Testing:
// The graph tests use golden files for DOT representation verification.
// To update the golden files:
// make gen/golden-files
//
// The golden files contain the expected DOT representation and can be easily
// inspected, version controlled, and updated when the graph structure changes.
package unit_test
import (
"bytes"
"flag"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/cryptorand"
)
type testGraphEdge string
const (
testEdgeStarted testGraphEdge = "started"
testEdgeCompleted testGraphEdge = "completed"
)
type testGraphVertex struct {
Name string
}
type (
testGraph = unit.Graph[testGraphEdge, *testGraphVertex]
testEdge = unit.Edge[testGraphEdge, *testGraphVertex]
)
// randInt generates a random integer in the range [0, limit).
func randInt(limit int) int {
if limit <= 0 {
return 0
}
n, err := cryptorand.Int63n(int64(limit))
if err != nil {
return 0
}
return int(n)
}
// UpdateGoldenFiles indicates golden files should be updated.
// To update the golden files:
// make gen/golden-files
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
// assertDOTGraph requires that the graph's DOT representation matches the golden file
func assertDOTGraph(t *testing.T, graph *testGraph, goldenName string) {
t.Helper()
dot, err := graph.ToDOT(goldenName)
require.NoError(t, err)
goldenFile := filepath.Join("testdata", goldenName+".golden")
if *UpdateGoldenFiles {
t.Logf("update golden file for: %q: %s", goldenName, goldenFile)
err := os.MkdirAll(filepath.Dir(goldenFile), 0o755)
require.NoError(t, err, "want no error creating golden file directory")
err = os.WriteFile(goldenFile, []byte(dot), 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenFile)
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
// Normalize line endings for cross-platform compatibility
expected = normalizeLineEndings(expected)
normalizedDot := normalizeLineEndings([]byte(dot))
assert.Empty(t, cmp.Diff(string(expected), string(normalizedDot)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenFile)
}
// normalizeLineEndings ensures that all line endings are normalized to \n.
// Required for Windows compatibility.
func normalizeLineEndings(content []byte) []byte {
content = bytes.ReplaceAll(content, []byte("\r\n"), []byte("\n"))
content = bytes.ReplaceAll(content, []byte("\r"), []byte("\n"))
return content
}
func TestGraph(t *testing.T) {
t.Parallel()
testFuncs := map[string]func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex]{
"ForwardAndReverseEdges": func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex] {
graph := &unit.Graph[testGraphEdge, *testGraphVertex]{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
unit3 := &testGraphVertex{Name: "unit3"}
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit1, unit3, testEdgeStarted)
require.NoError(t, err)
// Check for forward edge
vertices := graph.GetForwardAdjacentVertices(unit1)
require.Len(t, vertices, 2)
// Unit 1 depends on the completion of Unit2
require.Contains(t, vertices, testEdge{
From: unit1,
To: unit2,
Edge: testEdgeCompleted,
})
// Unit 1 depends on the start of Unit3
require.Contains(t, vertices, testEdge{
From: unit1,
To: unit3,
Edge: testEdgeStarted,
})
// Check for reverse edges
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
require.Len(t, unit2ReverseEdges, 1)
// Unit 2 must be completed before Unit 1 can start
require.Contains(t, unit2ReverseEdges, testEdge{
From: unit1,
To: unit2,
Edge: testEdgeCompleted,
})
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
require.Len(t, unit3ReverseEdges, 1)
// Unit 3 must be started before Unit 1 can complete
require.Contains(t, unit3ReverseEdges, testEdge{
From: unit1,
To: unit3,
Edge: testEdgeStarted,
})
return graph
},
"SelfReference": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
err := graph.AddEdge(unit1, unit1, testEdgeCompleted)
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit1, unit1))
return graph
},
"Cycle": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit2, unit1, testEdgeStarted)
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit2, unit1))
return graph
},
"MultipleDependenciesSameStatus": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
unit3 := &testGraphVertex{Name: "unit3"}
unit4 := &testGraphVertex{Name: "unit4"}
// Unit1 depends on completion of both unit2 and unit3 (same status type)
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit1, unit3, testEdgeCompleted)
require.NoError(t, err)
// Unit1 also depends on starting of unit4 (different status type)
err = graph.AddEdge(unit1, unit4, testEdgeStarted)
require.NoError(t, err)
// Check that unit1 has 3 forward dependencies
forwardEdges := graph.GetForwardAdjacentVertices(unit1)
require.Len(t, forwardEdges, 3)
// Verify all expected dependencies exist
expectedDependencies := []testEdge{
{From: unit1, To: unit2, Edge: testEdgeCompleted},
{From: unit1, To: unit3, Edge: testEdgeCompleted},
{From: unit1, To: unit4, Edge: testEdgeStarted},
}
for _, expected := range expectedDependencies {
require.Contains(t, forwardEdges, expected)
}
// Check reverse dependencies
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
require.Len(t, unit2ReverseEdges, 1)
require.Contains(t, unit2ReverseEdges, testEdge{
From: unit1, To: unit2, Edge: testEdgeCompleted,
})
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
require.Len(t, unit3ReverseEdges, 1)
require.Contains(t, unit3ReverseEdges, testEdge{
From: unit1, To: unit3, Edge: testEdgeCompleted,
})
unit4ReverseEdges := graph.GetReverseAdjacentVertices(unit4)
require.Len(t, unit4ReverseEdges, 1)
require.Contains(t, unit4ReverseEdges, testEdge{
From: unit1, To: unit4, Edge: testEdgeStarted,
})
return graph
},
}
for testName, testFunc := range testFuncs {
var graph *testGraph
t.Run(testName, func(t *testing.T) {
t.Parallel()
graph = testFunc(t)
assertDOTGraph(t, graph, testName)
})
}
}
func TestGraphThreadSafety(t *testing.T) {
t.Parallel()
t.Run("ConcurrentReadWrite", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
var wg sync.WaitGroup
const numWriters = 50
const numReaders = 100
const operationsPerWriter = 1000
const operationsPerReader = 2000
barrier := make(chan struct{})
// Launch writers
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(writerID int) {
defer wg.Done()
<-barrier
for j := 0; j < operationsPerWriter; j++ {
from := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j)}
to := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j+1)}
graph.AddEdge(from, to, testEdgeCompleted)
}
}(i)
}
// Launch readers
readerResults := make([]struct {
panicked bool
readCount int
}, numReaders)
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func(readerID int) {
defer wg.Done()
<-barrier
defer func() {
if r := recover(); r != nil {
readerResults[readerID].panicked = true
}
}()
readCount := 0
for j := 0; j < operationsPerReader; j++ {
// Create a test vertex and read
testUnit := &testGraphVertex{Name: fmt.Sprintf("test-reader-%d-%d", readerID, j)}
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
// Just verify no panics (results may be nil for non-existent vertices)
_ = forwardEdges
_ = reverseEdges
readCount++
}
readerResults[readerID].readCount = readCount
}(i)
}
close(barrier)
wg.Wait()
// Verify no panics occurred in readers
for i, result := range readerResults {
require.False(t, result.panicked, "reader %d panicked", i)
require.Equal(t, operationsPerReader, result.readCount, "reader %d should have performed expected reads", i)
}
})
t.Run("ConcurrentCycleDetection", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
// Pre-create chain: A→B→C→D
unitA := &testGraphVertex{Name: "A"}
unitB := &testGraphVertex{Name: "B"}
unitC := &testGraphVertex{Name: "C"}
unitD := &testGraphVertex{Name: "D"}
err := graph.AddEdge(unitA, unitB, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unitB, unitC, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unitC, unitD, testEdgeCompleted)
require.NoError(t, err)
barrier := make(chan struct{})
var wg sync.WaitGroup
const numGoroutines = 50
cycleErrors := make([]error, numGoroutines)
// Launch goroutines trying to add D→A (creates cycle)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
<-barrier
err := graph.AddEdge(unitD, unitA, testEdgeCompleted)
cycleErrors[goroutineID] = err
}(i)
}
close(barrier)
wg.Wait()
// Verify all attempts correctly returned cycle error
for i, err := range cycleErrors {
require.Error(t, err, "goroutine %d should have detected cycle", i)
require.Contains(t, err.Error(), "would create a cycle")
}
// Verify graph remains valid (original chain intact)
dot, err := graph.ToDOT("test")
require.NoError(t, err)
require.NotEmpty(t, dot)
})
t.Run("ConcurrentToDOT", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
// Pre-populate graph
for i := 0; i < 20; i++ {
from := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i)}
to := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i+1)}
err := graph.AddEdge(from, to, testEdgeCompleted)
require.NoError(t, err)
}
barrier := make(chan struct{})
var wg sync.WaitGroup
const numReaders = 100
const numWriters = 20
dotResults := make([]string, numReaders)
// Launch readers calling ToDOT
dotErrors := make([]error, numReaders)
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func(readerID int) {
defer wg.Done()
<-barrier
dot, err := graph.ToDOT(fmt.Sprintf("test-%d", readerID))
dotErrors[readerID] = err
if err == nil {
dotResults[readerID] = dot
}
}(i)
}
// Launch writers adding edges
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(writerID int) {
defer wg.Done()
<-barrier
from := &testGraphVertex{Name: fmt.Sprintf("writer-dot-%d", writerID)}
to := &testGraphVertex{Name: fmt.Sprintf("writer-dot-target-%d", writerID)}
graph.AddEdge(from, to, testEdgeCompleted)
}(i)
}
close(barrier)
wg.Wait()
// Verify no errors occurred during DOT generation
for i, err := range dotErrors {
require.NoError(t, err, "DOT generation error at index %d", i)
}
// Verify all DOT results are valid
for i, dot := range dotResults {
require.NotEmpty(t, dot, "DOT result %d should not be empty", i)
}
})
}
func BenchmarkGraph_ConcurrentMixedOperations(b *testing.B) {
graph := &testGraph{}
var wg sync.WaitGroup
const numGoroutines = 200
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Launch goroutines performing random operations
for j := 0; j < numGoroutines; j++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
operationCount := 0
for operationCount < 50 {
operation := float32(randInt(100)) / 100.0
if operation < 0.6 { // 60% reads
// Read operation
testUnit := &testGraphVertex{Name: fmt.Sprintf("bench-read-%d-%d", goroutineID, operationCount)}
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
// Just verify no panics (results may be nil for non-existent vertices)
_ = forwardEdges
_ = reverseEdges
} else { // 40% writes
// Write operation
from := &testGraphVertex{Name: fmt.Sprintf("bench-write-%d-%d", goroutineID, operationCount)}
to := &testGraphVertex{Name: fmt.Sprintf("bench-write-target-%d-%d", goroutineID, operationCount)}
graph.AddEdge(from, to, testEdgeCompleted)
}
operationCount++
}
}(j)
}
wg.Wait()
}
}
-8
View File
@@ -1,8 +0,0 @@
strict digraph Cycle {
// Node definitions.
1;
2;
// Edge definitions.
1 -> 2;
}
-10
View File
@@ -1,10 +0,0 @@
strict digraph ForwardAndReverseEdges {
// Node definitions.
1;
2;
3;
// Edge definitions.
1 -> 2;
1 -> 3;
}
@@ -1,12 +0,0 @@
strict digraph MultipleDependenciesSameStatus {
// Node definitions.
1;
2;
3;
4;
// Edge definitions.
1 -> 2;
1 -> 3;
1 -> 4;
}
-4
View File
@@ -1,4 +0,0 @@
strict digraph SelfReference {
// Node definitions.
1;
}
-78
View File
@@ -1,78 +0,0 @@
package cli
import (
"encoding/csv"
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
var (
_ pflag.SliceValue = &AllowListFlag{}
_ pflag.Value = &AllowListFlag{}
)
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
type AllowListFlag []codersdk.APIAllowListTarget
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
return (*AllowListFlag)(al)
}
func (a AllowListFlag) String() string {
return strings.Join(a.GetSlice(), ",")
}
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
return []codersdk.APIAllowListTarget(a)
}
func (AllowListFlag) Type() string { return "allow-list" }
func (a *AllowListFlag) Set(set string) error {
values, err := csv.NewReader(strings.NewReader(set)).Read()
if err != nil {
return xerrors.Errorf("parse allow list entries as csv: %w", err)
}
for _, v := range values {
if err := a.Append(v); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) Append(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return xerrors.New("allow list entry cannot be empty")
}
var target codersdk.APIAllowListTarget
if err := target.UnmarshalText([]byte(value)); err != nil {
return err
}
*a = append(*a, target)
return nil
}
func (a *AllowListFlag) Replace(items []string) error {
*a = []codersdk.APIAllowListTarget{}
for _, item := range items {
if err := a.Append(item); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) GetSlice() []string {
out := make([]string, len(*a))
for i, entry := range *a {
out[i] = entry.String()
}
return out
}
+9 -18
View File
@@ -296,23 +296,22 @@ func renderTable(out any, sort string, headers table.Row, filterColumns []string
// returned. If the table tag is malformed, an error is returned.
//
// The returned name is transformed from "snake_case" to "normal text".
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName, emptyNil bool, err error) {
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName bool, err error) {
tags, err := structtag.Parse(string(field.Tag))
if err != nil {
return "", false, false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
return "", false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
}
tag, err := tags.Get("table")
if err != nil || tag.Name == "-" {
// tags.Get only returns an error if the tag is not found.
return "", false, false, false, false, false, nil
return "", false, false, false, false, nil
}
defaultSortOpt := false
noSortOpt = false
recursiveOpt := false
skipParentNameOpt := false
emptyNilOpt := false
for _, opt := range tag.Options {
switch opt {
case "default_sort":
@@ -327,14 +326,12 @@ func parseTableStructTag(field reflect.StructField) (name string, defaultSort, n
// make sure the child name is unique across all nested structs in the parent.
recursiveOpt = true
skipParentNameOpt = true
case "empty_nil":
emptyNilOpt = true
default:
return "", false, false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
return "", false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
}
}
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, emptyNilOpt, nil
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, nil
}
func isStructOrStructPointer(t reflect.Type) bool {
@@ -361,7 +358,7 @@ func typeToTableHeaders(t reflect.Type, requireDefault bool) ([]string, string,
noSortOpt := false
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
name, defaultSort, noSort, recursive, skip, _, err := parseTableStructTag(field)
name, defaultSort, noSort, recursive, skip, err := parseTableStructTag(field)
if err != nil {
return nil, "", xerrors.Errorf("parse struct tags for field %q in type %q: %w", field.Name, t.String(), err)
}
@@ -438,7 +435,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
for i := 0; i < val.NumField(); i++ {
field := val.Type().Field(i)
fieldVal := val.Field(i)
name, _, _, recursive, skip, emptyNil, err := parseTableStructTag(field)
name, _, _, recursive, skip, err := parseTableStructTag(field)
if err != nil {
return nil, xerrors.Errorf("parse struct tags for field %q in type %T: %w", field.Name, val, err)
}
@@ -446,14 +443,8 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
continue
}
fieldType := field.Type
// If empty_nil is set and this is a nil pointer, use a zero value.
if emptyNil && fieldVal.Kind() == reflect.Pointer && fieldVal.IsNil() {
fieldVal = reflect.New(fieldType.Elem())
}
// Recurse if it's a struct.
fieldType := field.Type
if recursive {
if !isStructOrStructPointer(fieldType) {
return nil, xerrors.Errorf("field %q in type %q is marked as recursive but does not contain a struct or a pointer to a struct", field.Name, fieldType.String())
@@ -476,7 +467,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
}
// Otherwise, we just use the field value.
row[name] = fieldVal.Interface()
row[name] = val.Field(i).Interface()
}
return row, nil
-72
View File
@@ -400,78 +400,6 @@ foo <nil> 10 [a, b, c] foo1 11 foo2 12 fo
})
})
})
t.Run("EmptyNil", func(t *testing.T) {
t.Parallel()
type emptyNilTest struct {
Name string `table:"name,default_sort"`
EmptyOnNil *string `table:"empty_on_nil,empty_nil"`
NormalBehavior *string `table:"normal_behavior"`
}
value := "value"
in := []emptyNilTest{
{
Name: "has_value",
EmptyOnNil: &value,
NormalBehavior: &value,
},
{
Name: "has_nil",
EmptyOnNil: nil,
NormalBehavior: nil,
},
}
expected := `
NAME EMPTY ON NIL NORMAL BEHAVIOR
has_nil <nil>
has_value value value
`
out, err := cliui.DisplayTable(in, "", nil)
log.Println("rendered table:\n" + out)
require.NoError(t, err)
compareTables(t, expected, out)
})
t.Run("EmptyNilWithRecursiveInline", func(t *testing.T) {
t.Parallel()
type nestedData struct {
Name string `table:"name"`
}
type inlineTest struct {
Nested *nestedData `table:"ignored,recursive_inline,empty_nil"`
Count int `table:"count,default_sort"`
}
in := []inlineTest{
{
Nested: &nestedData{
Name: "alice",
},
Count: 1,
},
{
Nested: nil,
Count: 2,
},
}
expected := `
NAME COUNT
alice 1
2
`
out, err := cliui.DisplayTable(in, "", nil)
log.Println("rendered table:\n" + out)
require.NoError(t, err)
compareTables(t, expected, out)
})
}
// compareTables normalizes the incoming table lines
+6
View File
@@ -185,6 +185,9 @@ func TestDelete(t *testing.T) {
t.Run("WarnNoProvisioners", func(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
store, ps, db := dbtestutil.NewDBWithSQLDB(t)
client, closeDaemon := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{
@@ -225,6 +228,9 @@ func TestDelete(t *testing.T) {
t.Run("Prebuilt workspace delete permissions", func(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
// Setup
db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
+324 -12
View File
@@ -29,6 +29,7 @@ import (
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/httpapi"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
@@ -39,6 +40,7 @@ import (
"github.com/coder/coder/v2/scaletest/dashboard"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/coder/v2/scaletest/reconnectingpty"
"github.com/coder/coder/v2/scaletest/workspacebuild"
"github.com/coder/coder/v2/scaletest/workspacetraffic"
@@ -64,7 +66,6 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
r.scaletestWorkspaceTraffic(),
r.scaletestAutostart(),
r.scaletestNotifications(),
r.scaletestSMTP(),
},
}
@@ -1920,6 +1921,259 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
return cmd
}
func (r *RootCmd) scaletestNotifications() *serpent.Command {
var (
userCount int64
ownerUserPercentage float64
notificationTimeout time.Duration
dialTimeout time.Duration
noCleanup bool
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency.
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "notifications",
Short: "Simulate notification delivery by creating many users listening to notifications.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if userCount <= 0 {
return xerrors.Errorf("--user-count must be greater than 0")
}
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
}
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
if ownerUserCount == 0 && ownerUserPercentage > 0 {
ownerUserCount = 1
}
regularUserCount := userCount - ownerUserCount
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := notifications.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := &sync.WaitGroup{}
ownerWatchBarrier := &sync.WaitGroup{}
dialBarrier.Add(int(userCount))
ownerWatchBarrier.Add(int(ownerUserCount))
expectedNotifications := map[uuid.UUID]chan time.Time{
notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1),
notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1),
}
configs := make([]notifications.Config, 0, userCount)
for range ownerUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
ExpectedNotifications: expectedNotifications,
Metrics: metrics,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
Metrics: metrics,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
go triggerUserNotifications(
ctx,
logger,
client,
me.OrganizationIDs[0],
dialBarrier,
dialTimeout,
expectedNotifications,
)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
id := strconv.Itoa(i)
name := fmt.Sprintf("notifications-%s", id)
var runner harness.Runnable = notifications.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: name,
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "user-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
Description: "Required: Total number of users to create.",
Value: serpent.Int64Of(&userCount),
Required: true,
},
{
Flag: "owner-user-percentage",
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
Default: "20.0",
Description: "Percentage of users to assign Owner role to (0-100).",
Value: serpent.Float64Of(&ownerUserPercentage),
},
{
Flag: "notification-timeout",
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
Default: "5m",
Description: "How long to wait for notifications after triggering.",
Value: serpent.DurationOf(&notificationTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the notification websocket endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
}
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
type runnableTraceWrapper struct {
tracer trace.Tracer
spanName string
@@ -1929,9 +2183,8 @@ type runnableTraceWrapper struct {
}
var (
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
_ harness.Collectable = &runnableTraceWrapper{}
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
)
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
@@ -1973,14 +2226,6 @@ func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string, logs io.W
return c.Cleanup(ctx, id, logs)
}
func (r *runnableTraceWrapper) GetMetrics() map[string]any {
c, ok := r.runner.(harness.Collectable)
if !ok {
return nil
}
return c.GetMetrics()
}
func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner, template string) ([]codersdk.Workspace, int, error) {
var (
pageNumber = 0
@@ -2129,6 +2374,73 @@ func parseTargetRange(name, targets string) (start, end int, err error) {
return start, end, nil
}
// triggerUserNotifications waits for all test users to connect,
// then creates and deletes a test user to trigger notification events for testing.
func triggerUserNotifications(
ctx context.Context,
logger slog.Logger,
client *codersdk.Client,
orgID uuid.UUID,
dialBarrier *sync.WaitGroup,
dialTimeout time.Duration,
expectedNotifications map[uuid.UUID]chan time.Time,
) {
logger.Info(ctx, "waiting for all users to connect")
// Wait for all users to connect
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
dialBarrier.Wait()
close(done)
}()
select {
case <-done:
logger.Info(ctx, "all users connected")
case <-waitCtx.Done():
if waitCtx.Err() == context.DeadlineExceeded {
logger.Error(ctx, "timeout waiting for users to connect")
} else {
logger.Info(ctx, "context canceled while waiting for users")
}
return
}
const (
triggerUsername = "scaletest-trigger-user"
triggerEmail = "scaletest-trigger@example.com"
)
logger.Info(ctx, "creating test user to test notifications",
slog.F("username", triggerUsername),
slog.F("email", triggerEmail),
slog.F("org_id", orgID))
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{orgID},
Username: triggerUsername,
Email: triggerEmail,
Password: "test-password-123",
})
if err != nil {
logger.Error(ctx, "create test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
err = client.DeleteUser(ctx, testUser.ID)
if err != nil {
logger.Error(ctx, "delete test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
}
func createWorkspaceAppConfig(client *codersdk.Client, appHost, app string, workspace codersdk.Workspace, agent codersdk.WorkspaceAgent) (workspacetraffic.AppConfig, error) {
if app == "" {
return workspacetraffic.AppConfig{}, nil
+1 -12
View File
@@ -27,7 +27,6 @@ const (
func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
var (
templateName string
provisionerTags []string
numEvals int64
tracingFlags = &scaletestTracingFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
@@ -57,11 +56,6 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
return xerrors.Errorf("template cannot be empty")
}
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
org, err := orgContext.Selected(inv, client)
if err != nil {
return err
@@ -105,7 +99,7 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
}()
tracer := tracerProvider.Tracer(scaletestTracerName)
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, tags, numEvals, logger)
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, numEvals, logger)
if err != nil {
return xerrors.Errorf("setup dynamic parameters partitions: %w", err)
}
@@ -166,11 +160,6 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
Default: "100",
Value: serpent.Int64Of(&numEvals),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
orgContext.AttachOptions(cmd)
output.attach(&cmd.Options)
-447
View File
@@ -1,447 +0,0 @@
//go:build !slim
package cli
import (
"context"
"fmt"
"net/http"
"os/signal"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"cdr.dev/slog"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/serpent"
)
func (r *RootCmd) scaletestNotifications() *serpent.Command {
var (
userCount int64
ownerUserPercentage float64
notificationTimeout time.Duration
dialTimeout time.Duration
noCleanup bool
smtpAPIURL string
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency.
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "notifications",
Short: "Simulate notification delivery by creating many users listening to notifications.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if userCount <= 0 {
return xerrors.Errorf("--user-count must be greater than 0")
}
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
}
if smtpAPIURL != "" && !strings.HasPrefix(smtpAPIURL, "http://") && !strings.HasPrefix(smtpAPIURL, "https://") {
return xerrors.Errorf("--smtp-api-url must start with http:// or https://")
}
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
if ownerUserCount == 0 && ownerUserPercentage > 0 {
ownerUserCount = 1
}
regularUserCount := userCount - ownerUserCount
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := notifications.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := &sync.WaitGroup{}
ownerWatchBarrier := &sync.WaitGroup{}
dialBarrier.Add(int(userCount))
ownerWatchBarrier.Add(int(ownerUserCount))
expectedNotificationIDs := map[uuid.UUID]struct{}{
notificationsLib.TemplateUserAccountCreated: {},
notificationsLib.TemplateUserAccountDeleted: {},
}
triggerTimes := make(map[uuid.UUID]chan time.Time, len(expectedNotificationIDs))
for id := range expectedNotificationIDs {
triggerTimes[id] = make(chan time.Time, 1)
}
configs := make([]notifications.Config, 0, userCount)
for range ownerUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
ExpectedNotificationsIDs: expectedNotificationIDs,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
go triggerUserNotifications(
ctx,
logger,
client,
me.OrganizationIDs[0],
dialBarrier,
dialTimeout,
triggerTimes,
)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
id := strconv.Itoa(i)
name := fmt.Sprintf("notifications-%s", id)
var runner harness.Runnable = notifications.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: name,
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if err := computeNotificationLatencies(ctx, logger, triggerTimes, res, metrics); err != nil {
return xerrors.Errorf("compute notification latencies: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "user-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
Description: "Required: Total number of users to create.",
Value: serpent.Int64Of(&userCount),
Required: true,
},
{
Flag: "owner-user-percentage",
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
Default: "20.0",
Description: "Percentage of users to assign Owner role to (0-100).",
Value: serpent.Float64Of(&ownerUserPercentage),
},
{
Flag: "notification-timeout",
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
Default: "5m",
Description: "How long to wait for notifications after triggering.",
Value: serpent.DurationOf(&notificationTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the notification websocket endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "smtp-api-url",
Env: "CODER_SCALETEST_SMTP_API_URL",
Description: "SMTP mock HTTP API address.",
Value: serpent.StringOf(&smtpAPIURL),
},
}
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
func computeNotificationLatencies(
ctx context.Context,
logger slog.Logger,
expectedNotifications map[uuid.UUID]chan time.Time,
results harness.Results,
metrics *notifications.Metrics,
) error {
triggerTimes := make(map[uuid.UUID]time.Time)
for notificationID, triggerTimeChan := range expectedNotifications {
select {
case triggerTime := <-triggerTimeChan:
triggerTimes[notificationID] = triggerTime
logger.Info(ctx, "received trigger time",
slog.F("notification_id", notificationID),
slog.F("trigger_time", triggerTime))
default:
logger.Warn(ctx, "no trigger time received for notification",
slog.F("notification_id", notificationID))
}
}
if len(triggerTimes) == 0 {
logger.Warn(ctx, "no trigger times available, skipping latency computation")
return nil
}
var totalLatencies int
for runID, runResult := range results.Runs {
if runResult.Error != nil {
logger.Debug(ctx, "skipping failed run for latency computation",
slog.F("run_id", runID))
continue
}
if runResult.Metrics == nil {
continue
}
// Process websocket notifications.
if wsReceiptTimes, ok := runResult.Metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range wsReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeWebsocket)
totalLatencies++
logger.Debug(ctx, "computed websocket latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
// Process SMTP notifications
if smtpReceiptTimes, ok := runResult.Metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range smtpReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeSMTP)
totalLatencies++
logger.Debug(ctx, "computed SMTP latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
}
logger.Info(ctx, "finished computing notification latencies",
slog.F("total_runs", results.TotalRuns),
slog.F("total_latencies_computed", totalLatencies))
return nil
}
// triggerUserNotifications waits for all test users to connect,
// then creates and deletes a test user to trigger notification events for testing.
func triggerUserNotifications(
ctx context.Context,
logger slog.Logger,
client *codersdk.Client,
orgID uuid.UUID,
dialBarrier *sync.WaitGroup,
dialTimeout time.Duration,
expectedNotifications map[uuid.UUID]chan time.Time,
) {
logger.Info(ctx, "waiting for all users to connect")
// Wait for all users to connect
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
dialBarrier.Wait()
close(done)
}()
select {
case <-done:
logger.Info(ctx, "all users connected")
case <-waitCtx.Done():
if waitCtx.Err() == context.DeadlineExceeded {
logger.Error(ctx, "timeout waiting for users to connect")
} else {
logger.Info(ctx, "context canceled while waiting for users")
}
return
}
const (
triggerUsername = "scaletest-trigger-user"
triggerEmail = "scaletest-trigger@example.com"
)
logger.Info(ctx, "creating test user to test notifications",
slog.F("username", triggerUsername),
slog.F("email", triggerEmail),
slog.F("org_id", orgID))
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{orgID},
Username: triggerUsername,
Email: triggerEmail,
Password: "test-password-123",
})
if err != nil {
logger.Error(ctx, "create test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
err = client.DeleteUser(ctx, testUser.ID)
if err != nil {
logger.Error(ctx, "delete test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
}
-112
View File
@@ -1,112 +0,0 @@
//go:build !slim
package cli
import (
"fmt"
"os/signal"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/scaletest/smtpmock"
"github.com/coder/serpent"
)
func (*RootCmd) scaletestSMTP() *serpent.Command {
var (
hostAddress string
smtpPort int64
apiPort int64
purgeAtCount int64
)
cmd := &serpent.Command{
Use: "smtp",
Short: "Start a mock SMTP server for testing",
Long: `Start a mock SMTP server with an HTTP API server that can be used to purge
messages and get messages by email.`,
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
config := smtpmock.Config{
HostAddress: hostAddress,
SMTPPort: int(smtpPort),
APIPort: int(apiPort),
Logger: logger,
}
srv := new(smtpmock.Server)
if err := srv.Start(ctx, config); err != nil {
return xerrors.Errorf("start mock SMTP server: %w", err)
}
defer func() {
_ = srv.Stop()
}()
_, _ = fmt.Fprintf(inv.Stdout, "Mock SMTP server started on %s\n", srv.SMTPAddress())
_, _ = fmt.Fprintf(inv.Stdout, "HTTP API server started on %s\n", srv.APIAddress())
if purgeAtCount > 0 {
_, _ = fmt.Fprintf(inv.Stdout, " Auto-purge when message count reaches %d\n", purgeAtCount)
}
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
_, _ = fmt.Fprintf(inv.Stdout, "\nTotal messages received since last purge: %d\n", srv.MessageCount())
return nil
case <-ticker.C:
count := srv.MessageCount()
if count > 0 {
_, _ = fmt.Fprintf(inv.Stdout, "Messages received: %d\n", count)
}
if purgeAtCount > 0 && int64(count) >= purgeAtCount {
_, _ = fmt.Fprintf(inv.Stdout, "Message count (%d) reached threshold (%d). Purging...\n", count, purgeAtCount)
srv.Purge()
continue
}
}
}
},
}
cmd.Options = []serpent.Option{
{
Flag: "host-address",
Env: "CODER_SCALETEST_SMTP_HOST_ADDRESS",
Default: "localhost",
Description: "Host address to bind the mock SMTP and API servers.",
Value: serpent.StringOf(&hostAddress),
},
{
Flag: "smtp-port",
Env: "CODER_SCALETEST_SMTP_PORT",
Description: "Port for the mock SMTP server. Uses a random port if not specified.",
Value: serpent.Int64Of(&smtpPort),
},
{
Flag: "api-port",
Env: "CODER_SCALETEST_SMTP_API_PORT",
Description: "Port for the HTTP API server. Uses a random port if not specified.",
Value: serpent.Int64Of(&apiPort),
},
{
Flag: "purge-at-count",
Env: "CODER_SCALETEST_SMTP_PURGE_AT_COUNT",
Default: "100000",
Description: "Maximum number of messages to keep before auto-purging. Set to 0 to disable.",
Value: serpent.Int64Of(&purgeAtCount),
},
}
return cmd
}
+34 -10
View File
@@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/pretty"
@@ -46,19 +47,43 @@ func (r *RootCmd) taskDelete() *serpent.Command {
}
exp := codersdk.NewExperimentalClient(client)
var tasks []codersdk.Task
type toDelete struct {
ID uuid.UUID
Owner string
Display string
}
var items []toDelete
for _, identifier := range inv.Args {
task, err := exp.TaskByIdentifier(ctx, identifier)
identifier = strings.TrimSpace(identifier)
if identifier == "" {
return xerrors.New("task identifier cannot be empty or whitespace")
}
// Check task identifier, try UUID first.
if id, err := uuid.Parse(identifier); err == nil {
task, err := exp.TaskByID(ctx, id)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", identifier, err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
items = append(items, toDelete{ID: id, Display: display, Owner: task.OwnerName})
continue
}
// Non-UUID, treat as a workspace identifier (name or owner/name).
ws, err := namedWorkspace(ctx, client, identifier)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", identifier, err)
}
tasks = append(tasks, task)
display := ws.FullName()
items = append(items, toDelete{ID: ws.ID, Display: display, Owner: ws.OwnerName})
}
// Confirm deletion of the tasks.
var displayList []string
for _, task := range tasks {
displayList = append(displayList, fmt.Sprintf("%s/%s", task.OwnerName, task.Name))
for _, it := range items {
displayList = append(displayList, it.Display)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Delete these tasks: %s?", pretty.Sprint(cliui.DefaultStyles.Code, strings.Join(displayList, ", "))),
@@ -69,13 +94,12 @@ func (r *RootCmd) taskDelete() *serpent.Command {
return err
}
for i, task := range tasks {
display := displayList[i]
if err := exp.DeleteTask(ctx, task.OwnerName, task.ID); err != nil {
return xerrors.Errorf("delete task %q: %w", display, err)
for _, item := range items {
if err := exp.DeleteTask(ctx, item.Owner, item.ID); err != nil {
return xerrors.Errorf("delete task %q: %w", item.Display, err)
}
_, _ = fmt.Fprintln(
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, display)+" at "+cliui.Timestamp(time.Now()),
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, item.Display)+" at "+cliui.Timestamp(time.Now()),
)
}
+17 -41
View File
@@ -56,18 +56,12 @@ func TestExpTaskDelete(t *testing.T) {
taskID := uuid.MustParse(id1)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/exists":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: taskID,
Name: "exists",
OwnerName: "me",
}},
Count: 1,
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
ID: taskID,
Name: "exists",
OwnerName: "me",
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id1:
c.deleteCalls.Add(1)
@@ -110,18 +104,12 @@ func TestExpTaskDelete(t *testing.T) {
firstID := uuid.MustParse(id3)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/first":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: firstID,
Name: "first",
OwnerName: "me",
}},
Count: 1,
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
ID: firstID,
Name: "first",
OwnerName: "me",
})
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id4:
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
@@ -151,14 +139,8 @@ func TestExpTaskDelete(t *testing.T) {
buildHandler: func(_ *testCounters) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{},
Count: 0,
})
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/doesnotexist":
httpapi.ResourceNotFound(w)
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
@@ -174,18 +156,12 @@ func TestExpTaskDelete(t *testing.T) {
taskID := uuid.MustParse(id5)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/bad":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: taskID,
Name: "bad",
OwnerName: "me",
}},
Count: 1,
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
ID: taskID,
Name: "bad",
OwnerName: "me",
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id5:
httpapi.InternalServerError(w, xerrors.New("boom"))
+3 -4
View File
@@ -8,7 +8,6 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -99,10 +98,10 @@ func (r *RootCmd) taskList() *serpent.Command {
Options: serpent.OptionSet{
{
Name: "status",
Description: "Filter by task status.",
Description: "Filter by task status (e.g. running, failed, etc).",
Flag: "status",
Default: "",
Value: serpent.EnumOf(&statusFilter, slice.ToStrings(codersdk.AllTaskStatuses())...),
Value: serpent.StringOf(&statusFilter),
},
{
Name: "all",
@@ -144,7 +143,7 @@ func (r *RootCmd) taskList() *serpent.Command {
tasks, err := exp.Tasks(ctx, &codersdk.TasksFilter{
Owner: targetUser,
Status: codersdk.TaskStatus(statusFilter),
Status: statusFilter,
})
if err != nil {
return xerrors.Errorf("list tasks: %w", err)
+15 -36
View File
@@ -22,7 +22,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
@@ -30,7 +29,7 @@ import (
)
// makeAITask creates an AI-task workspace.
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) database.Task {
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) (workspace database.WorkspaceTable) {
t.Helper()
tv := dbfake.TemplateVersion(t, db).
@@ -92,27 +91,7 @@ func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UU
)
require.NoError(t, err)
// Create a task record in the tasks table for the new data model.
task := dbgen.Task(t, db, database.TaskTable{
OrganizationID: orgID,
OwnerID: ownerID,
Name: build.Workspace.Name,
WorkspaceID: uuid.NullUUID{UUID: build.Workspace.ID, Valid: true},
TemplateVersionID: tv.TemplateVersion.ID,
TemplateParameters: []byte("{}"),
Prompt: prompt,
CreatedAt: dbtime.Now(),
})
// Link the task to the workspace app.
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
TaskID: task.ID,
WorkspaceBuildNumber: build.Build.BuildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
})
return task
return build.Workspace
}
func TestExpTaskList(t *testing.T) {
@@ -149,7 +128,7 @@ func TestExpTaskList(t *testing.T) {
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
wantPrompt := "build me a web app"
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
inv, root := clitest.New(t, "exp", "task", "list", "--column", "id,name,status,initial prompt")
clitest.SetupConfig(t, memberClient, root)
@@ -161,8 +140,8 @@ func TestExpTaskList(t *testing.T) {
require.NoError(t, err)
// Validate the table includes the task and status.
pty.ExpectMatch(task.Name)
pty.ExpectMatch("initializing")
pty.ExpectMatch(ws.Name)
pty.ExpectMatch("running")
pty.ExpectMatch(wantPrompt)
})
@@ -175,12 +154,12 @@ func TestExpTaskList(t *testing.T) {
owner := coderdtest.CreateFirstUser(t, client)
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Create two AI tasks: one initializing, one paused.
initializingTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me initializing")
pausedTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Create two AI tasks: one running, one stopped.
running := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
stopped := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Use JSON output to reliably validate filtering.
inv, root := clitest.New(t, "exp", "task", "list", "--status=paused", "--output=json")
inv, root := clitest.New(t, "exp", "task", "list", "--status=stopped", "--output=json")
clitest.SetupConfig(t, memberClient, root)
ctx := testutil.Context(t, testutil.WaitShort)
@@ -194,10 +173,10 @@ func TestExpTaskList(t *testing.T) {
var tasks []codersdk.Task
require.NoError(t, json.Unmarshal(stdout.Bytes(), &tasks))
// Only the paused task is returned.
// Only the stopped task is returned.
require.Len(t, tasks, 1, "expected one task after filtering")
require.Equal(t, pausedTask.ID, tasks[0].ID)
require.NotEqual(t, initializingTask.ID, tasks[0].ID)
require.Equal(t, stopped.ID, tasks[0].ID)
require.NotEqual(t, running.ID, tasks[0].ID)
})
t.Run("UserFlag_Me_Table", func(t *testing.T) {
@@ -209,7 +188,7 @@ func TestExpTaskList(t *testing.T) {
_, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_ = makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "other-task")
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
inv, root := clitest.New(t, "exp", "task", "list", "--user", "me")
//nolint:gocritic // Owner client is intended here smoke test the member task not showing up.
@@ -221,7 +200,7 @@ func TestExpTaskList(t *testing.T) {
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
pty.ExpectMatch(task.Name)
pty.ExpectMatch(ws.Name)
})
t.Run("Quiet", func(t *testing.T) {
@@ -234,7 +213,7 @@ func TestExpTaskList(t *testing.T) {
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Given: We have two tasks
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me active")
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
task2 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Given: We add the `--quiet` flag
+15 -7
View File
@@ -3,6 +3,7 @@ package cli
import (
"fmt"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
@@ -40,17 +41,24 @@ func (r *RootCmd) taskLogs() *serpent.Command {
}
var (
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
identifier = inv.Args[0]
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
task = inv.Args[0]
taskID uuid.UUID
)
task, err := exp.TaskByIdentifier(ctx, identifier)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", identifier, err)
if id, err := uuid.Parse(task); err == nil {
taskID = id
} else {
ws, err := namedWorkspace(ctx, client, task)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", task, err)
}
taskID = ws.ID
}
logs, err := exp.TaskLogs(ctx, codersdk.Me, task.ID)
logs, err := exp.TaskLogs(ctx, codersdk.Me, taskID)
if err != nil {
return xerrors.Errorf("get task logs: %w", err)
}
+13 -13
View File
@@ -38,15 +38,15 @@ func Test_TaskLogs(t *testing.T) {
},
}
t.Run("ByTaskName_JSON", func(t *testing.T) {
t.Run("ByWorkspaceName_JSON", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client // user already has access to their own workspace
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.Name, "--output", "json")
inv, root := clitest.New(t, "exp", "task", "logs", workspace.Name, "--output", "json")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
@@ -64,15 +64,15 @@ func Test_TaskLogs(t *testing.T) {
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
})
t.Run("ByTaskID_JSON", func(t *testing.T) {
t.Run("ByWorkspaceID_JSON", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String(), "--output", "json")
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String(), "--output", "json")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
@@ -90,15 +90,15 @@ func Test_TaskLogs(t *testing.T) {
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
})
t.Run("ByTaskID_Table", func(t *testing.T) {
t.Run("ByWorkspaceID_Table", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
@@ -112,7 +112,7 @@ func Test_TaskLogs(t *testing.T) {
require.Contains(t, output, "output")
})
t.Run("TaskNotFound_ByName", func(t *testing.T) {
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -130,7 +130,7 @@ func Test_TaskLogs(t *testing.T) {
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("TaskNotFound_ByID", func(t *testing.T) {
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -152,10 +152,10 @@ func Test_TaskLogs(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
userClient := client
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
+15 -7
View File
@@ -3,6 +3,7 @@ package cli
import (
"io"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
@@ -38,11 +39,12 @@ func (r *RootCmd) taskSend() *serpent.Command {
}
var (
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
identifier = inv.Args[0]
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
task = inv.Args[0]
taskInput string
taskID uuid.UUID
)
if stdin {
@@ -60,12 +62,18 @@ func (r *RootCmd) taskSend() *serpent.Command {
taskInput = inv.Args[1]
}
task, err := exp.TaskByIdentifier(ctx, identifier)
if err != nil {
return xerrors.Errorf("resolve task: %w", err)
if id, err := uuid.Parse(task); err == nil {
taskID = id
} else {
ws, err := namedWorkspace(ctx, client, task)
if err != nil {
return xerrors.Errorf("resolve task: %w", err)
}
taskID = ws.ID
}
if err = exp.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
if err = exp.TaskSend(ctx, codersdk.Me, taskID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
return xerrors.Errorf("send input to task: %w", err)
}
+13 -13
View File
@@ -22,15 +22,15 @@ import (
func Test_TaskSend(t *testing.T) {
t.Parallel()
t.Run("ByTaskName_WithArgument", func(t *testing.T) {
t.Run("ByWorkspaceName_WithArgument", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "carry on with the task")
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
@@ -38,15 +38,15 @@ func Test_TaskSend(t *testing.T) {
require.NoError(t, err)
})
t.Run("ByTaskID_WithArgument", func(t *testing.T) {
t.Run("ByWorkspaceID_WithArgument", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.ID.String(), "carry on with the task")
inv, root := clitest.New(t, "exp", "task", "send", workspace.ID.String(), "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
@@ -54,15 +54,15 @@ func Test_TaskSend(t *testing.T) {
require.NoError(t, err)
})
t.Run("ByTaskName_WithStdin", func(t *testing.T) {
t.Run("ByWorkspaceName_WithStdin", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "--stdin")
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "--stdin")
inv.Stdout = &stdout
inv.Stdin = strings.NewReader("carry on with the task")
clitest.SetupConfig(t, userClient, root)
@@ -71,7 +71,7 @@ func Test_TaskSend(t *testing.T) {
require.NoError(t, err)
})
t.Run("TaskNotFound_ByName", func(t *testing.T) {
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -89,7 +89,7 @@ func Test_TaskSend(t *testing.T) {
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("TaskNotFound_ByID", func(t *testing.T) {
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -111,10 +111,10 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
userClient, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
userClient, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "some task input")
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "some task input")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
+31 -22
View File
@@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
@@ -83,10 +84,21 @@ func (r *RootCmd) taskStatus() *serpent.Command {
}
ctx := i.Context()
exp := codersdk.NewExperimentalClient(client)
ec := codersdk.NewExperimentalClient(client)
identifier := i.Args[0]
task, err := exp.TaskByIdentifier(ctx, identifier)
taskID, err := uuid.Parse(identifier)
if err != nil {
// Try to resolve the task as a named workspace
// TODO: right now tasks are still "workspaces" under the hood.
// We should update this once we have a proper task model.
ws, err := namedWorkspace(ctx, client, identifier)
if err != nil {
return err
}
taskID = ws.ID
}
task, err := ec.TaskByID(ctx, taskID)
if err != nil {
return err
}
@@ -107,7 +119,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
// TODO: implement streaming updates instead of polling
lastStatusRow := tsr
for range t.C {
task, err := exp.TaskByID(ctx, task.ID)
task, err := ec.TaskByID(ctx, taskID)
if err != nil {
return err
}
@@ -140,7 +152,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
}
func taskWatchIsEnded(task codersdk.Task) bool {
if task.WorkspaceStatus == codersdk.WorkspaceStatusStopped {
if task.Status == codersdk.WorkspaceStatusStopped {
return true
}
if task.WorkspaceAgentHealth == nil || !task.WorkspaceAgentHealth.Healthy {
@@ -156,21 +168,28 @@ func taskWatchIsEnded(task codersdk.Task) bool {
}
type taskStatusRow struct {
codersdk.Task `table:"r,recursive_inline"`
ChangedAgo string `json:"-" table:"state changed"`
Healthy bool `json:"-" table:"healthy"`
codersdk.Task `table:"-"`
ChangedAgo string `json:"-" table:"state changed,default_sort"`
Timestamp time.Time `json:"-" table:"-"`
TaskStatus string `json:"-" table:"status"`
Healthy bool `json:"-" table:"healthy"`
TaskState string `json:"-" table:"state"`
Message string `json:"-" table:"message"`
}
func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
return r1.Status == r2.Status &&
return r1.TaskStatus == r2.TaskStatus &&
r1.Healthy == r2.Healthy &&
taskStateEqual(r1.CurrentState, r2.CurrentState)
r1.TaskState == r2.TaskState &&
r1.Message == r2.Message
}
func toStatusRow(task codersdk.Task) taskStatusRow {
tsr := taskStatusRow{
Task: task,
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
Timestamp: task.UpdatedAt,
TaskStatus: string(task.Status),
}
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
task.WorkspaceAgentHealth.Healthy &&
@@ -180,19 +199,9 @@ func toStatusRow(task codersdk.Task) taskStatusRow {
if task.CurrentState != nil {
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
tsr.Timestamp = task.CurrentState.Timestamp
tsr.TaskState = string(task.CurrentState.State)
tsr.Message = task.CurrentState.Message
}
return tsr
}
func taskStateEqual(se1, se2 *codersdk.TaskStateEntry) bool {
var s1, m1, s2, m2 string
if se1 != nil {
s1 = string(se1.State)
m1 = se1.Message
}
if se2 != nil {
s2 = string(se2.State)
m2 = se2.Message
}
return s1 == s2 && m1 == m2
}
+69 -149
View File
@@ -36,17 +36,26 @@ func Test_TaskStatus(t *testing.T) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{},
Count: 0,
})
return
}
case "/api/v2/users/me/workspace/doesnotexist":
httpapi.ResourceNotFound(w)
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
{
args: []string{"err-fetching-workspace"},
expectError: assert.AnError.Error(),
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/workspace/err-fetching-workspace":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.InternalServerError(w, assert.AnError)
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -55,45 +64,21 @@ func Test_TaskStatus(t *testing.T) {
},
{
args: []string{"exists"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
0s ago active true working Thinking furiously...`,
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
0s ago running true working Thinking furiously...`,
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: now,
Message: "Thinking furiously...",
},
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: now,
@@ -103,9 +88,7 @@ func Test_TaskStatus(t *testing.T) {
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
})
return
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -114,77 +97,50 @@ func Test_TaskStatus(t *testing.T) {
},
{
args: []string{"exists", "--watch"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
5s ago pending true
4s ago initializing true
4s ago active true
3s ago active true working Reticulating splines...
2s ago active true complete Splines reticulated successfully!`,
expectOutput: `
STATE CHANGED STATUS HEALTHY STATE MESSAGE
4s ago running true
3s ago running true working Reticulating splines...
2s ago running true complete Splines reticulated successfully!`,
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
var calls atomic.Int64
return func(w http.ResponseWriter, r *http.Request) {
defer calls.Add(1)
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
// Return initial task state for --watch test
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusPending,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-5 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusPending,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
defer calls.Add(1)
switch calls.Load() {
case 0:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusPending,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-5 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusInitializing,
})
return
case 1:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
UpdatedAt: now.Add(-4 * time.Second),
Status: codersdk.TaskStatusActive,
})
return
case 2:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
@@ -194,15 +150,13 @@ func Test_TaskStatus(t *testing.T) {
Timestamp: now.Add(-3 * time.Second),
Message: "Reticulating splines...",
},
Status: codersdk.TaskStatusActive,
})
return
case 3:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
@@ -212,16 +166,13 @@ func Test_TaskStatus(t *testing.T) {
Timestamp: now.Add(-2 * time.Second),
Message: "Splines reticulated successfully!",
},
Status: codersdk.TaskStatusActive,
})
return
default:
httpapi.InternalServerError(w, xerrors.New("too many calls!"))
return
}
default:
httpapi.InternalServerError(w, xerrors.Errorf("unexpected path: %q", r.URL.Path))
return
}
}
},
@@ -232,24 +183,19 @@ func Test_TaskStatus(t *testing.T) {
"id": "11111111-1111-1111-1111-111111111111",
"organization_id": "00000000-0000-0000-0000-000000000000",
"owner_id": "00000000-0000-0000-0000-000000000000",
"owner_name": "me",
"name": "exists",
"owner_name": "",
"name": "",
"template_id": "00000000-0000-0000-0000-000000000000",
"template_version_id": "00000000-0000-0000-0000-000000000000",
"template_name": "",
"template_display_name": "",
"template_icon": "",
"workspace_id": null,
"workspace_name": "",
"workspace_status": "running",
"workspace_agent_id": null,
"workspace_agent_lifecycle": "ready",
"workspace_agent_health": {
"healthy": true
},
"workspace_agent_lifecycle": null,
"workspace_agent_health": null,
"workspace_app_id": null,
"initial_prompt": "",
"status": "active",
"status": "running",
"current_state": {
"timestamp": "2025-08-26T12:34:57Z",
"state": "working",
@@ -259,52 +205,26 @@ func Test_TaskStatus(t *testing.T) {
"created_at": "2025-08-26T12:34:56Z",
"updated_at": "2025-08-26T12:34:56Z"
}`,
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: ts.Add(time.Second),
Message: "Thinking furiously...",
},
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: ts.Add(time.Second),
Message: "Thinking furiously...",
},
Status: codersdk.TaskStatusActive,
})
return
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
+6 -229
View File
@@ -2,242 +2,26 @@ package cli_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
// This test performs an integration-style test for tasks functionality.
//
//nolint:tparallel // The sub-tests of this test must be run sequentially.
func Test_Tasks(t *testing.T) {
t.Parallel()
// Given: a template configured for tasks
var (
ctx = testutil.Context(t, testutil.WaitLong)
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner = coderdtest.CreateFirstUser(t, client)
userClient, _ = coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
initMsg = agentapisdk.Message{
Content: "test task input for " + t.Name(),
Id: 0,
Role: "user",
Time: time.Now().UTC(),
}
authToken = uuid.NewString()
echoAgentAPI = startFakeAgentAPI(t, fakeAgentAPIEcho(ctx, t, initMsg, "hello"))
taskTpl = createAITaskTemplate(t, client, owner.OrganizationID, withAgentToken(authToken), withSidebarURL(echoAgentAPI.URL()))
taskName = strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
)
//nolint:paralleltest // The sub-tests of this test must be run sequentially.
for _, tc := range []struct {
name string
cmdArgs []string
assertFn func(stdout string, userClient *codersdk.Client)
}{
{
name: "create task",
cmdArgs: []string{"exp", "task", "create", "test task input for " + t.Name(), "--name", taskName, "--template", taskTpl.Name},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, taskName, "task name should be in output")
},
},
{
name: "list tasks after create",
cmdArgs: []string{"exp", "task", "list", "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var tasks []codersdk.Task
err := json.NewDecoder(strings.NewReader(stdout)).Decode(&tasks)
require.NoError(t, err, "list output should unmarshal properly")
require.Len(t, tasks, 1, "expected one task")
require.Equal(t, taskName, tasks[0].Name, "task name should match")
require.Equal(t, initMsg.Content, tasks[0].InitialPrompt, "initial prompt should match")
require.True(t, tasks[0].WorkspaceID.Valid, "workspace should be created")
// For the next test, we need to wait for the workspace to be healthy
ws := coderdtest.MustWorkspace(t, userClient, tasks[0].WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
},
},
{
name: "get task status after create",
cmdArgs: []string{"exp", "task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, task.Name, taskName, "task name should match")
require.Equal(t, codersdk.TaskStatusActive, task.Status, "task should be active")
},
},
{
name: "send task message",
cmdArgs: []string{"exp", "task", "send", taskName, "hello"},
// Assertions for this happen in the fake agent API handler.
},
{
name: "read task logs",
cmdArgs: []string{"exp", "task", "logs", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var logs []codersdk.TaskLogEntry
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&logs), "should unmarshal task logs")
require.Len(t, logs, 3, "should have 3 logs")
require.Equal(t, logs[0].Content, initMsg.Content, "first message should be the init message")
require.Equal(t, logs[0].Type, codersdk.TaskLogTypeInput, "first message should be an input")
require.Equal(t, logs[1].Content, "hello", "second message should be the sent message")
require.Equal(t, logs[1].Type, codersdk.TaskLogTypeInput, "second message should be an input")
require.Equal(t, logs[2].Content, "hello", "third message should be the echoed message")
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
},
},
{
name: "delete task",
cmdArgs: []string{"exp", "task", "delete", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
// The task should eventually no longer show up in the list of tasks
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
expClient := codersdk.NewExperimentalClient(userClient)
tasks, err := expClient.Tasks(ctx, &codersdk.TasksFilter{})
if !assert.NoError(t, err) {
return false
}
return slices.IndexFunc(tasks, func(task codersdk.Task) bool {
return task.Name == taskName
}) == -1
}, testutil.IntervalMedium)
},
},
} {
t.Run(tc.name, func(t *testing.T) {
var stdout strings.Builder
inv, root := clitest.New(t, tc.cmdArgs...)
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
require.NoError(t, inv.WithContext(ctx).Run())
if tc.assertFn != nil {
tc.assertFn(stdout.String(), userClient)
}
})
}
}
func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Message, want ...string) map[string]http.HandlerFunc {
t.Helper()
var mmu sync.RWMutex
msgs := []agentapisdk.Message{initMsg}
wantCpy := make([]string, len(want))
copy(wantCpy, want)
t.Cleanup(func() {
mmu.Lock()
defer mmu.Unlock()
if !t.Failed() {
assert.Empty(t, wantCpy, "not all expected messages received: missing %v", wantCpy)
}
})
writeAgentAPIError := func(w http.ResponseWriter, err error, status int) {
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(agentapisdk.ErrorModel{
Errors: ptr.Ref([]agentapisdk.ErrorDetail{
{
Message: ptr.Ref(err.Error()),
},
}),
})
}
return map[string]http.HandlerFunc{
"/status": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(agentapisdk.GetStatusResponse{
Status: "stable",
})
},
"/messages": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
mmu.RLock()
defer mmu.RUnlock()
bs, err := json.Marshal(agentapisdk.GetMessagesResponse{
Messages: msgs,
})
if err != nil {
writeAgentAPIError(w, err, http.StatusBadRequest)
return
}
_, _ = w.Write(bs)
},
"/message": func(w http.ResponseWriter, r *http.Request) {
mmu.Lock()
defer mmu.Unlock()
var params agentapisdk.PostMessageParams
w.Header().Set("Content-Type", "application/json")
err := json.NewDecoder(r.Body).Decode(&params)
if !assert.NoError(t, err, "decode message") {
writeAgentAPIError(w, err, http.StatusBadRequest)
return
}
if len(wantCpy) == 0 {
assert.Fail(t, "unexpected message", "received message %v, but no more expected messages", params)
writeAgentAPIError(w, xerrors.New("no more expected messages"), http.StatusBadRequest)
return
}
exp := wantCpy[0]
wantCpy = wantCpy[1:]
if !assert.Equal(t, exp, params.Content, "message content mismatch") {
writeAgentAPIError(w, xerrors.New("unexpected message content: expected "+exp+", got "+params.Content), http.StatusBadRequest)
return
}
msgs = append(msgs, agentapisdk.Message{
Id: int64(len(msgs) + 1),
Content: params.Content,
Role: agentapisdk.RoleUser,
Time: time.Now().UTC(),
})
msgs = append(msgs, agentapisdk.Message{
Id: int64(len(msgs) + 1),
Content: params.Content,
Role: agentapisdk.RoleAgent,
Time: time.Now().UTC(),
})
assert.NoError(t, json.NewEncoder(w).Encode(agentapisdk.PostMessageResponse{
Ok: true,
}))
},
}
}
// setupCLITaskTest creates a test workspace with an AI task template and agent,
// with a fake agent API configured with the provided set of handlers.
// Returns the user client and workspace.
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Workspace) {
t.Helper()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
@@ -250,18 +34,11 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
wantPrompt := "test prompt"
exp := codersdk.NewExperimentalClient(userClient)
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: wantPrompt,
Name: "test-task",
workspace := coderdtest.CreateWorkspace(t, userClient, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: codersdk.AITaskPromptParameterName, Value: wantPrompt},
}
})
require.NoError(t, err)
// Wait for the task's underlying workspace to be built
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
@@ -272,7 +49,7 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
WaitFor(coderdtest.AgentsReady)
return userClient, task
return userClient, workspace
}
// createAITaskTemplate creates a template configured for AI tasks with a sidebar app.
-16
View File
@@ -176,22 +176,6 @@ func (r *RootCmd) scheduleStart() *serpent.Command {
}
schedStr = ptr.Ref(sched.String())
// Check if the template has autostart requirements that may conflict
// with the user's schedule.
template, err := client.Template(inv.Context(), workspace.TemplateID)
if err != nil {
return xerrors.Errorf("get template: %w", err)
}
if len(template.AutostartRequirement.DaysOfWeek) > 0 {
_, _ = fmt.Fprintf(
inv.Stderr,
"Warning: your workspace template restricts autostart to the following days: %s.\n"+
"Your workspace may only autostart on these days.\n",
strings.Join(template.AutostartRequirement.DaysOfWeek, ", "),
)
}
}
err = client.UpdateWorkspaceAutostart(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
-64
View File
@@ -373,67 +373,3 @@ func TestScheduleOverride(t *testing.T) {
})
}
}
//nolint:paralleltest // t.Setenv
func TestScheduleStart_TemplateAutostartRequirement(t *testing.T) {
t.Setenv("TZ", "UTC")
loc, err := tz.TimezoneIANA()
require.NoError(t, err)
require.Equal(t, "UTC", loc.String())
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
// Update template to have autostart requirement
// Note: In AGPL, this will be ignored and all days will be allowed (enterprise feature).
template, err = client.UpdateTemplateMeta(context.Background(), template.ID, codersdk.UpdateTemplateMeta{
AutostartRequirement: &codersdk.TemplateAutostartRequirement{
DaysOfWeek: []string{"monday", "wednesday", "friday"},
},
})
require.NoError(t, err)
// Verify the template - in AGPL, AutostartRequirement will have all days (enterprise feature)
template, err = client.Template(context.Background(), template.ID)
require.NoError(t, err)
require.NotEmpty(t, template.AutostartRequirement.DaysOfWeek, "template should have autostart requirement days")
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
t.Run("ShowsWarning", func(t *testing.T) {
// When: user sets autostart schedule
inv, root := clitest.New(t,
"schedule", "start", workspace.Name, "9:30AM", "Mon-Fri",
)
clitest.SetupConfig(t, client, root)
pty := ptytest.New(t).Attach(inv)
require.NoError(t, inv.Run())
// Then: warning should be shown
// In AGPL, this will show all days (enterprise feature defaults to all days allowed)
pty.ExpectMatch("Warning")
pty.ExpectMatch("may only autostart")
})
t.Run("NoWarningWhenManual", func(t *testing.T) {
// When: user sets manual schedule
inv, root := clitest.New(t,
"schedule", "start", workspace.Name, "manual",
)
clitest.SetupConfig(t, client, root)
var stderrBuf bytes.Buffer
inv.Stderr = &stderrBuf
require.NoError(t, inv.Run())
// Then: no warning should be shown on stderr
stderrOutput := stderrBuf.String()
require.NotContains(t, stderrOutput, "Warning")
})
}
@@ -17,6 +17,9 @@ import (
func TestRegenerateVapidKeypair(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test is only supported on postgres")
}
t.Run("NoExistingVAPIDKeys", func(t *testing.T) {
t.Parallel()
+11
View File
@@ -348,6 +348,9 @@ func TestServer(t *testing.T) {
runGitHubProviderTest := func(t *testing.T, tc testCase) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("test requires postgres")
}
ctx, cancelFunc := context.WithCancel(testutil.Context(t, testutil.WaitLong))
defer cancelFunc()
@@ -2139,6 +2142,10 @@ func TestServerYAMLConfig(t *testing.T) {
func TestConnectToPostgres(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test does not make sense without postgres")
}
t.Run("Migrate", func(t *testing.T) {
t.Parallel()
@@ -2249,6 +2256,10 @@ type runServerOpts struct {
func TestServer_TelemetryDisabled_FinalReport(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
telemetryServerURL, deployment, snapshot := mockTelemetryServer(t)
dbConnURL, err := dbtestutil.Open(t)
require.NoError(t, err)
-47
View File
@@ -109,51 +109,6 @@ func (r *RootCmd) ssh() *serpent.Command {
}
},
),
CompletionHandler: func(inv *serpent.Invocation) []string {
client, err := r.InitClient(inv)
if err != nil {
return []string{}
}
res, err := client.Workspaces(inv.Context(), codersdk.WorkspaceFilter{
Owner: codersdk.Me,
})
if err != nil {
return []string{}
}
var mu sync.Mutex
var completions []string
var wg sync.WaitGroup
for _, ws := range res.Workspaces {
wg.Add(1)
go func() {
defer wg.Done()
resources, err := client.TemplateVersionResources(inv.Context(), ws.LatestBuild.TemplateVersionID)
if err != nil {
return
}
var agents []codersdk.WorkspaceAgent
for _, resource := range resources {
agents = append(agents, resource.Agents...)
}
mu.Lock()
defer mu.Unlock()
if len(agents) == 1 {
completions = append(completions, ws.Name)
} else {
for _, agent := range agents {
completions = append(completions, fmt.Sprintf("%s.%s", ws.Name, agent.Name))
}
}
}()
}
wg.Wait()
slices.Sort(completions)
return completions
},
Handler: func(inv *serpent.Invocation) (retErr error) {
client, err := r.InitClient(inv)
if err != nil {
@@ -951,8 +906,6 @@ func GetWorkspaceAndAgent(ctx context.Context, inv *serpent.Invocation, client *
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with active template version: %w", err)
}
_, _ = fmt.Fprintln(inv.Stdout, "Unable to start the workspace with template version from last build. Your workspace has been updated to the current active template version.")
default:
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
}
} else if err != nil {
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, nil, xerrors.Errorf("start workspace with current template version: %w", err)
-96
View File
@@ -2447,99 +2447,3 @@ func tempDirUnixSocket(t *testing.T) string {
return t.TempDir()
}
func TestSSH_Completion(t *testing.T) {
t.Parallel()
t.Run("SingleAgent", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)
_ = agenttest.New(t, client.URL, agentToken)
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, client, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For single-agent workspaces, the only completion should be the
// bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
require.Contains(t, output, workspace.Name)
})
t.Run("MultiAgent", func(t *testing.T) {
t.Parallel()
client, store := coderdtest.NewWithDatabase(t, nil)
first := coderdtest.CreateFirstUser(t, client)
userClient, user := coderdtest.CreateAnotherUserMutators(t, client, first.OrganizationID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.Username = "multiuser"
})
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
Name: "multiworkspace",
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
return []*proto.Agent{
{
Name: "agent1",
Auth: &proto.Agent_Token{},
},
{
Name: "agent2",
Auth: &proto.Agent_Token{},
},
}
}).Do()
var stdout bytes.Buffer
inv, root := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
clitest.SetupConfig(t, userClient, root)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
// For multi-agent workspaces, completions should include the
// workspace.agent format but NOT the bare workspace name.
output := stdout.String()
t.Logf("Completion output: %q", output)
lines := strings.Split(strings.TrimSpace(output), "\n")
require.NotContains(t, lines, r.Workspace.Name)
require.Contains(t, output, r.Workspace.Name+".agent1")
require.Contains(t, output, r.Workspace.Name+".agent2")
})
t.Run("NetworkError", func(t *testing.T) {
t.Parallel()
var stdout bytes.Buffer
inv, _ := clitest.New(t, "ssh", "")
inv.Stdout = &stdout
inv.Environ.Set("COMPLETION_MODE", "1")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
output := stdout.String()
require.Empty(t, output)
})
}
-17
View File
@@ -87,7 +87,6 @@ func buildNumberOption(n *int64) serpent.Option {
func (r *RootCmd) statePush() *serpent.Command {
var buildNumber int64
var noBuild bool
cmd := &serpent.Command{
Use: "push <workspace> <file>",
Short: "Push a Terraform state file to a workspace.",
@@ -127,16 +126,6 @@ func (r *RootCmd) statePush() *serpent.Command {
return err
}
if noBuild {
// Update state directly without triggering a build.
err = client.UpdateWorkspaceBuildState(inv.Context(), build.ID, state)
if err != nil {
return err
}
_, _ = fmt.Fprintln(inv.Stdout, "State updated successfully.")
return nil
}
build, err = client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{
TemplateVersionID: build.TemplateVersionID,
Transition: build.Transition,
@@ -150,12 +139,6 @@ func (r *RootCmd) statePush() *serpent.Command {
}
cmd.Options = serpent.OptionSet{
buildNumberOption(&buildNumber),
{
Flag: "no-build",
FlagShorthand: "n",
Description: "Update the state without triggering a workspace build. Useful for state-only migrations.",
Value: serpent.BoolOf(&noBuild),
},
}
return cmd
}
-47
View File
@@ -2,7 +2,6 @@ package cli_test
import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"
@@ -11,7 +10,6 @@ import (
"testing"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/stretchr/testify/require"
@@ -160,49 +158,4 @@ func TestStatePush(t *testing.T) {
err := inv.Run()
require.NoError(t, err)
})
t.Run("NoBuild", func(t *testing.T) {
t.Parallel()
client, store := coderdtest.NewWithDatabase(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
templateAdmin, taUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.RoleTemplateAdmin())
initialState := []byte("initial state")
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
require.NoError(t, err)
_, err = stateFile.Write(wantState)
require.NoError(t, err)
err = stateFile.Close()
require.NoError(t, err)
inv, root := clitest.New(t, "state", "push", "--no-build", r.Workspace.Name, stateFile.Name())
clitest.SetupConfig(t, templateAdmin, root)
var stdout bytes.Buffer
inv.Stdout = &stdout
err = inv.Run()
require.NoError(t, err)
require.Contains(t, stdout.String(), "State updated successfully")
// Verify the state was updated by pulling it.
inv, root = clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
inv.Stdout = &gotState
clitest.SetupConfig(t, templateAdmin, root)
err = inv.Run()
require.NoError(t, err)
require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes()))
// Verify no new build was created.
builds, err := store.GetWorkspaceBuildsByWorkspaceID(dbauthz.AsSystemRestricted(context.Background()), database.GetWorkspaceBuildsByWorkspaceIDParams{
WorkspaceID: r.Workspace.ID,
})
require.NoError(t, err)
require.Len(t, builds, 1, "expected only the initial build, no new build should be created")
})
}
+1 -2
View File
@@ -90,7 +90,6 @@
"allow_renames": false,
"favorite": false,
"next_start_at": "====[timestamp]=====",
"is_prebuild": false,
"task_id": null
"is_prebuild": false
}
]
-35
View File
@@ -80,41 +80,6 @@ OPTIONS:
Periodically check for new releases of Coder and inform the owner. The
check is performed once per day.
AIBRIDGE OPTIONS:
--aibridge-anthropic-base-url string, $CODER_AIBRIDGE_ANTHROPIC_BASE_URL (default: https://api.anthropic.com/)
The base URL of the Anthropic API.
--aibridge-anthropic-key string, $CODER_AIBRIDGE_ANTHROPIC_KEY
The key to authenticate against the Anthropic API.
--aibridge-bedrock-access-key string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY
The access key to authenticate against the AWS Bedrock API.
--aibridge-bedrock-access-key-secret string, $CODER_AIBRIDGE_BEDROCK_ACCESS_KEY_SECRET
The access key secret to use with the access key to authenticate
against the AWS Bedrock API.
--aibridge-bedrock-model string, $CODER_AIBRIDGE_BEDROCK_MODEL (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0)
The model to use when making requests to the AWS Bedrock API.
--aibridge-bedrock-region string, $CODER_AIBRIDGE_BEDROCK_REGION
The AWS Bedrock API region.
--aibridge-bedrock-small-fastmodel string, $CODER_AIBRIDGE_BEDROCK_SMALL_FAST_MODEL (default: global.anthropic.claude-haiku-4-5-20251001-v1:0)
The small fast model to use when making requests to the AWS Bedrock
API. Claude Code uses Haiku-class models to perform background tasks.
See
https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
--aibridge-enabled bool, $CODER_AIBRIDGE_ENABLED (default: false)
Whether to start an in-memory aibridged instance.
--aibridge-openai-base-url string, $CODER_AIBRIDGE_OPENAI_BASE_URL (default: https://api.openai.com/v1/)
The base URL of the OpenAI API.
--aibridge-openai-key string, $CODER_AIBRIDGE_OPENAI_KEY
The key to authenticate against the OpenAI API.
CLIENT OPTIONS:
These options change the behavior of how clients interact with the Coder.
Clients include the Coder CLI, Coder Desktop, IDE extensions, and the web UI.
-4
View File
@@ -9,9 +9,5 @@ OPTIONS:
-b, --build int
Specify a workspace build to target by name. Defaults to latest.
-n, --no-build bool
Update the state without triggering a workspace build. Useful for
state-only migrations.
———
Run `coder --help` for a list of global options.
-5
View File
@@ -16,10 +16,6 @@ USAGE:
$ coder tokens ls
- Create a scoped token:
$ coder tokens create --scope workspace:read --allow workspace:<uuid>
- Remove a token by ID:
$ coder tokens rm WuoWs4ZsMX
@@ -28,7 +24,6 @@ SUBCOMMANDS:
create Create a token
list List tokens
remove Delete a token
view Display detailed information about a token
———
Run `coder --help` for a list of global options.
+1 -9
View File
@@ -6,20 +6,12 @@ USAGE:
Create a token
OPTIONS:
--allow allow-list
Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).
--lifetime string, $CODER_TOKEN_LIFETIME
Duration for the token lifetime. Supports standard Go duration units
(ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d,
1y, 1d12h30m.
Specify a duration for the lifetime of the token.
-n, --name string, $CODER_TOKEN_NAME
Specify a human-readable name.
--scope string-array
Repeatable scope to attach to the token (e.g. workspace:read).
-u, --user string, $CODER_TOKEN_USER
Specify the user to create the token for (Only works if logged in user
is admin).
+1 -1
View File
@@ -12,7 +12,7 @@ OPTIONS:
Specifies whether all users' tokens will be listed or not (must have
Owner role to see all tokens).
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
-c, --column [id|name|last used|expires at|created at|owner] (default: id,name,last used,expires at,created at)
Columns to display in table output.
-o, --output table|json (default: table)
-16
View File
@@ -1,16 +0,0 @@
coder v0.0.0-devel
USAGE:
coder tokens view [flags] <name|id>
Display detailed information about a token
OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at,owner)
Columns to display in table output.
-o, --output table|json (default: table)
Output format.
———
Run `coder --help` for a list of global options.
+4 -21
View File
@@ -714,7 +714,8 @@ workspace_prebuilds:
# (default: 3, type: int)
failure_hard_limit: 3
aibridge:
# Whether to start an in-memory aibridged instance.
# Whether to start an in-memory aibridged instance ("aibridge" experiment must be
# enabled, too).
# (default: false, type: bool)
enabled: false
# The base URL of the OpenAI API.
@@ -725,25 +726,7 @@ aibridge:
openai_key: ""
# The base URL of the Anthropic API.
# (default: https://api.anthropic.com/, type: string)
anthropic_base_url: https://api.anthropic.com/
base_url: https://api.anthropic.com/
# The key to authenticate against the Anthropic API.
# (default: <unset>, type: string)
anthropic_key: ""
# The AWS Bedrock API region.
# (default: <unset>, type: string)
bedrock_region: ""
# The access key to authenticate against the AWS Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key: ""
# The access key secret to use with the access key to authenticate against the AWS
# Bedrock API.
# (default: <unset>, type: string)
bedrock_access_key_secret: ""
# The model to use when making requests to the AWS Bedrock API.
# (default: global.anthropic.claude-sonnet-4-5-20250929-v1:0, type: string)
bedrock_model: global.anthropic.claude-sonnet-4-5-20250929-v1:0
# The small fast model to use when making requests to the AWS Bedrock API. Claude
# Code uses Haiku-class models to perform background tasks. See
# https://docs.claude.com/en/docs/claude-code/settings#environment-variables.
# (default: global.anthropic.claude-haiku-4-5-20251001-v1:0, type: string)
bedrock_small_fast_model: global.anthropic.claude-haiku-4-5-20251001-v1:0
key: ""
+6 -104
View File
@@ -4,14 +4,12 @@ import (
"fmt"
"os"
"slices"
"sort"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -29,10 +27,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Description: "List your tokens",
Command: "coder tokens ls",
},
Example{
Description: "Create a scoped token",
Command: "coder tokens create --scope workspace:read --allow workspace:<uuid>",
},
Example{
Description: "Remove a token by ID",
Command: "coder tokens rm WuoWs4ZsMX",
@@ -45,7 +39,6 @@ func (r *RootCmd) tokens() *serpent.Command {
Children: []*serpent.Command{
r.createToken(),
r.listTokens(),
r.viewToken(),
r.removeToken(),
},
}
@@ -57,8 +50,6 @@ func (r *RootCmd) createToken() *serpent.Command {
tokenLifetime string
name string
user string
scopes []string
allowList []codersdk.APIAllowListTarget
)
cmd := &serpent.Command{
Use: "create",
@@ -97,18 +88,10 @@ func (r *RootCmd) createToken() *serpent.Command {
}
}
req := codersdk.CreateTokenRequest{
res, err := client.CreateToken(inv.Context(), userID, codersdk.CreateTokenRequest{
Lifetime: parsedLifetime,
TokenName: name,
}
if len(req.Scopes) == 0 {
req.Scopes = slice.StringEnums[codersdk.APIKeyScope](scopes)
}
if len(allowList) > 0 {
req.AllowList = append([]codersdk.APIAllowListTarget(nil), allowList...)
}
res, err := client.CreateToken(inv.Context(), userID, req)
})
if err != nil {
return xerrors.Errorf("create tokens: %w", err)
}
@@ -123,7 +106,7 @@ func (r *RootCmd) createToken() *serpent.Command {
{
Flag: "lifetime",
Env: "CODER_TOKEN_LIFETIME",
Description: "Duration for the token lifetime. Supports standard Go duration units (ns, us, ms, s, m, h) plus d (days) and y (years). Examples: 8h, 30d, 1y, 1d12h30m.",
Description: "Specify a duration for the lifetime of the token.",
Value: serpent.StringOf(&tokenLifetime),
},
{
@@ -140,16 +123,6 @@ func (r *RootCmd) createToken() *serpent.Command {
Description: "Specify the user to create the token for (Only works if logged in user is admin).",
Value: serpent.StringOf(&user),
},
{
Flag: "scope",
Description: "Repeatable scope to attach to the token (e.g. workspace:read).",
Value: serpent.StringArrayOf(&scopes),
},
{
Flag: "allow",
Description: "Repeatable allow-list entry (<type>:<uuid>, e.g. workspace:1234-...).",
Value: AllowListFlagOf(&allowList),
},
}
return cmd
@@ -163,8 +136,6 @@ type tokenListRow struct {
// For table format:
ID string `json:"-" table:"id,default_sort"`
TokenName string `json:"token_name" table:"name"`
Scopes string `json:"-" table:"scopes"`
Allow string `json:"-" table:"allow list"`
LastUsed time.Time `json:"-" table:"last used"`
ExpiresAt time.Time `json:"-" table:"expires at"`
CreatedAt time.Time `json:"-" table:"created at"`
@@ -172,47 +143,20 @@ type tokenListRow struct {
}
func tokenListRowFromToken(token codersdk.APIKeyWithOwner) tokenListRow {
return tokenListRowFromKey(token.APIKey, token.Username)
}
func tokenListRowFromKey(token codersdk.APIKey, owner string) tokenListRow {
return tokenListRow{
APIKey: token,
APIKey: token.APIKey,
ID: token.ID,
TokenName: token.TokenName,
Scopes: joinScopes(token.Scopes),
Allow: joinAllowList(token.AllowList),
LastUsed: token.LastUsed,
ExpiresAt: token.ExpiresAt,
CreatedAt: token.CreatedAt,
Owner: owner,
Owner: token.Username,
}
}
func joinScopes(scopes []codersdk.APIKeyScope) string {
if len(scopes) == 0 {
return ""
}
vals := slice.ToStrings(scopes)
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func joinAllowList(entries []codersdk.APIAllowListTarget) string {
if len(entries) == 0 {
return ""
}
vals := make([]string, len(entries))
for i, entry := range entries {
vals[i] = entry.String()
}
sort.Strings(vals)
return strings.Join(vals, ", ")
}
func (r *RootCmd) listTokens() *serpent.Command {
// we only display the 'owner' column if the --all argument is passed in
defaultCols := []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at"}
defaultCols := []string{"id", "name", "last used", "expires at", "created at"}
if slices.Contains(os.Args, "-a") || slices.Contains(os.Args, "--all") {
defaultCols = append(defaultCols, "owner")
}
@@ -282,48 +226,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
return cmd
}
func (r *RootCmd) viewToken() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, []string{"id", "name", "scopes", "allow list", "last used", "expires at", "created at", "owner"}),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "view <name|id>",
Short: "Display detailed information about a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
tokenName := inv.Args[0]
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, tokenName)
if err != nil {
maybeID := strings.Split(tokenName, "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
return xerrors.Errorf("fetch api key by name or id: %w", err)
}
}
row := tokenListRowFromKey(*token, "")
out, err := formatter.Format(inv.Context(), []tokenListRow{row})
if err != nil {
return err
}
_, err = fmt.Fprintln(inv.Stdout, out)
return err
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
func (r *RootCmd) removeToken() *serpent.Command {
cmd := &serpent.Command{
Use: "remove <name|id|token>",
+3 -56
View File
@@ -4,13 +4,10 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/google/uuid"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
@@ -49,18 +46,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
id := res[:10]
allowWorkspaceID := uuid.New()
allowSpec := fmt.Sprintf("workspace:%s", allowWorkspaceID.String())
inv, root = clitest.New(t, "tokens", "create", "--name", "scoped-token", "--scope", string(codersdk.APIKeyScopeWorkspaceRead), "--allow", allowSpec)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
scopedTokenID := res[:10]
// Test creating a token for second user from first user's (admin) session
inv, root = clitest.New(t, "tokens", "create", "--name", "token-two", "--user", secondUser.ID.String())
clitest.SetupConfig(t, client, root)
@@ -82,7 +67,7 @@ func TestTokens(t *testing.T) {
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
// Result should only contain the tokens created for the admin user
// Result should only contain the token created for the admin user
require.Contains(t, res, "ID")
require.Contains(t, res, "EXPIRES AT")
require.Contains(t, res, "CREATED AT")
@@ -91,16 +76,6 @@ func TestTokens(t *testing.T) {
// Result should not contain the token created for the second user
require.NotContains(t, res, secondTokenID)
inv, root = clitest.New(t, "tokens", "view", "scoped-token")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.Contains(t, res, string(codersdk.APIKeyScopeWorkspaceRead))
require.Contains(t, res, allowSpec)
// Test listing tokens from the second user's session
inv, root = clitest.New(t, "tokens", "ls")
clitest.SetupConfig(t, secondUserClient, root)
@@ -126,14 +101,6 @@ func TestTokens(t *testing.T) {
// User (non-admin) should not be able to create a token for another user
require.Error(t, err)
inv, root = clitest.New(t, "tokens", "create", "--name", "invalid-allow", "--allow", "badvalue")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "invalid allow_list entry")
inv, root = clitest.New(t, "tokens", "ls", "--output=json")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
@@ -143,17 +110,8 @@ func TestTokens(t *testing.T) {
var tokens []codersdk.APIKey
require.NoError(t, json.Unmarshal(buf.Bytes(), &tokens))
require.Len(t, tokens, 2)
tokenByName := make(map[string]codersdk.APIKey, len(tokens))
for _, tk := range tokens {
tokenByName[tk.TokenName] = tk
}
require.Contains(t, tokenByName, "token-one")
require.Contains(t, tokenByName, "scoped-token")
scopedToken := tokenByName["scoped-token"]
require.Contains(t, scopedToken.Scopes, codersdk.APIKeyScopeWorkspaceRead)
require.Len(t, scopedToken.AllowList, 1)
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
require.Len(t, tokens, 1)
require.Equal(t, id, tokens[0].ID)
// Delete by name
inv, root = clitest.New(t, "tokens", "rm", "token-one")
@@ -177,17 +135,6 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Create third token
inv, root = clitest.New(t, "tokens", "create", "--name", "token-three")
clitest.SetupConfig(t, client, root)
-4
View File
@@ -239,10 +239,6 @@ func (a *API) Serve(ctx context.Context, l net.Listener) error {
return xerrors.Errorf("create agent API server: %w", err)
}
if err := a.ResourcesMonitoringAPI.InitMonitors(ctx); err != nil {
return xerrors.Errorf("initialize resource monitoring: %w", err)
}
return server.Serve(ctx, l)
}
+35 -52
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"sync"
"time"
"golang.org/x/xerrors"
@@ -34,60 +33,42 @@ type ResourcesMonitoringAPI struct {
Debounce time.Duration
Config resourcesmonitor.Config
// Cache resource monitors on first call to avoid millions of DB queries per day.
memoryMonitor database.WorkspaceAgentMemoryResourceMonitor
volumeMonitors []database.WorkspaceAgentVolumeResourceMonitor
monitorsLock sync.RWMutex
}
// InitMonitors fetches resource monitors from the database and caches them.
// This must be called once after creating a ResourcesMonitoringAPI, the context should be
// the agent per-RPC connection context. If fetching fails with a real error (not sql.ErrNoRows), the
// connection should be torn down.
func (a *ResourcesMonitoringAPI) InitMonitors(ctx context.Context) error {
memMon, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
// If sql.ErrNoRows, memoryMonitor stays as zero value (CreatedAt.IsZero() = true).
// Otherwise, store the fetched monitor.
if err == nil {
a.memoryMonitor = memMon
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(ctx context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
memoryMonitor, memoryErr := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if memoryErr != nil && !errors.Is(memoryErr, sql.ErrNoRows) {
return nil, xerrors.Errorf("failed to fetch memory resource monitor: %w", memoryErr)
}
volMons, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("fetch volume resource monitors: %w", err)
return nil, xerrors.Errorf("failed to fetch volume resource monitors: %w", err)
}
// 0 length is valid, indicating none configured, since the volume monitors in the DB can be many.
a.volumeMonitors = volMons
return nil
}
func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.Context, _ *proto.GetResourcesMonitoringConfigurationRequest) (*proto.GetResourcesMonitoringConfigurationResponse, error) {
return &proto.GetResourcesMonitoringConfigurationResponse{
Config: &proto.GetResourcesMonitoringConfigurationResponse_Config{
CollectionIntervalSeconds: int32(a.Config.CollectionInterval.Seconds()),
NumDatapoints: a.Config.NumDatapoints,
},
Memory: func() *proto.GetResourcesMonitoringConfigurationResponse_Memory {
if a.memoryMonitor.CreatedAt.IsZero() {
if memoryErr != nil {
return nil
}
return &proto.GetResourcesMonitoringConfigurationResponse_Memory{
Enabled: a.memoryMonitor.Enabled,
Enabled: memoryMonitor.Enabled,
}
}(),
Volumes: func() []*proto.GetResourcesMonitoringConfigurationResponse_Volume {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(a.volumeMonitors))
for _, monitor := range a.volumeMonitors {
volumes := make([]*proto.GetResourcesMonitoringConfigurationResponse_Volume, 0, len(volumeMonitors))
for _, monitor := range volumeMonitors {
volumes = append(volumes, &proto.GetResourcesMonitoringConfigurationResponse_Volume{
Enabled: monitor.Enabled,
Path: monitor.Path,
})
}
return volumes
}(),
}, nil
@@ -96,10 +77,6 @@ func (a *ResourcesMonitoringAPI) GetResourcesMonitoringConfiguration(_ context.C
func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Context, req *proto.PushResourcesMonitoringUsageRequest) (*proto.PushResourcesMonitoringUsageResponse, error) {
var err error
// Lock for the entire push operation since calls are sequential from the agent
a.monitorsLock.Lock()
defer a.monitorsLock.Unlock()
if memoryErr := a.monitorMemory(ctx, req.Datapoints); memoryErr != nil {
err = errors.Join(err, xerrors.Errorf("monitor memory: %w", memoryErr))
}
@@ -112,7 +89,18 @@ func (a *ResourcesMonitoringAPI) PushResourcesMonitoringUsage(ctx context.Contex
}
func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
if !a.memoryMonitor.Enabled {
monitor, err := a.Database.FetchMemoryResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
// It is valid for an agent to not have a memory monitor, so we
// do not want to treat it as an error.
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return xerrors.Errorf("fetch memory resource monitor: %w", err)
}
if !monitor.Enabled {
return nil
}
@@ -121,15 +109,15 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
usageDatapoints = append(usageDatapoints, datapoint.Memory)
}
usageStates := resourcesmonitor.CalculateMemoryUsageStates(a.memoryMonitor, usageDatapoints)
usageStates := resourcesmonitor.CalculateMemoryUsageStates(monitor, usageDatapoints)
oldState := a.memoryMonitor.State
oldState := monitor.State
newState := resourcesmonitor.NextState(a.Config, oldState, usageStates)
debouncedUntil, shouldNotify := a.memoryMonitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
debouncedUntil, shouldNotify := monitor.Debounce(a.Debounce, a.Clock.Now(), oldState, newState)
//nolint:gocritic // We need to be able to update the resource monitor here.
err := a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
err = a.Database.UpdateMemoryResourceMonitor(dbauthz.AsResourceMonitor(ctx), database.UpdateMemoryResourceMonitorParams{
AgentID: a.AgentID,
State: newState,
UpdatedAt: dbtime.Time(a.Clock.Now()),
@@ -139,11 +127,6 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.memoryMonitor.State = newState
a.memoryMonitor.DebouncedUntil = dbtime.Time(debouncedUntil)
a.memoryMonitor.UpdatedAt = dbtime.Time(a.Clock.Now())
if !shouldNotify {
return nil
}
@@ -160,7 +143,7 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
notifications.TemplateWorkspaceOutOfMemory,
map[string]string{
"workspace": workspace.Name,
"threshold": fmt.Sprintf("%d%%", a.memoryMonitor.Threshold),
"threshold": fmt.Sprintf("%d%%", monitor.Threshold),
},
map[string]any{
// NOTE(DanielleMaywood):
@@ -186,9 +169,14 @@ func (a *ResourcesMonitoringAPI) monitorMemory(ctx context.Context, datapoints [
}
func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints []*proto.PushResourcesMonitoringUsageRequest_Datapoint) error {
volumeMonitors, err := a.Database.FetchVolumesResourceMonitorsByAgentID(ctx, a.AgentID)
if err != nil {
return xerrors.Errorf("get or insert volume monitor: %w", err)
}
outOfDiskVolumes := make([]map[string]any, 0)
for i, monitor := range a.volumeMonitors {
for _, monitor := range volumeMonitors {
if !monitor.Enabled {
continue
}
@@ -231,11 +219,6 @@ func (a *ResourcesMonitoringAPI) monitorVolumes(ctx context.Context, datapoints
}); err != nil {
return xerrors.Errorf("update workspace monitor: %w", err)
}
// Update cached state
a.volumeMonitors[i].State = newState
a.volumeMonitors[i].DebouncedUntil = dbtime.Time(debouncedUntil)
a.volumeMonitors[i].UpdatedAt = dbtime.Time(a.Clock.Now())
}
if len(outOfDiskVolumes) == 0 {
@@ -101,9 +101,6 @@ func TestMemoryResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: The monitor is given a state that will trigger NOK
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -307,9 +304,6 @@ func TestMemoryResourceMonitor(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -343,8 +337,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
State: database.WorkspaceAgentMonitorStateOK,
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
@@ -395,9 +387,6 @@ func TestMemoryResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -477,9 +466,6 @@ func TestVolumeResourceMonitorDebounce(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When:
// - First monitor is in a NOK state
// - Second monitor is in an OK state
@@ -756,9 +742,6 @@ func TestVolumeResourceMonitor(t *testing.T) {
Threshold: tt.thresholdPercent,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
clock.Set(collectedAt)
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: datapoints,
@@ -797,9 +780,6 @@ func TestVolumeResourceMonitorMultiple(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: both of them move to a NOK state
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -852,9 +832,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two NOK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
@@ -914,9 +891,6 @@ func TestVolumeResourceMonitorMissingData(t *testing.T) {
Threshold: 80,
})
// Initialize API to fetch and cache the monitors
require.NoError(t, api.InitMonitors(context.Background()))
// When: A datapoint is missing, surrounded by two OK datapoints.
_, err := api.PushResourcesMonitoringUsage(context.Background(), &agentproto.PushResourcesMonitoringUsageRequest{
Datapoints: []*agentproto.PushResourcesMonitoringUsageRequest_Datapoint{
+376 -304
View File
@@ -2,6 +2,8 @@ package coderd
import (
"context"
"database/sql"
"errors"
"fmt"
"net"
"net/http"
@@ -10,13 +12,12 @@ import (
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw"
@@ -24,6 +25,7 @@ import (
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/coderd/taskname"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
aiagentapi "github.com/coder/agentapi-sdk-go"
@@ -94,56 +96,33 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
// This endpoint creates a new task for the given user.
func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
auditor = api.Auditor.Load()
mems = httpmw.OrganizationMembersParam(r)
taskResourceInfo = audit.AdditionalFields{}
ctx = r.Context()
apiKey = httpmw.APIKey(r)
auditor = api.Auditor.Load()
mems = httpmw.OrganizationMembersParam(r)
)
if mems.User != nil {
taskResourceInfo.WorkspaceOwner = mems.User.Username
}
aReq, commitAudit := audit.InitRequest[database.TaskTable](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
AdditionalFields: taskResourceInfo,
})
defer commitAudit()
var req codersdk.CreateTaskRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
// Fetch the template version to verify access and whether or not it has an
// AI task.
templateVersion, err := api.Database.GetTemplateVersionByID(ctx, req.TemplateVersionID)
hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID)
if err != nil {
if httpapi.Is404Error(err) {
// Avoid using httpapi.ResourceNotFound() here because this is an
// input error and 404 would be confusing.
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Template version not found or you do not have access to this resource",
})
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching template version.",
Message: "Internal error fetching whether the template version has an AI task.",
Detail: err.Error(),
})
return
}
aReq.UpdateOrganizationID(templateVersion.OrganizationID)
if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
if !hasAITask {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: `Template does not have a valid "coder_ai_task" resource.`,
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
})
return
}
@@ -198,12 +177,23 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
} else {
// A task can still be created if the caller can read the organization
// member. The organization is required, which can be sourced from the
// templateVersion.
// template.
//
// TODO: This code gets called twice for each workspace build request.
// This is inefficient and costs at most 2 extra RTTs to the DB.
// This can be optimized. It exists as it is now for code simplicity.
// The most common case is to create a workspace for 'Me'. Which does
// not enter this code branch.
template, err := requestTemplate(ctx, createReq, api.Database)
if err != nil {
httperror.WriteResponseError(ctx, rw, err)
return
}
// If the caller can find the organization membership in the same org
// as the template, then they can continue.
orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool {
return mem.OrganizationID == templateVersion.OrganizationID
return mem.OrganizationID == template.OrganizationID
})
if orgIndex == -1 {
httpapi.ResourceNotFound(rw)
@@ -216,113 +206,56 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
Username: member.Username,
AvatarURL: member.AvatarURL,
}
// Update workspace owner information for audit in case it changed.
taskResourceInfo.WorkspaceOwner = owner.Username
}
// Track insert from preCreateInTX.
var dbTaskTable database.TaskTable
// Ensure an audit log is created for the workspace creation event.
aReqWS, commitAuditWS := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
AdditionalFields: taskResourceInfo,
OrganizationID: templateVersion.OrganizationID,
})
defer commitAuditWS()
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
// Before creating the workspace, ensure that this task can be created.
preCreateInTX: func(ctx context.Context, tx database.Store) error {
// Create task record in the database before creating the workspace so that
// we can request that the workspace be linked to it after creation.
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
ID: uuid.New(),
OrganizationID: templateVersion.OrganizationID,
OwnerID: owner.ID,
Name: taskName,
WorkspaceID: uuid.NullUUID{}, // Will be set after workspace creation.
TemplateVersionID: templateVersion.ID,
TemplateParameters: []byte("{}"),
Prompt: req.Input,
CreatedAt: dbtime.Time(api.Clock.Now()),
})
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Internal error creating task.",
Detail: err.Error(),
})
}
return nil
},
// After the workspace is created, ensure that the task is linked to it.
postCreateInTX: func(ctx context.Context, tx database.Store, workspace database.Workspace) error {
// Update the task record with the workspace ID after creation.
dbTaskTable, err = tx.UpdateTaskWorkspaceID(ctx, database.UpdateTaskWorkspaceIDParams{
ID: dbTaskTable.ID,
WorkspaceID: uuid.NullUUID{
UUID: workspace.ID,
Valid: true,
},
})
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Internal error updating task.",
Detail: err.Error(),
})
}
return nil
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
AdditionalFields: audit.AdditionalFields{
WorkspaceOwner: owner.Username,
},
})
defer commitAudit()
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, r)
if err != nil {
httperror.WriteResponseError(ctx, rw, err)
return
}
aReq.New = dbTaskTable
// Fetch the task to get the additional columns from the view.
dbTask, err := api.Database.GetTaskByID(ctx, dbTaskTable.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, taskFromDBTaskAndWorkspace(dbTask, workspace))
task := taskFromWorkspace(w, req.Input)
httpapi.Write(ctx, rw, http.StatusCreated, task)
}
// taskFromDBTaskAndWorkspace creates a codersdk.Task response from the task
// database record and workspace.
func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) codersdk.Task {
func taskFromWorkspace(ws codersdk.Workspace, initialPrompt string) codersdk.Task {
// TODO(DanielleMaywood):
// This just picks up the first agent it discovers.
// This approach _might_ break when a task has multiple agents,
// depending on which agent was found first.
//
// We explicitly do not have support for running tasks
// inside of a sub agent at the moment, so we can be sure
// that any sub agents are not the agent we're looking for.
var taskAgentID uuid.NullUUID
var taskAgentLifecycle *codersdk.WorkspaceAgentLifecycle
var taskAgentHealth *codersdk.WorkspaceAgentHealth
// If we have an agent ID from the task, find the agent details in the
// workspace.
if dbTask.WorkspaceAgentID.Valid {
findTaskAgentLoop:
for _, resource := range ws.LatestBuild.Resources {
for _, agent := range resource.Agents {
if agent.ID == dbTask.WorkspaceAgentID.UUID {
taskAgentLifecycle = &agent.LifecycleState
taskAgentHealth = &agent.Health
break findTaskAgentLoop
}
for _, resource := range ws.LatestBuild.Resources {
for _, agent := range resource.Agents {
if agent.ParentID.Valid {
continue
}
taskAgentID = uuid.NullUUID{Valid: true, UUID: agent.ID}
taskAgentLifecycle = &agent.LifecycleState
taskAgentHealth = &agent.Health
break
}
}
// Ignore 'latest app status' if it is older than the latest build and the
// latest build is a 'start' transition. This ensures that you don't show a
// stale app status from a previous build. For stop transitions, there is
// still value in showing the latest app status.
// Ignore 'latest app status' if it is older than the latest build and the latest build is a 'start' transition.
// This ensures that you don't show a stale app status from a previous build.
// For stop transitions, there is still value in showing the latest app status.
var currentState *codersdk.TaskStateEntry
if ws.LatestAppStatus != nil {
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
@@ -335,135 +268,188 @@ func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) cod
}
}
var appID uuid.NullUUID
if ws.LatestBuild.AITaskSidebarAppID != nil {
appID = uuid.NullUUID{
Valid: true,
UUID: *ws.LatestBuild.AITaskSidebarAppID,
}
}
return codersdk.Task{
ID: dbTask.ID,
OrganizationID: dbTask.OrganizationID,
OwnerID: dbTask.OwnerID,
OwnerName: dbTask.OwnerUsername,
OwnerAvatarURL: dbTask.OwnerAvatarUrl,
Name: dbTask.Name,
ID: ws.ID,
OrganizationID: ws.OrganizationID,
OwnerID: ws.OwnerID,
OwnerName: ws.OwnerName,
Name: ws.Name,
TemplateID: ws.TemplateID,
TemplateVersionID: dbTask.TemplateVersionID,
TemplateName: ws.TemplateName,
TemplateDisplayName: ws.TemplateDisplayName,
TemplateIcon: ws.TemplateIcon,
WorkspaceID: dbTask.WorkspaceID,
WorkspaceName: ws.Name,
WorkspaceBuildNumber: dbTask.WorkspaceBuildNumber.Int32,
WorkspaceStatus: ws.LatestBuild.Status,
WorkspaceAgentID: dbTask.WorkspaceAgentID,
WorkspaceID: uuid.NullUUID{Valid: true, UUID: ws.ID},
WorkspaceBuildNumber: ws.LatestBuild.BuildNumber,
WorkspaceAgentID: taskAgentID,
WorkspaceAgentLifecycle: taskAgentLifecycle,
WorkspaceAgentHealth: taskAgentHealth,
WorkspaceAppID: dbTask.WorkspaceAppID,
InitialPrompt: dbTask.Prompt,
Status: codersdk.TaskStatus(dbTask.Status),
CurrentState: currentState,
CreatedAt: dbTask.CreatedAt,
WorkspaceAppID: appID,
CreatedAt: ws.CreatedAt,
UpdatedAt: ws.UpdatedAt,
InitialPrompt: initialPrompt,
Status: ws.LatestBuild.Status,
CurrentState: currentState,
}
}
// tasksFromWorkspaces converts a slice of API workspaces into tasks, fetching
// prompts and mapping status/state. This method enforces that only AI task
// workspaces are given.
func (api *API) tasksFromWorkspaces(ctx context.Context, apiWorkspaces []codersdk.Workspace) ([]codersdk.Task, error) {
// Fetch prompts for each workspace build and map by build ID.
buildIDs := make([]uuid.UUID, 0, len(apiWorkspaces))
for _, ws := range apiWorkspaces {
buildIDs = append(buildIDs, ws.LatestBuild.ID)
}
parameters, err := api.Database.GetWorkspaceBuildParametersByBuildIDs(ctx, buildIDs)
if err != nil {
return nil, err
}
promptsByBuildID := make(map[uuid.UUID]string, len(parameters))
for _, p := range parameters {
if p.Name == codersdk.AITaskPromptParameterName {
promptsByBuildID[p.WorkspaceBuildID] = p.Value
}
}
tasks := make([]codersdk.Task, 0, len(apiWorkspaces))
for _, ws := range apiWorkspaces {
tasks = append(tasks, taskFromWorkspace(ws, promptsByBuildID[ws.LatestBuild.ID]))
}
return tasks, nil
}
// tasksListResponse wraps a list of experimental tasks.
//
// Experimental: Response shape is experimental and may change.
type tasksListResponse struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}
// @Summary List AI tasks
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
// @ID list-tasks
// @Security CoderSessionToken
// @Tags Experimental
// @Param q query string false "Search query for filtering tasks. Supports: owner:<username/uuid/me>, organization:<org-name/uuid>, status:<status>"
// @Success 200 {object} codersdk.TasksListResponse
// @Param q query string false "Search query for filtering tasks"
// @Param after_id query string false "Return tasks after this ID for pagination"
// @Param limit query int false "Maximum number of tasks to return" minimum(1) maximum(100) default(25)
// @Param offset query int false "Offset for pagination" minimum(0) default(0)
// @Success 200 {object} coderd.tasksListResponse
// @Router /api/experimental/tasks [get]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// tasksList is an experimental endpoint to list tasks.
// tasksList is an experimental endpoint to list AI tasks by mapping
// workspaces to a task-shaped response.
func (api *API) tasksList(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
// Parse query parameters for filtering tasks.
// Support standard pagination/filters for workspaces.
page, ok := ParsePagination(rw, r)
if !ok {
return
}
queryStr := r.URL.Query().Get("q")
filter, errs := searchquery.Tasks(ctx, api.Database, queryStr, apiKey.UserID)
filter, errs := searchquery.Workspaces(ctx, api.Database, queryStr, page, api.AgentInactiveDisconnectTimeout)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid task search query.",
Message: "Invalid workspace search query.",
Validations: errs,
})
return
}
// Fetch all tasks matching the filters from the database.
dbTasks, err := api.Database.ListTasks(ctx, filter)
// Ensure that we only include AI task workspaces in the results.
filter.HasAITask = sql.NullBool{Valid: true, Bool: true}
if filter.OwnerUsername == "me" {
filter.OwnerID = apiKey.UserID
filter.OwnerUsername = ""
}
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, policy.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching tasks.",
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}
tasks, err := api.convertTasks(ctx, apiKey.UserID, dbTasks)
// Order with requester's favorites first, include summary row.
filter.RequesterID = apiKey.UserID
filter.WithSummary = true
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error converting tasks.",
Message: "Internal error fetching workspaces.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.TasksListResponse{
Tasks: tasks,
Count: len(tasks),
})
}
// convertTasks converts database tasks to API tasks, enriching them with
// workspace information.
func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks []database.Task) ([]codersdk.Task, error) {
if len(dbTasks) == 0 {
return []codersdk.Task{}, nil
if len(workspaceRows) == 0 {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",
Detail: "Workspace summary row is missing.",
})
return
}
if len(workspaceRows) == 1 {
httpapi.Write(ctx, rw, http.StatusOK, tasksListResponse{
Tasks: []codersdk.Task{},
Count: 0,
})
return
}
// Prepare to batch fetch workspaces.
workspaceIDs := make([]uuid.UUID, 0, len(dbTasks))
for _, task := range dbTasks {
if !task.WorkspaceID.Valid {
return nil, xerrors.New("task has no workspace ID")
}
workspaceIDs = append(workspaceIDs, task.WorkspaceID.UUID)
}
// Fetch workspaces for tasks that have workspaces.
workspaceRows, err := api.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{
WorkspaceIds: workspaceIDs,
})
if err != nil {
return nil, xerrors.Errorf("fetch workspaces: %w", err)
}
// Skip summary row.
workspaceRows = workspaceRows[:len(workspaceRows)-1]
workspaces := database.ConvertWorkspaceRows(workspaceRows)
// Gather associated data and convert to API workspaces.
data, err := api.workspaceData(ctx, workspaces)
if err != nil {
return nil, xerrors.Errorf("fetch workspace data: %w", err)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace resources.",
Detail: err.Error(),
})
return
}
apiWorkspaces, err := convertWorkspaces(requesterID, workspaces, data)
apiWorkspaces, err := convertWorkspaces(apiKey.UserID, workspaces, data)
if err != nil {
return nil, xerrors.Errorf("convert workspaces: %w", err)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error converting workspaces.",
Detail: err.Error(),
})
return
}
workspacesByID := make(map[uuid.UUID]codersdk.Workspace)
for _, ws := range apiWorkspaces {
workspacesByID[ws.ID] = ws
tasks, err := api.tasksFromWorkspaces(ctx, apiWorkspaces)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task prompts and states.",
Detail: err.Error(),
})
return
}
// Convert tasks to SDK format.
result := make([]codersdk.Task, 0, len(dbTasks))
for _, dbTask := range dbTasks {
task := taskFromDBTaskAndWorkspace(dbTask, workspacesByID[dbTask.WorkspaceID.UUID])
result = append(result, task)
}
return result, nil
httpapi.Write(ctx, rw, http.StatusOK, tasksListResponse{
Tasks: tasks,
Count: len(tasks),
})
}
// @Summary Get AI task by ID
@@ -472,9 +458,9 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param id path string true "Task ID" format(uuid)
// @Success 200 {object} codersdk.Task
// @Router /api/experimental/tasks/{user}/{task} [get]
// @Router /api/experimental/tasks/{user}/{id} [get]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// taskGet is an experimental endpoint to fetch a single AI task by ID
@@ -483,22 +469,25 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
task := httpmw.TaskParam(r)
if !task.WorkspaceID.Valid {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task.",
Detail: "Task workspace ID is invalid.",
idStr := chi.URLParam(r, "id")
taskID, err := uuid.Parse(idStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
})
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
// For now, taskID = workspaceID, once we have a task data model in
// the DB, we can change this lookup.
workspaceID := taskID
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace.",
Detail: err.Error(),
@@ -518,6 +507,34 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
httpapi.ResourceNotFound(rw)
return
}
if data.builds[0].HasAITask == nil || !*data.builds[0].HasAITask {
// TODO(DanielleMaywood):
// This is a temporary workaround. When a task has just been created, but
// not yet provisioned, the workspace build will not have `HasAITask` set.
//
// When we reach this code flow, it is _either_ because the workspace is
// not a task, or it is a task that has not yet been provisioned. This
// endpoint should rarely be called with a non-task workspace so we
// should be fine with this extra database call to check if it has the
// special "AI Task" parameter.
parameters, err := api.Database.GetWorkspaceBuildParameters(ctx, data.builds[0].ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build parameters.",
Detail: err.Error(),
})
return
}
_, hasAITask := slice.Find(parameters, func(t database.WorkspaceBuildParameter) bool {
return t.Name == codersdk.AITaskPromptParameterName
})
if !hasAITask {
httpapi.ResourceNotFound(rw)
return
}
}
appStatus := codersdk.WorkspaceAppStatus{}
if len(data.appStatuses) > 0 {
@@ -540,8 +557,16 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
return
}
taskResp := taskFromDBTaskAndWorkspace(task, ws)
httpapi.Write(ctx, rw, http.StatusOK, taskResp)
tasks, err := api.tasksFromWorkspaces(ctx, []codersdk.Workspace{ws})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task prompt and state.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, tasks[0])
}
// @Summary Delete AI task by ID
@@ -550,71 +575,83 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param id path string true "Task ID" format(uuid)
// @Success 202 "Task deletion initiated"
// @Router /api/experimental/tasks/{user}/{task} [delete]
// @Router /api/experimental/tasks/{user}/{id} [delete]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// taskDelete is an experimental endpoint to delete a task by ID.
// taskDelete is an experimental endpoint to delete a task by ID (workspace ID).
// It creates a delete workspace build and returns 202 Accepted if the build was
// created.
func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
task := httpmw.TaskParam(r)
now := api.Clock.Now()
if task.WorkspaceID.Valid {
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task workspace before deleting task.",
Detail: err.Error(),
})
return
}
// Construct a request to the workspace build creation handler to
// initiate deletion.
buildReq := codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete,
Reason: "Deleted via tasks API",
}
_, err = api.postWorkspaceBuildsInternal(
ctx,
apiKey,
workspace,
buildReq,
func(action policy.Action, object rbac.Objecter) bool {
return api.Authorize(r, action, object)
},
audit.WorkspaceBuildBaggageFromRequest(r),
)
if err != nil {
httperror.WriteWorkspaceBuildError(ctx, rw, err)
return
}
idStr := chi.URLParam(r, "id")
taskID, err := uuid.Parse(idStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
})
return
}
_, err := api.Database.DeleteTask(ctx, database.DeleteTaskParams{
ID: task.ID,
DeletedAt: dbtime.Time(now),
})
// For now, taskID = workspaceID, once we have a task data model in
// the DB, we can change this lookup.
workspaceID := taskID
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to delete task",
Message: "Internal error fetching workspace.",
Detail: err.Error(),
})
return
}
// Task deleted and delete build created successfully.
data, err := api.workspaceData(ctx, []database.Workspace{workspace})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace resources.",
Detail: err.Error(),
})
return
}
if len(data.builds) == 0 || len(data.templates) == 0 {
httpapi.ResourceNotFound(rw)
return
}
if data.builds[0].HasAITask == nil || !*data.builds[0].HasAITask {
httpapi.ResourceNotFound(rw)
return
}
// Construct a request to the workspace build creation handler to
// initiate deletion.
buildReq := codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete,
Reason: "Deleted via tasks API",
}
_, err = api.postWorkspaceBuildsInternal(
ctx,
apiKey,
workspace,
buildReq,
func(action policy.Action, object rbac.Objecter) bool {
return api.Authorize(r, action, object)
},
audit.WorkspaceBuildBaggageFromRequest(r),
)
if err != nil {
httperror.WriteWorkspaceBuildError(ctx, rw, err)
return
}
// Delete build created successfully.
rw.WriteHeader(http.StatusAccepted)
}
@@ -624,18 +661,26 @@ func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param id path string true "Task ID" format(uuid)
// @Param request body codersdk.TaskSendRequest true "Task input request"
// @Success 204 "Input sent successfully"
// @Router /api/experimental/tasks/{user}/{task}/send [post]
// @Router /api/experimental/tasks/{user}/{id}/send [post]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// taskSend submits task input to the task app by dialing the agent
// taskSend submits task input to the tasks sidebar app by dialing the agent
// directly over the tailnet. We enforce ApplicationConnect RBAC on the
// workspace and validate the task app health.
// workspace and validate the sidebar app health.
func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
task := httpmw.TaskParam(r)
idStr := chi.URLParam(r, "id")
taskID, err := uuid.Parse(idStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
})
return
}
var req codersdk.TaskSendRequest
if !httpapi.Read(ctx, rw, r, &req) {
@@ -648,7 +693,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
return
}
if err := api.authAndDoWithTaskAppClient(r, task, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
if err = api.authAndDoWithTaskSidebarAppClient(r, taskID, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
agentAPIClient, err := aiagentapi.NewClient(appURL.String(), aiagentapi.WithHTTPClient(client))
if err != nil {
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
@@ -698,19 +743,27 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
// @Security CoderSessionToken
// @Tags Experimental
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Param id path string true "Task ID" format(uuid)
// @Success 200 {object} codersdk.TaskLogsResponse
// @Router /api/experimental/tasks/{user}/{task}/logs [get]
// @Router /api/experimental/tasks/{user}/{id}/logs [get]
//
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
// taskLogs reads task output by dialing the agent directly over the tailnet.
// We enforce ApplicationConnect RBAC on the workspace and validate the task app health.
// We enforce ApplicationConnect RBAC on the workspace and validate the sidebar app health.
func (api *API) taskLogs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
task := httpmw.TaskParam(r)
idStr := chi.URLParam(r, "id")
taskID, err := uuid.Parse(idStr)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
})
return
}
var out codersdk.TaskLogsResponse
if err := api.authAndDoWithTaskAppClient(r, task, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
if err := api.authAndDoWithTaskSidebarAppClient(r, taskID, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
agentAPIClient, err := aiagentapi.NewClient(appURL.String(), aiagentapi.WithHTTPClient(client))
if err != nil {
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
@@ -758,40 +811,24 @@ func (api *API) taskLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, out)
}
// authAndDoWithTaskAppClient centralizes the shared logic to:
// authAndDoWithTaskSidebarAppClient centralizes the shared logic to:
//
// - Fetch the task workspace
// - Authorize ApplicationConnect on the workspace
// - Validate the AI task and task app health
// - Validate the AI task and sidebar app health
// - Dial the agent and construct an HTTP client to the apps loopback URL
//
// The provided callback receives the context, an HTTP client that dials via the
// agent, and the base app URL (as a value URL) to perform any request.
func (api *API) authAndDoWithTaskAppClient(
func (api *API) authAndDoWithTaskSidebarAppClient(
r *http.Request,
task database.Task,
taskID uuid.UUID,
do func(ctx context.Context, client *http.Client, appURL *url.URL) error,
) error {
ctx := r.Context()
if task.Status != database.TaskStatusActive {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task status must be active.",
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
})
}
if !task.WorkspaceID.Valid {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task does not have a workspace.",
})
}
if !task.WorkspaceAppID.Valid {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task does not have a workspace app.",
})
}
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
workspaceID := taskID
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
if httpapi.Is404Error(err) {
return httperror.ErrResourceNotFound
@@ -807,30 +844,65 @@ func (api *API) authAndDoWithTaskAppClient(
return httperror.ErrResourceNotFound
}
apps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, task.WorkspaceAgentID.UUID)
data, err := api.workspaceData(ctx, []database.Workspace{workspace})
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace resources.",
Detail: err.Error(),
})
}
if len(data.builds) == 0 || len(data.templates) == 0 {
return httperror.ErrResourceNotFound
}
build := data.builds[0]
if build.HasAITask == nil || !*build.HasAITask || build.AITaskSidebarAppID == nil || *build.AITaskSidebarAppID == uuid.Nil {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task is not configured with a sidebar app.",
})
}
var app *database.WorkspaceApp
for _, a := range apps {
if a.ID == task.WorkspaceAppID.UUID {
app = &a
break
// Find the sidebar app details to get the URL and validate app health.
sidebarAppID := *build.AITaskSidebarAppID
agentID, sidebarApp, ok := func() (uuid.UUID, codersdk.WorkspaceApp, bool) {
for _, res := range build.Resources {
for _, agent := range res.Agents {
for _, app := range agent.Apps {
if app.ID == sidebarAppID {
return agent.ID, app, true
}
}
}
}
return uuid.Nil, codersdk.WorkspaceApp{}, false
}()
if !ok {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task sidebar app not found in latest build.",
})
}
// Return an informative error if the app isn't healthy rather than trying
// and failing.
switch sidebarApp.Health {
case codersdk.WorkspaceAppHealthDisabled:
// No health check, pass through.
case codersdk.WorkspaceAppHealthInitializing:
return httperror.NewResponseError(http.StatusServiceUnavailable, codersdk.Response{
Message: "Task sidebar app is initializing. Try again shortly.",
})
case codersdk.WorkspaceAppHealthUnhealthy:
return httperror.NewResponseError(http.StatusServiceUnavailable, codersdk.Response{
Message: "Task sidebar app is unhealthy.",
})
}
// Build the direct app URL and dial the agent.
appURL := app.Url.String
if appURL == "" {
if sidebarApp.URL == "" {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Task app URL is not configured.",
Message: "Task sidebar app URL is not configured.",
})
}
parsedURL, err := url.Parse(appURL)
parsedURL, err := url.Parse(sidebarApp.URL)
if err != nil {
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
Message: "Internal error parsing task app URL.",
@@ -845,7 +917,7 @@ func (api *API) authAndDoWithTaskAppClient(
dialCtx, dialCancel := context.WithTimeout(ctx, time.Second*30)
defer dialCancel()
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, task.WorkspaceAgentID.UUID)
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agentID)
if err != nil {
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
Message: "Failed to reach task app endpoint.",
+264 -536
View File
File diff suppressed because it is too large Load Diff
+67 -251
View File
@@ -85,7 +85,7 @@ const docTemplate = `{
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -151,16 +151,39 @@ const docTemplate = `{
"parameters": [
{
"type": "string",
"description": "Search query for filtering tasks. Supports: owner:\u003cusername/uuid/me\u003e, organization:\u003corg-name/uuid\u003e, status:\u003cstatus\u003e",
"description": "Search query for filtering tasks",
"name": "q",
"in": "query"
},
{
"type": "string",
"description": "Return tasks after this ID for pagination",
"name": "after_id",
"in": "query"
},
{
"maximum": 100,
"minimum": 1,
"type": "integer",
"default": 25,
"description": "Maximum number of tasks to return",
"name": "limit",
"in": "query"
},
{
"minimum": 0,
"type": "integer",
"default": 0,
"description": "Offset for pagination",
"name": "offset",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.TasksListResponse"
"$ref": "#/definitions/coderd.tasksListResponse"
}
}
}
@@ -206,7 +229,7 @@ const docTemplate = `{
}
}
},
"/api/experimental/tasks/{user}/{task}": {
"/api/experimental/tasks/{user}/{id}": {
"get": {
"security": [
{
@@ -230,7 +253,7 @@ const docTemplate = `{
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -267,7 +290,7 @@ const docTemplate = `{
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -279,7 +302,7 @@ const docTemplate = `{
}
}
},
"/api/experimental/tasks/{user}/{task}/logs": {
"/api/experimental/tasks/{user}/{id}/logs": {
"get": {
"security": [
{
@@ -303,7 +326,7 @@ const docTemplate = `{
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -318,7 +341,7 @@ const docTemplate = `{
}
}
},
"/api/experimental/tasks/{user}/{task}/send": {
"/api/experimental/tasks/{user}/{id}/send": {
"post": {
"security": [
{
@@ -342,7 +365,7 @@ const docTemplate = `{
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
},
@@ -3059,45 +3082,6 @@ const docTemplate = `{
}
}
},
"/oauth2/revoke": {
"post": {
"consumes": [
"application/x-www-form-urlencoded"
],
"tags": [
"Enterprise"
],
"summary": "Revoke OAuth2 tokens (RFC 7009).",
"operationId": "oauth2-token-revocation",
"parameters": [
{
"type": "string",
"description": "Client ID for authentication",
"name": "client_id",
"in": "formData",
"required": true
},
{
"type": "string",
"description": "The token to revoke",
"name": "token",
"in": "formData",
"required": true
},
{
"type": "string",
"description": "Hint about token type (access_token or refresh_token)",
"name": "token_type_hint",
"in": "formData"
}
],
"responses": {
"200": {
"description": "Token successfully revoked"
}
}
}
},
"/oauth2/tokens": {
"post": {
"produces": [
@@ -10008,45 +9992,6 @@ const docTemplate = `{
}
}
}
},
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": [
"application/json"
],
"tags": [
"Builds"
],
"summary": "Update workspace build state",
"operationId": "update-workspace-build-state",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace build ID",
"name": "workspacebuild",
"in": "path",
"required": true
},
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/workspacebuilds/{workspacebuild}/timings": {
@@ -11679,6 +11624,20 @@ const docTemplate = `{
}
}
},
"coderd.tasksListResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"tasks": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Task"
}
}
}
},
"codersdk.ACLAvailable": {
"type": "object",
"properties": {
@@ -11707,35 +11666,12 @@ const docTemplate = `{
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -11747,10 +11683,6 @@ const docTemplate = `{
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -11795,14 +11727,14 @@ const docTemplate = `{
"codersdk.AIBridgeListInterceptionsResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"results": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeInterception"
}
},
"total": {
"type": "integer"
}
}
},
@@ -11946,12 +11878,6 @@ const docTemplate = `{
"user_id"
],
"properties": {
"allow_list": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.APIAllowListTarget"
}
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -12562,13 +12488,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -13792,13 +13711,6 @@ const docTemplate = `{
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -14355,9 +14267,11 @@ const docTemplate = `{
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -14375,7 +14289,8 @@ const docTemplate = `{
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -14982,15 +14897,7 @@ const docTemplate = `{
"enum": [
"bug",
"chat",
"docs",
"star"
]
},
"location": {
"type": "string",
"enum": [
"navbar",
"dropdown"
"docs"
]
},
"name": {
@@ -15438,9 +15345,6 @@ const docTemplate = `{
},
"token": {
"type": "string"
},
"token_revoke": {
"type": "string"
}
}
},
@@ -15540,10 +15444,7 @@ const docTemplate = `{
}
},
"registration_access_token": {
"type": "array",
"items": {
"type": "integer"
}
"type": "string"
},
"registration_client_uri": {
"type": "string"
@@ -17563,13 +17464,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -17814,9 +17708,6 @@ const docTemplate = `{
"type": "string",
"format": "uuid"
},
"owner_avatar_url": {
"type": "string"
},
"owner_id": {
"type": "string",
"format": "uuid"
@@ -17827,15 +17718,19 @@ const docTemplate = `{
"status": {
"enum": [
"pending",
"initializing",
"active",
"paused",
"unknown",
"error"
"starting",
"running",
"stopping",
"stopped",
"failed",
"canceling",
"canceled",
"deleting",
"deleted"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.TaskStatus"
"$ref": "#/definitions/codersdk.WorkspaceStatus"
}
]
},
@@ -17852,10 +17747,6 @@ const docTemplate = `{
"template_name": {
"type": "string"
},
"template_version_id": {
"type": "string",
"format": "uuid"
},
"updated_at": {
"type": "string",
"format": "date-time"
@@ -17892,28 +17783,6 @@ const docTemplate = `{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"workspace_name": {
"type": "string"
},
"workspace_status": {
"enum": [
"pending",
"starting",
"running",
"stopping",
"stopped",
"failed",
"canceling",
"canceled",
"deleting",
"deleted"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.WorkspaceStatus"
}
]
}
}
},
@@ -17998,39 +17867,6 @@ const docTemplate = `{
}
}
},
"codersdk.TaskStatus": {
"type": "string",
"enum": [
"pending",
"initializing",
"active",
"paused",
"unknown",
"error"
],
"x-enum-varnames": [
"TaskStatusPending",
"TaskStatusInitializing",
"TaskStatusActive",
"TaskStatusPaused",
"TaskStatusUnknown",
"TaskStatusError"
]
},
"codersdk.TasksListResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"tasks": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Task"
}
}
}
},
"codersdk.TelemetryConfig": {
"type": "object",
"properties": {
@@ -19198,17 +19034,6 @@ const docTemplate = `{
}
}
},
"codersdk.UpdateWorkspaceBuildStateRequest": {
"type": "object",
"properties": {
"state": {
"type": "array",
"items": {
"type": "integer"
}
}
}
},
"codersdk.UpdateWorkspaceDormancy": {
"type": "object",
"properties": {
@@ -19762,14 +19587,6 @@ const docTemplate = `{
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
@@ -20577,7 +20394,6 @@ const docTemplate = `{
"type": "object",
"properties": {
"ai_task_sidebar_app_id": {
"description": "Deprecated: This field has been replaced with ` + "`" + `Task.WorkspaceAppID` + "`" + `",
"type": "string",
"format": "uuid"
},
+67 -239
View File
@@ -65,7 +65,7 @@
}
}
},
"/aibridge/interceptions": {
"/api/experimental/aibridge/interceptions": {
"get": {
"security": [
{
@@ -125,16 +125,39 @@
"parameters": [
{
"type": "string",
"description": "Search query for filtering tasks. Supports: owner:\u003cusername/uuid/me\u003e, organization:\u003corg-name/uuid\u003e, status:\u003cstatus\u003e",
"description": "Search query for filtering tasks",
"name": "q",
"in": "query"
},
{
"type": "string",
"description": "Return tasks after this ID for pagination",
"name": "after_id",
"in": "query"
},
{
"maximum": 100,
"minimum": 1,
"type": "integer",
"default": 25,
"description": "Maximum number of tasks to return",
"name": "limit",
"in": "query"
},
{
"minimum": 0,
"type": "integer",
"default": 0,
"description": "Offset for pagination",
"name": "offset",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.TasksListResponse"
"$ref": "#/definitions/coderd.tasksListResponse"
}
}
}
@@ -178,7 +201,7 @@
}
}
},
"/api/experimental/tasks/{user}/{task}": {
"/api/experimental/tasks/{user}/{id}": {
"get": {
"security": [
{
@@ -200,7 +223,7 @@
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -235,7 +258,7 @@
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -247,7 +270,7 @@
}
}
},
"/api/experimental/tasks/{user}/{task}/logs": {
"/api/experimental/tasks/{user}/{id}/logs": {
"get": {
"security": [
{
@@ -269,7 +292,7 @@
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
}
@@ -284,7 +307,7 @@
}
}
},
"/api/experimental/tasks/{user}/{task}/send": {
"/api/experimental/tasks/{user}/{id}/send": {
"post": {
"security": [
{
@@ -306,7 +329,7 @@
"type": "string",
"format": "uuid",
"description": "Task ID",
"name": "task",
"name": "id",
"in": "path",
"required": true
},
@@ -2697,41 +2720,6 @@
}
}
},
"/oauth2/revoke": {
"post": {
"consumes": ["application/x-www-form-urlencoded"],
"tags": ["Enterprise"],
"summary": "Revoke OAuth2 tokens (RFC 7009).",
"operationId": "oauth2-token-revocation",
"parameters": [
{
"type": "string",
"description": "Client ID for authentication",
"name": "client_id",
"in": "formData",
"required": true
},
{
"type": "string",
"description": "The token to revoke",
"name": "token",
"in": "formData",
"required": true
},
{
"type": "string",
"description": "Hint about token type (access_token or refresh_token)",
"name": "token_type_hint",
"in": "formData"
}
],
"responses": {
"200": {
"description": "Token successfully revoked"
}
}
}
},
"/oauth2/tokens": {
"post": {
"produces": ["application/json"],
@@ -8870,41 +8858,6 @@
}
}
}
},
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"consumes": ["application/json"],
"tags": ["Builds"],
"summary": "Update workspace build state",
"operationId": "update-workspace-build-state",
"parameters": [
{
"type": "string",
"format": "uuid",
"description": "Workspace build ID",
"name": "workspacebuild",
"in": "path",
"required": true
},
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateWorkspaceBuildStateRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/workspacebuilds/{workspacebuild}/timings": {
@@ -10371,6 +10324,20 @@
}
}
},
"coderd.tasksListResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"tasks": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Task"
}
}
}
},
"codersdk.ACLAvailable": {
"type": "object",
"properties": {
@@ -10399,35 +10366,12 @@
}
}
},
"codersdk.AIBridgeBedrockConfig": {
"type": "object",
"properties": {
"access_key": {
"type": "string"
},
"access_key_secret": {
"type": "string"
},
"model": {
"type": "string"
},
"region": {
"type": "string"
},
"small_fast_model": {
"type": "string"
}
}
},
"codersdk.AIBridgeConfig": {
"type": "object",
"properties": {
"anthropic": {
"$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig"
},
"bedrock": {
"$ref": "#/definitions/codersdk.AIBridgeBedrockConfig"
},
"enabled": {
"type": "boolean"
},
@@ -10439,10 +10383,6 @@
"codersdk.AIBridgeInterception": {
"type": "object",
"properties": {
"ended_at": {
"type": "string",
"format": "date-time"
},
"id": {
"type": "string",
"format": "uuid"
@@ -10487,14 +10427,14 @@
"codersdk.AIBridgeListInterceptionsResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"results": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.AIBridgeInterception"
}
},
"total": {
"type": "integer"
}
}
},
@@ -10638,12 +10578,6 @@
"user_id"
],
"properties": {
"allow_list": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.APIAllowListTarget"
}
},
"created_at": {
"type": "string",
"format": "date-time"
@@ -11240,13 +11174,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -12402,13 +12329,6 @@
"name": {
"type": "string"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific to the organization the role belongs to.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific to the organization the role belongs to.",
"type": "array",
@@ -12958,9 +12878,11 @@
"web-push",
"oauth2",
"mcp-server-http",
"workspace-sharing"
"workspace-sharing",
"aibridge"
],
"x-enum-comments": {
"ExperimentAIBridge": "Enables AI Bridge functionality.",
"ExperimentAutoFillParameters": "This should not be taken out of experiments until we have redesigned the feature.",
"ExperimentExample": "This isn't used for anything.",
"ExperimentMCPServerHTTP": "Enables the MCP HTTP server functionality.",
@@ -12978,7 +12900,8 @@
"ExperimentWebPush",
"ExperimentOAuth2",
"ExperimentMCPServerHTTP",
"ExperimentWorkspaceSharing"
"ExperimentWorkspaceSharing",
"ExperimentAIBridge"
]
},
"codersdk.ExternalAPIKeyScopes": {
@@ -13566,11 +13489,7 @@
"properties": {
"icon": {
"type": "string",
"enum": ["bug", "chat", "docs", "star"]
},
"location": {
"type": "string",
"enum": ["navbar", "dropdown"]
"enum": ["bug", "chat", "docs"]
},
"name": {
"type": "string"
@@ -13988,9 +13907,6 @@
},
"token": {
"type": "string"
},
"token_revoke": {
"type": "string"
}
}
},
@@ -14090,10 +14006,7 @@
}
},
"registration_access_token": {
"type": "array",
"items": {
"type": "integer"
}
"type": "string"
},
"registration_client_uri": {
"type": "string"
@@ -16051,13 +15964,6 @@
"type": "string",
"format": "uuid"
},
"organization_member_permissions": {
"description": "OrganizationMemberPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Permission"
}
},
"organization_permissions": {
"description": "OrganizationPermissions are specific for the organization in the field 'OrganizationID' above.",
"type": "array",
@@ -16298,9 +16204,6 @@
"type": "string",
"format": "uuid"
},
"owner_avatar_url": {
"type": "string"
},
"owner_id": {
"type": "string",
"format": "uuid"
@@ -16311,15 +16214,19 @@
"status": {
"enum": [
"pending",
"initializing",
"active",
"paused",
"unknown",
"error"
"starting",
"running",
"stopping",
"stopped",
"failed",
"canceling",
"canceled",
"deleting",
"deleted"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.TaskStatus"
"$ref": "#/definitions/codersdk.WorkspaceStatus"
}
]
},
@@ -16336,10 +16243,6 @@
"template_name": {
"type": "string"
},
"template_version_id": {
"type": "string",
"format": "uuid"
},
"updated_at": {
"type": "string",
"format": "date-time"
@@ -16376,28 +16279,6 @@
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"workspace_name": {
"type": "string"
},
"workspace_status": {
"enum": [
"pending",
"starting",
"running",
"stopping",
"stopped",
"failed",
"canceling",
"canceled",
"deleting",
"deleted"
],
"allOf": [
{
"$ref": "#/definitions/codersdk.WorkspaceStatus"
}
]
}
}
},
@@ -16471,39 +16352,6 @@
}
}
},
"codersdk.TaskStatus": {
"type": "string",
"enum": [
"pending",
"initializing",
"active",
"paused",
"unknown",
"error"
],
"x-enum-varnames": [
"TaskStatusPending",
"TaskStatusInitializing",
"TaskStatusActive",
"TaskStatusPaused",
"TaskStatusUnknown",
"TaskStatusError"
]
},
"codersdk.TasksListResponse": {
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"tasks": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Task"
}
}
}
},
"codersdk.TelemetryConfig": {
"type": "object",
"properties": {
@@ -17616,17 +17464,6 @@
}
}
},
"codersdk.UpdateWorkspaceBuildStateRequest": {
"type": "object",
"properties": {
"state": {
"type": "array",
"items": {
"type": "integer"
}
}
}
},
"codersdk.UpdateWorkspaceDormancy": {
"type": "object",
"properties": {
@@ -18144,14 +17981,6 @@
"description": "OwnerName is the username of the owner of the workspace.",
"type": "string"
},
"task_id": {
"description": "TaskID, if set, indicates that the workspace is relevant to the given codersdk.Task.",
"allOf": [
{
"$ref": "#/definitions/uuid.NullUUID"
}
]
},
"template_active_version_id": {
"type": "string",
"format": "uuid"
@@ -18907,7 +18736,6 @@
"type": "object",
"properties": {
"ai_task_sidebar_app_id": {
"description": "Deprecated: This field has been replaced with `Task.WorkspaceAppID`",
"type": "string",
"format": "uuid"
},
+15 -28
View File
@@ -2,7 +2,6 @@ package apikey
import (
"crypto/sha256"
"crypto/subtle"
"fmt"
"net"
"time"
@@ -45,17 +44,12 @@ type CreateParams struct {
// database representation. It is the responsibility of the caller to insert it
// into the database.
func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error) {
// Length of an API Key ID.
keyID, err := cryptorand.String(10)
keyID, keySecret, err := generateKey()
if err != nil {
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key ID: %w", err)
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key: %w", err)
}
// Length of an API Key secret.
keySecret, hashedSecret, err := GenerateSecret(22)
if err != nil {
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key secret: %w", err)
}
hashed := sha256.Sum256([]byte(keySecret))
// Default expires at to now+lifetime, or use the configured value if not
// set.
@@ -126,7 +120,7 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
ExpiresAt: params.ExpiresAt.UTC(),
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
HashedSecret: hashedSecret,
HashedSecret: hashed[:],
LoginType: params.LoginType,
Scopes: scopes,
AllowList: params.AllowList,
@@ -134,24 +128,17 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
}, token, nil
}
func GenerateSecret(length int) (secret string, hashed []byte, err error) {
secret, err = cryptorand.String(length)
// generateKey a new ID and secret for an API key.
func generateKey() (id string, secret string, err error) {
// Length of an API Key ID.
id, err = cryptorand.String(10)
if err != nil {
return "", nil, err
return "", "", err
}
hash := HashSecret(secret)
return secret, hash, nil
}
// ValidateHash compares a secret against an expected hashed secret.
func ValidateHash(hashedSecret []byte, secret string) bool {
hash := HashSecret(secret)
return subtle.ConstantTimeCompare(hashedSecret, hash) == 1
}
// HashSecret is the single function used to hash API key secrets.
// Use this to ensure a consistent hashing algorithm.
func HashSecret(secret string) []byte {
hash := sha256.Sum256([]byte(secret))
return hash[:]
// Length of an API Key secret.
secret, err = cryptorand.String(22)
if err != nil {
return "", "", err
}
return id, secret, nil
}
+3 -16
View File
@@ -1,6 +1,7 @@
package apikey_test
import (
"crypto/sha256"
"strings"
"testing"
"time"
@@ -125,8 +126,8 @@ func TestGenerate(t *testing.T) {
require.Equal(t, key.ID, keytokens[0])
// Assert that the hashed secret is correct.
equal := apikey.ValidateHash(key.HashedSecret, keytokens[1])
require.True(t, equal, "valid secret")
hashed := sha256.Sum256([]byte(keytokens[1]))
assert.ElementsMatch(t, hashed, key.HashedSecret)
assert.Equal(t, tc.params.UserID, key.UserID)
assert.WithinDuration(t, dbtime.Now(), key.CreatedAt, time.Second*5)
@@ -172,17 +173,3 @@ func TestGenerate(t *testing.T) {
})
}
}
// TestInvalid just ensures the false case is asserted by some tests.
// Otherwise, a function that just `returns true` might pass all tests incorrectly.
func TestInvalid(t *testing.T) {
t.Parallel()
require.Falsef(t, apikey.ValidateHash([]byte{}, "secret"), "empty hash")
secret, hash, err := apikey.GenerateSecret(10)
require.NoError(t, err)
require.Falsef(t, apikey.ValidateHash(hash, secret+"_"), "different secret")
require.Falsef(t, apikey.ValidateHash(hash[:len(hash)-1], secret), "different hash length")
}
-6
View File
@@ -51,8 +51,6 @@ func TestTokenCRUD(t *testing.T) {
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
require.Len(t, keys[0].AllowList, 1)
require.Equal(t, "*:*", keys[0].AllowList[0].String())
// no update
@@ -88,8 +86,6 @@ func TestTokenScoped(t *testing.T) {
require.EqualValues(t, len(keys), 1)
require.Contains(t, res.Key, keys[0].ID)
require.Equal(t, keys[0].Scope, codersdk.APIKeyScopeApplicationConnect)
require.Len(t, keys[0].AllowList, 1)
require.Equal(t, "*:*", keys[0].AllowList[0].String())
}
// Ensure backward-compat: when a token is created using the legacy singular
@@ -136,8 +132,6 @@ func TestTokenLegacySingularScopeCompat(t *testing.T) {
require.Len(t, keys, 1)
require.Equal(t, tc.scope, keys[0].Scope)
require.ElementsMatch(t, keys[0].Scopes, tc.scopes)
require.Len(t, keys[0].AllowList, 1)
require.Equal(t, "*:*", keys[0].AllowList[0].String())
})
}
}
+2 -2
View File
@@ -509,11 +509,11 @@ func (api *API) auditLogResourceLink(ctx context.Context, alog database.GetAudit
if err != nil {
return ""
}
user, err := api.Database.GetUserByID(ctx, task.OwnerID)
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
if err != nil {
return ""
}
return fmt.Sprintf("/tasks/%s/%s", user.Username, task.ID)
return fmt.Sprintf("/tasks/%s/%s", workspace.OwnerName, task.Name)
default:
return ""
+10 -11
View File
@@ -50,13 +50,6 @@ func TestCheckPermissions(t *testing.T) {
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
},
Action: "read",
},
readMyself: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceUser,
@@ -65,10 +58,16 @@ func TestCheckPermissions(t *testing.T) {
Action: "read",
},
readOwnWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OwnerID: "me",
},
Action: "read",
},
readOrgWorkspaces: {
Object: codersdk.AuthorizationObject{
ResourceType: codersdk.ResourceWorkspace,
OrganizationID: adminUser.OrganizationID.String(),
OwnerID: "me",
},
Action: "read",
},
@@ -93,9 +92,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: adminUser.UserID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -105,9 +104,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: orgAdminUser.ID,
Check: map[string]bool{
readAllUsers: true,
readOrgWorkspaces: true,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: true,
updateSpecificTemplate: true,
},
},
@@ -117,9 +116,9 @@ func TestCheckPermissions(t *testing.T) {
UserID: memberUser.ID,
Check: map[string]bool{
readAllUsers: false,
readOrgWorkspaces: false,
readMyself: true,
readOwnWorkspaces: true,
readOrgWorkspaces: false,
updateSpecificTemplate: false,
},
},
+8 -172
View File
@@ -776,6 +776,10 @@ func TestExecutorWorkspaceAutostopNoWaitChangedMyMind(t *testing.T) {
}
func TestExecutorAutostartMultipleOK(t *testing.T) {
if !dbtestutil.WillUsePostgres() {
t.Skip(`This test only really works when using a "real" database, similar to a HA setup`)
}
t.Parallel()
var (
@@ -1255,6 +1259,10 @@ func TestNotifications(t *testing.T) {
func TestExecutorPrebuilds(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
// Prebuild workspaces should not be autostopped when the deadline is reached.
// After being claimed, the workspace should stop at the deadline.
t.Run("OnlyStopsAfterClaimed", func(t *testing.T) {
@@ -1764,175 +1772,3 @@ func TestExecutorAutostartSkipsWhenNoProvisionersAvailable(t *testing.T) {
assert.Len(t, stats.Transitions, 1, "should create builds when provisioners are available")
}
func TestExecutorTaskWorkspace(t *testing.T) {
t.Parallel()
createTaskTemplate := func(t *testing.T, client *codersdk.Client, orgID uuid.UUID, ctx context.Context, defaultTTL time.Duration) codersdk.Template {
t.Helper()
taskAppID := uuid.New()
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{HasAiTasks: true},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Agents: []*proto.Agent{
{
Id: uuid.NewString(),
Name: "dev",
Auth: &proto.Agent_Token{
Token: uuid.NewString(),
},
Apps: []*proto.App{
{
Id: taskAppID.String(),
Slug: "task-app",
},
},
},
},
},
},
AiTasks: []*proto.AITask{
{
AppId: taskAppID.String(),
},
},
},
},
},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
if defaultTTL > 0 {
_, err := client.UpdateTemplateMeta(ctx, template.ID, codersdk.UpdateTemplateMeta{
DefaultTTLMillis: defaultTTL.Milliseconds(),
})
require.NoError(t, err)
}
return template
}
createTaskWorkspace := func(t *testing.T, client *codersdk.Client, template codersdk.Template, ctx context.Context, input string) codersdk.Workspace {
t.Helper()
exp := codersdk.NewExperimentalClient(client)
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: input,
})
require.NoError(t, err)
require.True(t, task.WorkspaceID.Valid, "task should have a workspace")
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
return workspace
}
t.Run("Autostart", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
sched = mustSchedule(t, "CRON_TZ=UTC 0 * * * *")
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 0)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostart")
// Given: The task workspace has an autostart schedule
err := client.UpdateWorkspaceAutostart(ctx, workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
Schedule: ptr.Ref(sched.String()),
})
require.NoError(t, err)
// Given: That the workspace is in a stopped state.
workspace = coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the scheduled time
go func() {
tickTime := sched.Next(workspace.LatestBuild.CreatedAt)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a start transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStart, stats.Transitions[workspace.ID], "should autostart the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
t.Run("Autostop", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
tickCh = make(chan time.Time)
statsCh = make(chan autobuild.Stats)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
AutobuildTicker: tickCh,
IncludeProvisionerDaemon: true,
AutobuildStats: statsCh,
})
admin = coderdtest.CreateFirstUser(t, client)
)
// Given: A task workspace with an 8 hour deadline
template := createTaskTemplate(t, client, admin.OrganizationID, ctx, 8*time.Hour)
workspace := createTaskWorkspace(t, client, template, ctx, "test task for autostop")
// Given: The workspace is currently running
workspace = coderdtest.MustWorkspace(t, client, workspace.ID)
require.Equal(t, codersdk.WorkspaceTransitionStart, workspace.LatestBuild.Transition)
require.NotZero(t, workspace.LatestBuild.Deadline, "workspace should have a deadline for autostop")
p, err := coderdtest.GetProvisionerForTags(db, time.Now(), workspace.OrganizationID, map[string]string{})
require.NoError(t, err)
// When: the autobuild executor ticks after the deadline
go func() {
tickTime := workspace.LatestBuild.Deadline.Time.Add(time.Minute)
coderdtest.UpdateProvisionerLastSeenAt(t, db, p.ID, tickTime)
tickCh <- tickTime
close(tickCh)
}()
// Then: We expect to see a stop transition
stats := <-statsCh
require.Len(t, stats.Transitions, 1, "lifecycle executor should transition the task workspace")
assert.Contains(t, stats.Transitions, workspace.ID, "task workspace should be in transitions")
assert.Equal(t, database.WorkspaceTransitionStop, stats.Transitions[workspace.ID], "should autostop the workspace")
require.Empty(t, stats.Errors, "should have no errors when managing task workspaces")
})
}
+4 -19
View File
@@ -985,16 +985,6 @@ func New(options *Options) *API {
r.Post("/", api.postOAuth2ProviderAppToken())
})
// RFC 7009 Token Revocation Endpoint
r.Route("/revoke", func(r chi.Router) {
r.Use(
// RFC 7009 endpoint uses OAuth2 client authentication, not API key
httpmw.AsAuthzSystem(httpmw.ExtractOAuth2ProviderAppWithOAuth2Errors(options.Database)),
)
// POST /revoke is the standard OAuth2 token revocation endpoint per RFC 7009
r.Post("/", api.revokeOAuth2Token())
})
// RFC 7591 Dynamic Client Registration - Public endpoint
r.Post("/register", api.postOAuth2ClientRegistration())
@@ -1032,15 +1022,11 @@ func New(options *Options) *API {
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
r.Get("/{id}", api.taskGet)
r.Delete("/{id}", api.taskDelete)
r.Post("/{id}/send", api.taskSend)
r.Get("/{id}/logs", api.taskLogs)
r.Post("/", api.tasksCreate)
r.Route("/{task}", func(r chi.Router) {
r.Use(httpmw.ExtractTaskParam(options.Database))
r.Get("/", api.taskGet)
r.Delete("/", api.taskDelete)
r.Post("/send", api.taskSend)
r.Get("/logs", api.taskLogs)
})
})
})
r.Route("/mcp", func(r chi.Router) {
@@ -1496,7 +1482,6 @@ func New(options *Options) *API {
r.Get("/parameters", api.workspaceBuildParameters)
r.Get("/resources", api.workspaceBuildResourcesDeprecated)
r.Get("/state", api.workspaceBuildState)
r.Put("/state", api.workspaceBuildUpdateState)
r.Get("/timings", api.workspaceBuildTimings)
})
r.Route("/authcheck", func(r chi.Router) {
-2
View File
@@ -6,7 +6,6 @@ type CheckConstraint string
// CheckConstraint enums.
const (
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
@@ -14,7 +13,6 @@ const (
CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents
CheckWorkspaceBuildsAiTaskSidebarAppIDRequired CheckConstraint = "workspace_builds_ai_task_sidebar_app_id_required" // workspace_builds
CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds
CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks
CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters
CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events
)
+8 -23
View File
@@ -51,13 +51,6 @@ func ListLazy[F any, T any](convert func(F) T) func(list []F) []T {
}
}
func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget {
return codersdk.APIAllowListTarget{
Type: codersdk.RBACResource(entry.Type),
ID: entry.ID,
}
}
type ExternalAuthMeta struct {
Authenticated bool
ValidateError string
@@ -390,9 +383,6 @@ func OAuth2ProviderApp(accessURL *url.URL, dbApp database.OAuth2ProviderApp) cod
}).String(),
// We do not currently support DeviceAuth.
DeviceAuth: "",
TokenRevoke: accessURL.ResolveReference(&url.URL{
Path: "/oauth2/revoke",
}).String(),
},
}
}
@@ -714,13 +704,12 @@ func RBACRole(role rbac.Role) codersdk.Role {
orgPerms := role.ByOrgID[slim.OrganizationID]
return codersdk.Role{
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
OrganizationMemberPermissions: List(orgPerms.Member, RBACPermission),
Name: slim.Name,
OrganizationID: slim.OrganizationID,
DisplayName: slim.DisplayName,
SitePermissions: List(role.Site, RBACPermission),
OrganizationPermissions: List(orgPerms.Org, RBACPermission),
UserPermissions: List(role.User, RBACPermission),
}
}
@@ -735,8 +724,8 @@ func Role(role database.CustomRole) codersdk.Role {
OrganizationID: orgID,
DisplayName: role.DisplayName,
SitePermissions: List(role.SitePermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
OrganizationPermissions: List(role.OrgPermissions, Permission),
UserPermissions: List(role.UserPermissions, Permission),
}
}
@@ -963,7 +952,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
// created_at ASC
return sdkToolUsages[i].CreatedAt.Before(sdkToolUsages[j].CreatedAt)
})
intc := codersdk.AIBridgeInterception{
return codersdk.AIBridgeInterception{
ID: interception.ID,
Initiator: MinimalUserFromVisibleUser(initiator),
Provider: interception.Provider,
@@ -974,10 +963,6 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
UserPrompts: sdkUserPrompts,
ToolUsages: sdkToolUsages,
}
if interception.EndedAt.Valid {
intc.EndedAt = &interception.EndedAt.Time
}
return intc
}
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
+4
View File
@@ -85,6 +85,10 @@ func TestNestedInTx(t *testing.T) {
func testSQLDB(t testing.TB) *sql.DB {
t.Helper()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
connection, err := dbtestutil.Open(t)
require.NoError(t, err)
+20 -165
View File
@@ -217,10 +217,10 @@ var (
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
// Unsure why provisionerd needs update and read personal
rbac.ResourceUser.Type: {policy.ActionRead, policy.ActionReadPersonal, policy.ActionUpdatePersonal},
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
rbac.ResourceWorkspaceDormant.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStop},
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop, policy.ActionCreateAgent},
// Provisionerd needs to read, update, and delete tasks associated with workspaces.
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionDelete},
// Provisionerd needs to read and update tasks associated with workspaces.
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceApiKey.Type: {policy.WildcardSymbol},
// When org scoped provisioner credentials are implemented,
// this can be reduced to read a specific org.
@@ -254,7 +254,6 @@ var (
rbac.ResourceFile.Type: {policy.ActionRead}, // Required to read terraform files
rbac.ResourceNotificationMessage.Type: {policy.ActionCreate, policy.ActionRead},
rbac.ResourceSystem.Type: {policy.WildcardSymbol},
rbac.ResourceTask.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceTemplate.Type: {policy.ActionRead, policy.ActionUpdate},
rbac.ResourceUser.Type: {policy.ActionRead},
rbac.ResourceWorkspace.Type: {policy.ActionDelete, policy.ActionRead, policy.ActionUpdate, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop},
@@ -396,13 +395,11 @@ var (
Identifier: rbac.RoleIdentifier{Name: "subagentapi"},
DisplayName: "Sub Agent API",
Site: []rbac.Permission{},
User: []rbac.Permission{},
User: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
ByOrgID: map[string]rbac.OrgPermissions{
orgID.String(): {
Member: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate, policy.ActionCreateAgent, policy.ActionDeleteAgent},
}),
},
orgID.String(): {},
},
},
}),
@@ -449,34 +446,6 @@ var (
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemOAuth2 = rbac.Subject{
Type: rbac.SubjectTypeSystemOAuth,
FriendlyName: "System OAuth2",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "system-oauth2"},
DisplayName: "System OAuth2",
Site: rbac.Permissions(map[string][]policy.Action{
// OAuth2 resources - full CRUD permissions
rbac.ResourceOauth2App.Type: rbac.ResourceOauth2App.AvailableActions(),
rbac.ResourceOauth2AppSecret.Type: rbac.ResourceOauth2AppSecret.AvailableActions(),
rbac.ResourceOauth2AppCodeToken.Type: rbac.ResourceOauth2AppCodeToken.AvailableActions(),
// API key permissions needed for OAuth2 token revocation
rbac.ResourceApiKey.Type: {policy.ActionRead, policy.ActionDelete},
// Minimal read permissions that might be needed for OAuth2 operations
rbac.ResourceUser.Type: {policy.ActionRead},
rbac.ResourceOrganization.Type: {policy.ActionRead},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectSystemReadProvisionerDaemons = rbac.Subject{
Type: rbac.SubjectTypeSystemReadProvisionerDaemons,
FriendlyName: "Provisioner Daemons Reader",
@@ -674,12 +643,6 @@ func AsSystemRestricted(ctx context.Context) context.Context {
return As(ctx, subjectSystemRestricted)
}
// AsSystemOAuth2 returns a context with an actor that has permissions
// required for OAuth2 provider operations (token revocation, device codes, registration).
func AsSystemOAuth2(ctx context.Context) context.Context {
return As(ctx, subjectSystemOAuth2)
}
// AsSystemReadProvisionerDaemons returns a context with an actor that has permissions
// to read provisioner daemons.
func AsSystemReadProvisionerDaemons(ctx context.Context) context.Context {
@@ -1293,17 +1256,14 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
return xerrors.Errorf("invalid role: %w", err)
}
if len(rbacRole.ByOrgID) > 0 && (len(rbacRole.Site) > 0 || len(rbacRole.User) > 0) {
// This is a choice to keep roles simple. If we allow mixing site and org
// scoped perms, then knowing who can do what gets more complicated. Roles
// should either be entirely org-scoped or entirely unrelated to
// organizations.
return xerrors.Errorf("invalid custom role, cannot assign both org-scoped and site/user permissions at the same time")
if len(rbacRole.ByOrgID) > 0 && len(rbacRole.Site) > 0 {
// This is a choice to keep roles simple. If we allow mixing site and org scoped perms, then knowing who can
// do what gets more complicated.
return xerrors.Errorf("invalid custom role, cannot assign both org and site permissions at the same time")
}
if len(rbacRole.ByOrgID) > 1 {
// Again to avoid more complexity in our roles. Roles are limited to one
// organization.
// Again to avoid more complexity in our roles
return xerrors.Errorf("invalid custom role, cannot assign permissions to more than 1 org at a time")
}
@@ -1319,18 +1279,7 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
for _, orgPerm := range perms.Org {
err := q.customRoleEscalationCheck(ctx, act, orgPerm, rbac.Object{OrgID: orgID, Type: orgPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: org: %w", orgID, err)
}
}
for _, memberPerm := range perms.Member {
// The person giving the permission should still be required to have
// the permissions throughout the org in order to give individuals the
// same permission among their own resources, since the role can be given
// to anyone. The `Owner` is intentionally omitted from the `Object` to
// enforce this.
err := q.customRoleEscalationCheck(ctx, act, memberPerm, rbac.Object{OrgID: orgID, Type: memberPerm.ResourceType})
if err != nil {
return xerrors.Errorf("org=%q: member: %w", orgID, err)
return xerrors.Errorf("org=%q: %w", orgID, err)
}
}
}
@@ -1348,8 +1297,8 @@ func (q *querier) customRoleCheck(ctx context.Context, role database.CustomRole)
func (q *querier) authorizeProvisionerJob(ctx context.Context, job database.ProvisionerJob) error {
switch job.Type {
case database.ProvisionerJobTypeWorkspaceBuild:
// Authorized call to get workspace build. If we can read the build, we can
// read the job.
// Authorized call to get workspace build. If we can read the build, we
// can read the job.
_, err := q.GetWorkspaceBuildByJobID(ctx, job.ID)
if err != nil {
return xerrors.Errorf("fetch related workspace build: %w", err)
@@ -1392,8 +1341,8 @@ func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.Activi
}
func (q *querier) AllUserIDs(ctx context.Context, includeSystem bool) ([]uuid.UUID, error) {
// Although this technically only reads users, only system-related functions
// should be allowed to call this.
// Although this technically only reads users, only system-related functions should be
// allowed to call this.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
@@ -1412,8 +1361,8 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
}
func (q *querier) BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg database.BatchUpdateWorkspaceLastUsedAtParams) error {
// Could be any workspace and checking auth to each workspace is overkill for
// the purpose of this function.
// Could be any workspace and checking auth to each workspace is overkill for the purpose
// of this function.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceWorkspace.All()); err != nil {
return err
}
@@ -1441,13 +1390,6 @@ func (q *querier) BulkMarkNotificationMessagesSent(ctx context.Context, arg data
return q.db.BulkMarkNotificationMessagesSent(ctx, arg)
}
func (q *querier) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, err
}
return q.db.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
}
func (q *querier) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
empty := database.ClaimPrebuiltWorkspaceRow{}
@@ -1536,13 +1478,6 @@ func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.Coun
return q.db.CountInProgressPrebuilds(ctx)
}
func (q *querier) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
return nil, err
}
return q.db.CountPendingNonActivePrebuilds(ctx)
}
func (q *querier) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceInboxNotification.WithOwner(userID.String())); err != nil {
return 0, err
@@ -1747,13 +1682,6 @@ func (q *querier) DeleteOldProvisionerDaemons(ctx context.Context) error {
return q.db.DeleteOldProvisionerDaemons(ctx)
}
func (q *querier) DeleteOldTelemetryLocks(ctx context.Context, beforeTime time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
}
return q.db.DeleteOldTelemetryLocks(ctx, beforeTime)
}
func (q *querier) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return err
@@ -1836,19 +1764,6 @@ func (q *querier) DeleteTailnetTunnel(ctx context.Context, arg database.DeleteTa
return q.db.DeleteTailnetTunnel(ctx, arg)
}
func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
task, err := q.db.GetTaskByID(ctx, arg.ID)
if err != nil {
return database.TaskTable{}, err
}
if err := q.authorizeContext(ctx, policy.ActionDelete, task.RBACObject()); err != nil {
return database.TaskTable{}, err
}
return q.db.DeleteTask(ctx, arg)
}
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
@@ -2513,7 +2428,7 @@ func (q *querier) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) (d
return q.db.GetOAuth2ProviderAppByID(ctx, id)
}
func (q *querier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
func (q *querier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2App); err != nil {
return database.OAuth2ProviderApp{}, err
}
@@ -2649,13 +2564,6 @@ func (q *querier) GetOrganizationsByUserID(ctx context.Context, userID database.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationsByUserID)(ctx, userID)
}
func (q *querier) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOrganization.All()); err != nil {
return nil, err
}
return q.db.GetOrganizationsWithPrebuildStatus(ctx, arg)
}
func (q *querier) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
version, err := q.db.GetTemplateVersionByJobID(ctx, jobID)
if err != nil {
@@ -4250,13 +4158,6 @@ func (q *querier) InsertTelemetryItemIfNotExists(ctx context.Context, arg databa
return q.db.InsertTelemetryItemIfNotExists(ctx, arg)
}
func (q *querier) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
if err := q.authorizeContext(ctx, policy.ActionCreate, rbac.ResourceSystem); err != nil {
return err
}
return q.db.InsertTelemetryLock(ctx, arg)
}
func (q *querier) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
obj := rbac.ResourceTemplate.InOrg(arg.OrganizationID)
if err := q.authorizeContext(ctx, policy.ActionCreate, obj); err != nil {
@@ -4568,13 +4469,6 @@ func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.Li
return q.db.ListAuthorizedAIBridgeInterceptions(ctx, arg, prep)
}
func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil {
return nil, err
}
return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
}
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
// This function is a system function until we implement a join for aibridge interceptions.
// Matches the behavior of the workspaces listing endpoint.
@@ -4763,13 +4657,6 @@ func (q *querier) UnfavoriteWorkspace(ctx context.Context, id uuid.UUID) error {
return update(q.log, q.auth, fetch, q.db.UnfavoriteWorkspace)(ctx, id)
}
func (q *querier) UpdateAIBridgeInterceptionEnded(ctx context.Context, params database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
if err := q.authorizeAIBridgeInterceptionAction(ctx, policy.ActionUpdate, params.ID); err != nil {
return database.AIBridgeInterception{}, err
}
return q.db.UpdateAIBridgeInterceptionEnded(ctx, params)
}
func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateAPIKeyByIDParams) (database.APIKey, error) {
return q.db.GetAPIKeyByID(ctx, arg.ID)
@@ -4941,14 +4828,6 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
}
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
return []database.UpdatePrebuildProvisionerJobWithCancelRow{}, err
}
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
}
func (q *querier) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
preset, err := q.db.GetPresetByID(ctx, arg.PresetID)
if err != nil {
@@ -5096,30 +4975,6 @@ func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg
return q.db.UpdateTailnetPeerStatusByCoordinator(ctx, arg)
}
func (q *querier) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
// An actor is allowed to update the workspace ID of a task if they are the
// owner of the task and workspace or have the appropriate permissions.
task, err := q.db.GetTaskByID(ctx, arg.ID)
if err != nil {
return database.TaskTable{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, task.RBACObject()); err != nil {
return database.TaskTable{}, err
}
ws, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID.UUID)
if err != nil {
return database.TaskTable{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, ws.RBACObject()); err != nil {
return database.TaskTable{}, err
}
return q.db.UpdateTaskWorkspaceID(ctx, arg)
}
func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) {
return q.db.GetTemplateByID(ctx, arg.ID)
+4 -83
View File
@@ -641,19 +641,6 @@ func (s *MethodTestSuite) TestProvisionerJob() {
dbm.EXPECT().UpdateProvisionerJobWithCancelByID(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(v.RBACObject(tpl), []policy.Action{policy.ActionRead, policy.ActionUpdate}).Returns()
}))
s.Run("UpdatePrebuildProvisionerJobWithCancel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpdatePrebuildProvisionerJobWithCancelParams{
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
Now: dbtime.Now(),
}
canceledJobs := []database.UpdatePrebuildProvisionerJobWithCancelRow{
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), WorkspaceID: uuid.New(), TemplateID: uuid.New(), TemplateVersionPresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(canceledJobs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(canceledJobs)
}))
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
org := testutil.Fake(s.T(), faker, database.Organization{})
org2 := testutil.Fake(s.T(), faker, database.Organization{})
@@ -2375,16 +2362,6 @@ func (s *MethodTestSuite) TestTasks() {
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
check.Args(task.ID).Asserts(task, policy.ActionRead).Returns(task)
}))
s.Run("DeleteTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
task := testutil.Fake(s.T(), faker, database.Task{})
arg := database.DeleteTaskParams{
ID: task.ID,
DeletedAt: dbtime.Now(),
}
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
dbm.EXPECT().DeleteTask(gomock.Any(), arg).Return(database.TaskTable{}, nil).AnyTimes()
check.Args(arg).Asserts(task, policy.ActionDelete).Returns(database.TaskTable{})
}))
s.Run("InsertTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
tpl := testutil.Fake(s.T(), faker, database.Template{})
tv := testutil.Fake(s.T(), faker, database.TemplateVersion{
@@ -2418,20 +2395,6 @@ func (s *MethodTestSuite) TestTasks() {
check.Args(arg).Asserts(task, policy.ActionUpdate).Returns(database.TaskWorkspaceApp{})
}))
s.Run("UpdateTaskWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
task := testutil.Fake(s.T(), faker, database.Task{})
ws := testutil.Fake(s.T(), faker, database.Workspace{})
arg := database.UpdateTaskWorkspaceIDParams{
ID: task.ID,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
}
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
dbm.EXPECT().UpdateTaskWorkspaceID(gomock.Any(), arg).Return(database.TaskTable{}, nil).AnyTimes()
check.Args(arg).Asserts(task, policy.ActionUpdate, ws, policy.ActionUpdate).Returns(database.TaskTable{})
}))
s.Run("GetTaskByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
task := testutil.Fake(s.T(), faker, database.Task{})
task.WorkspaceID = uuid.NullUUID{UUID: uuid.New(), Valid: true}
@@ -2983,6 +2946,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().GetParameterSchemasByJobID(gomock.Any(), jobID).Return([]database.ParameterSchema{}, nil).AnyTimes()
check.Args(jobID).
Asserts(tpl, policy.ActionRead).
ErrorsWithInMemDB(sql.ErrNoRows).
Returns([]database.ParameterSchema{})
}))
s.Run("GetWorkspaceAppsByAgentIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -3225,7 +3189,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}))
s.Run("GetAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetAppSecurityKey(gomock.Any()).Return("", sql.ErrNoRows).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows)
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).ErrorsWithPG(sql.ErrNoRows)
}))
s.Run("UpsertAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertAppSecurityKey(gomock.Any(), "foo").Return(nil).AnyTimes()
@@ -3759,14 +3723,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
dbm.EXPECT().GetPrebuildMetrics(gomock.Any()).Return([]database.GetPrebuildMetricsRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
}))
s.Run("GetOrganizationsWithPrebuildStatus", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetOrganizationsWithPrebuildStatusParams{
UserID: uuid.New(),
GroupName: "test",
}
dbm.EXPECT().GetOrganizationsWithPrebuildStatus(gomock.Any(), arg).Return([]database.GetOrganizationsWithPrebuildStatusRow{}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceOrganization.All(), policy.ActionRead)
}))
s.Run("GetPrebuildsSettings", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetPrebuildsSettings(gomock.Any()).Return("{}", nil).AnyTimes()
check.Args().Asserts()
@@ -3779,10 +3735,6 @@ func (s *MethodTestSuite) TestPrebuilds() {
dbm.EXPECT().CountInProgressPrebuilds(gomock.Any()).Return([]database.CountInProgressPrebuildsRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
}))
s.Run("CountPendingNonActivePrebuilds", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().CountPendingNonActivePrebuilds(gomock.Any()).Return([]database.CountPendingNonActivePrebuildsRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
}))
s.Run("GetPresetsAtFailureLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetPresetsAtFailureLimit(gomock.Any(), int64(0)).Return([]database.GetPresetsAtFailureLimitRow{}, nil).AnyTimes()
check.Args(int64(0)).Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights)
@@ -3950,9 +3902,9 @@ func (s *MethodTestSuite) TestOAuth2ProviderApps() {
}))
s.Run("GetOAuth2ProviderAppByRegistrationToken", s.Subtest(func(db database.Store, check *expects) {
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{
RegistrationAccessToken: []byte("test-token"),
RegistrationAccessToken: sql.NullString{String: "test-token", Valid: true},
})
check.Args([]byte("test-token")).Asserts(rbac.ResourceOauth2App, policy.ActionRead).Returns(app)
check.Args(sql.NullString{String: "test-token", Valid: true}).Asserts(rbac.ResourceOauth2App, policy.ActionRead).Returns(app)
}))
}
@@ -4628,35 +4580,4 @@ func (s *MethodTestSuite) TestAIBridge() {
db.EXPECT().ListAIBridgeToolUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeToolUsage{}, nil).AnyTimes()
check.Args(ids).Asserts(rbac.ResourceSystem, policy.ActionRead).Returns([]database.AIBridgeToolUsage{})
}))
s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
intcID := uuid.UUID{1}
params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID}
intc := testutil.Fake(s.T(), faker, database.AIBridgeInterception{ID: intcID})
db.EXPECT().GetAIBridgeInterceptionByID(gomock.Any(), intcID).Return(intc, nil).AnyTimes() // Validation.
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), params).Return(intc, nil).AnyTimes()
check.Args(params).Asserts(intc, policy.ActionUpdate).Returns(intc)
}))
}
func (s *MethodTestSuite) TestTelemetry() {
s.Run("InsertTelemetryLock", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().InsertTelemetryLock(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(database.InsertTelemetryLockParams{}).Asserts(rbac.ResourceSystem, policy.ActionCreate)
}))
s.Run("DeleteOldTelemetryLocks", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().DeleteOldTelemetryLocks(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
check.Args(time.Time{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("ListAIBridgeInterceptionsTelemetrySummaries", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().ListAIBridgeInterceptionsTelemetrySummaries(gomock.Any(), gomock.Any()).Return([]database.ListAIBridgeInterceptionsTelemetrySummariesRow{}, nil).AnyTimes()
check.Args(database.ListAIBridgeInterceptionsTelemetrySummariesParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
s.Run("CalculateAIBridgeInterceptionsTelemetrySummary", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
db.EXPECT().CalculateAIBridgeInterceptionsTelemetrySummary(gomock.Any(), gomock.Any()).Return(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow{}, nil).AnyTimes()
check.Args(database.CalculateAIBridgeInterceptionsTelemetrySummaryParams{}).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead)
}))
}
+18
View File
@@ -430,6 +430,24 @@ func (m *expects) Errors(err error) *expects {
return m
}
// ErrorsWithPG is optional. If it is never called, it will not be asserted.
// It will only be asserted if the test is running with a Postgres database.
func (m *expects) ErrorsWithPG(err error) *expects {
if dbtestutil.WillUsePostgres() {
return m.Errors(err)
}
return m
}
// ErrorsWithInMemDB is optional. If it is never called, it will not be asserted.
// It will only be asserted if the test is running with an in-memory database.
func (m *expects) ErrorsWithInMemDB(err error) *expects {
if !dbtestutil.WillUsePostgres() {
return m.Errors(err)
}
return m
}
func (m *expects) FailSystemObjectChecks() *expects {
return m.WithSuccessAuthorizer(func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error {
if obj.Type == rbac.ResourceSystem.Type {
+21 -161
View File
@@ -41,7 +41,6 @@ type WorkspaceResponse struct {
Build database.WorkspaceBuild
AgentToken string
TemplateVersionResponse
Task database.Task
}
// WorkspaceBuildBuilder generates workspace builds and associated
@@ -56,9 +55,12 @@ type WorkspaceBuildBuilder struct {
resources []*sdkproto.Resource
params []database.WorkspaceBuildParameter
agentToken string
jobStatus database.ProvisionerJobStatus
dispo workspaceBuildDisposition
taskAppID uuid.UUID
taskSeed database.TaskTable
}
type workspaceBuildDisposition struct {
starting bool
}
// WorkspaceBuild generates a workspace build for the provided workspace.
@@ -117,28 +119,21 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
return b
}
func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sdkproto.App) WorkspaceBuildBuilder {
//nolint:revive // returns modified struct
b.taskSeed = taskSeed
if appSeed == nil {
appSeed = &sdkproto.App{}
}
var err error
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
//nolint: revive // returns modified struct
b.taskAppID, err = uuid.Parse(takeFirst(appSeed.Id, uuid.NewString()))
require.NoError(b.t, err)
b.taskAppID = uuid.New()
if seed == nil {
seed = &sdkproto.App{}
}
return b.Params(database.WorkspaceBuildParameter{
Name: codersdk.AITaskPromptParameterName,
Value: b.taskSeed.Prompt,
Value: "list me",
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
a[0].Apps = []*sdkproto.App{
{
Id: b.taskAppID.String(),
Slug: takeFirst(appSeed.Slug, "task-app"),
Url: takeFirst(appSeed.Url, ""),
Id: takeFirst(seed.Id, b.taskAppID.String()),
Slug: takeFirst(seed.Slug, "vcode"),
Url: takeFirst(seed.Url, ""),
},
}
return a
@@ -146,17 +141,8 @@ func (b WorkspaceBuildBuilder) WithTask(taskSeed database.TaskTable, appSeed *sd
}
func (b WorkspaceBuildBuilder) Starting() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusRunning
return b
}
func (b WorkspaceBuildBuilder) Pending() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusPending
return b
}
func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
b.jobStatus = database.ProvisionerJobStatusCanceled
//nolint: revive // returns modified struct
b.dispo.starting = true
return b
}
@@ -166,19 +152,6 @@ func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
// Workspace will be optionally populated if no ID is set on the provided
// workspace.
func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
var resp WorkspaceResponse
// Use transaction, like real wsbuilder.
err := b.db.InTx(func(tx database.Store) error {
//nolint:revive // calls do on modified struct
b.db = tx
resp = b.doInTX()
return nil
}, nil)
require.NoError(b.t, err)
return resp
}
func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
b.t.Helper()
jobID := uuid.New()
b.seed.ID = uuid.New()
@@ -222,45 +195,14 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
if b.ws.ID == uuid.Nil {
// nolint: revive
b.ws = dbgen.Workspace(b.t, b.db, b.ws)
resp.Workspace = b.ws
b.logger.Debug(context.Background(), "created workspace",
slog.F("name", b.ws.Name),
slog.F("workspace_id", b.ws.ID))
slog.F("name", resp.Workspace.Name),
slog.F("workspace_id", resp.Workspace.ID))
}
resp.Workspace = b.ws
b.seed.WorkspaceID = b.ws.ID
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
// If a task was requested, ensure it exists and is associated with this
// workspace.
if b.taskAppID != uuid.Nil {
b.logger.Debug(context.Background(), "creating or updating task", "task_id", b.taskSeed.ID)
b.taskSeed.OrganizationID = takeFirst(b.taskSeed.OrganizationID, b.ws.OrganizationID)
b.taskSeed.OwnerID = takeFirst(b.taskSeed.OwnerID, b.ws.OwnerID)
b.taskSeed.Name = takeFirst(b.taskSeed.Name, b.ws.Name)
b.taskSeed.WorkspaceID = uuid.NullUUID{UUID: takeFirst(b.taskSeed.WorkspaceID.UUID, b.ws.ID), Valid: true}
b.taskSeed.TemplateVersionID = takeFirst(b.taskSeed.TemplateVersionID, b.seed.TemplateVersionID)
// Try to fetch existing task and update its workspace ID.
if task, err := b.db.GetTaskByID(ownerCtx, b.taskSeed.ID); err == nil {
if !task.WorkspaceID.Valid {
b.logger.Info(context.Background(), "updating task workspace id", "task_id", b.taskSeed.ID, "workspace_id", b.ws.ID)
_, err = b.db.UpdateTaskWorkspaceID(ownerCtx, database.UpdateTaskWorkspaceIDParams{
ID: b.taskSeed.ID,
WorkspaceID: uuid.NullUUID{UUID: b.ws.ID, Valid: true},
})
require.NoError(b.t, err, "update task workspace id")
} else if task.WorkspaceID.UUID != b.ws.ID {
require.Fail(b.t, "task already has a workspace id, mismatch", task.WorkspaceID.UUID, b.ws.ID)
}
} else if errors.Is(err, sql.ErrNoRows) {
task := dbgen.Task(b.t, b.db, b.taskSeed)
b.taskSeed.ID = task.ID
b.logger.Info(context.Background(), "created new task", "task_id", b.taskSeed.ID)
} else {
require.NoError(b.t, err, "get task by id")
}
}
// Create a provisioner job for the build!
payload, err := json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: b.seed.ID,
@@ -285,11 +227,7 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
require.NoError(b.t, err, "insert job")
b.logger.Debug(context.Background(), "inserted provisioner job", slog.F("job_id", job.ID))
switch b.jobStatus {
case database.ProvisionerJobStatusPending:
// Provisioner jobs are created in 'pending' status
b.logger.Debug(context.Background(), "pending the provisioner job")
case database.ProvisionerJobStatusRunning:
if b.dispo.starting {
// might need to do this multiple times if we got a template version
// import job as well
b.logger.Debug(context.Background(), "looping to acquire provisioner job")
@@ -313,23 +251,7 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
break
}
}
case database.ProvisionerJobStatusCanceled:
// Set provisioner job status to 'canceled'
b.logger.Debug(context.Background(), "canceling the provisioner job")
err = b.db.UpdateProvisionerJobWithCancelByID(ownerCtx, database.UpdateProvisionerJobWithCancelByIDParams{
ID: jobID,
CanceledAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
CompletedAt: sql.NullTime{
Time: dbtime.Now(),
Valid: true,
},
})
require.NoError(b.t, err, "cancel job")
default:
// By default, consider jobs in 'succeeded' status
} else {
b.logger.Debug(context.Background(), "completing the provisioner job")
err = b.db.UpdateProvisionerJobWithCompleteByID(ownerCtx, database.UpdateProvisionerJobWithCompleteByIDParams{
ID: job.ID,
@@ -351,35 +273,6 @@ func (b WorkspaceBuildBuilder) doInTX() WorkspaceResponse {
slog.F("workspace_id", resp.Workspace.ID),
slog.F("build_number", resp.Build.BuildNumber))
// If this is a task workspace, link it to the workspace build.
task, err := b.db.GetTaskByWorkspaceID(ownerCtx, resp.Workspace.ID)
if err != nil {
if b.taskAppID != uuid.Nil {
require.Fail(b.t, "task app configured but failed to get task by workspace id", err)
}
} else {
if b.taskAppID == uuid.Nil {
require.Fail(b.t, "task app not configured but workspace is a task workspace")
}
app := mustWorkspaceAppByWorkspaceAndBuildAndAppID(ownerCtx, b.t, b.db, resp.Workspace.ID, resp.Build.BuildNumber, b.taskAppID)
_, err = b.db.UpsertTaskWorkspaceApp(ownerCtx, database.UpsertTaskWorkspaceAppParams{
TaskID: task.ID,
WorkspaceBuildNumber: resp.Build.BuildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: app.AgentID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
})
require.NoError(b.t, err, "upsert task workspace app")
b.logger.Debug(context.Background(), "linked task to workspace build",
slog.F("task_id", task.ID),
slog.F("build_number", resp.Build.BuildNumber))
// Update task after linking.
task, err = b.db.GetTaskByID(ownerCtx, task.ID)
require.NoError(b.t, err, "get task by id")
resp.Task = task
}
for i := range b.params {
b.params[i].WorkspaceBuildID = resp.Build.ID
}
@@ -650,12 +543,6 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
t.params[i] = dbgen.TemplateVersionParameter(t.t, t.db, param)
}
// Update response with template and version
if resp.Template.ID == uuid.Nil && version.TemplateID.Valid {
template, err := t.db.GetTemplateByID(ownerCtx, version.TemplateID.UUID)
require.NoError(t.t, err)
resp.Template = template
}
resp.TemplateVersion = version
return resp
}
@@ -736,30 +623,3 @@ func takeFirst[Value comparable](values ...Value) Value {
return v != empty
})
}
// mustWorkspaceAppByWorkspaceAndBuildAndAppID finds a workspace app by
// workspace ID, build number, and app ID. It returns the workspace app
// if found, otherwise fails the test.
func mustWorkspaceAppByWorkspaceAndBuildAndAppID(ctx context.Context, t testing.TB, db database.Store, workspaceID uuid.UUID, buildNumber int32, appID uuid.UUID) database.WorkspaceApp {
t.Helper()
agents, err := db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
WorkspaceID: workspaceID,
BuildNumber: buildNumber,
})
require.NoError(t, err, "get workspace agents")
require.NotEmpty(t, agents, "no agents found for workspace")
for _, agent := range agents {
apps, err := db.GetWorkspaceAppsByAgentID(ctx, agent.ID)
require.NoError(t, err, "get workspace apps")
for _, app := range apps {
if app.ID == appID {
return app
}
}
}
require.FailNow(t, "could not find workspace app", "workspaceID=%s buildNumber=%d appID=%s", workspaceID, buildNumber, appID)
return database.WorkspaceApp{} // Unreachable.
}
+10 -17
View File
@@ -3,6 +3,7 @@ package dbgen
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
@@ -19,7 +20,6 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -161,8 +161,8 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func(*database.InsertAPIKeyParams)) (key database.APIKey, token string) {
id, _ := cryptorand.String(10)
secret, hashed, err := apikey.GenerateSecret(22)
require.NoError(t, err)
secret, _ := cryptorand.String(22)
hashed := sha256.Sum256([]byte(secret))
ip := seed.IPAddress
if !ip.Valid {
@@ -179,7 +179,7 @@ func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func
ID: takeFirst(seed.ID, id),
// 0 defaults to 86400 at the db layer
LifetimeSeconds: takeFirst(seed.LifetimeSeconds, 0),
HashedSecret: takeFirstSlice(seed.HashedSecret, hashed),
HashedSecret: takeFirstSlice(seed.HashedSecret, hashed[:]),
IPAddress: ip,
UserID: takeFirst(seed.UserID, uuid.New()),
LastUsed: takeFirst(seed.LastUsed, dbtime.Now()),
@@ -194,7 +194,7 @@ func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func
for _, fn := range munge {
fn(&params)
}
key, err = db.InsertAPIKey(genCtx, params)
key, err := db.InsertAPIKey(genCtx, params)
require.NoError(t, err, "insert api key")
return key, fmt.Sprintf("%s-%s", key.ID, secret)
}
@@ -980,15 +980,16 @@ func WorkspaceResourceMetadatums(t testing.TB, db database.Store, seed database.
}
func WorkspaceProxy(t testing.TB, db database.Store, orig database.WorkspaceProxy) (database.WorkspaceProxy, string) {
secret, hashedSecret, err := apikey.GenerateSecret(64)
secret, err := cryptorand.HexString(64)
require.NoError(t, err, "generate secret")
hashedSecret := sha256.Sum256([]byte(secret))
proxy, err := db.InsertWorkspaceProxy(genCtx, database.InsertWorkspaceProxyParams{
ID: takeFirst(orig.ID, uuid.New()),
Name: takeFirst(orig.Name, testutil.GetRandomName(t)),
DisplayName: takeFirst(orig.DisplayName, testutil.GetRandomName(t)),
Icon: takeFirst(orig.Icon, testutil.GetRandomName(t)),
TokenHashedSecret: hashedSecret,
TokenHashedSecret: hashedSecret[:],
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
DerpEnabled: takeFirst(orig.DerpEnabled, false),
@@ -1258,7 +1259,7 @@ func OAuth2ProviderApp(t testing.TB, db database.Store, seed database.OAuth2Prov
Jwks: seed.Jwks, // pqtype.NullRawMessage{} is not comparable, use existing value
SoftwareID: takeFirst(seed.SoftwareID, sql.NullString{}),
SoftwareVersion: takeFirst(seed.SoftwareVersion, sql.NullString{}),
RegistrationAccessToken: seed.RegistrationAccessToken,
RegistrationAccessToken: takeFirst(seed.RegistrationAccessToken, sql.NullString{}),
RegistrationClientUri: takeFirst(seed.RegistrationClientUri, sql.NullString{}),
})
require.NoError(t, err, "insert oauth2 app")
@@ -1495,7 +1496,7 @@ func ClaimPrebuild(
return claimedWorkspace
}
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams, endedAt *time.Time) database.AIBridgeInterception {
func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertAIBridgeInterceptionParams) database.AIBridgeInterception {
interception, err := db.InsertAIBridgeInterception(genCtx, database.InsertAIBridgeInterceptionParams{
ID: takeFirst(seed.ID, uuid.New()),
InitiatorID: takeFirst(seed.InitiatorID, uuid.New()),
@@ -1504,13 +1505,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")),
StartedAt: takeFirst(seed.StartedAt, dbtime.Now()),
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
ID: interception.ID,
EndedAt: *endedAt,
})
require.NoError(t, err, "insert aibridge interception")
}
require.NoError(t, err, "insert aibridge interception")
return interception
}
@@ -1576,7 +1570,6 @@ func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Tas
}
task, err := db.InsertTask(genCtx, database.InsertTaskParams{
ID: takeFirst(orig.ID, uuid.New()),
OrganizationID: orig.OrganizationID,
OwnerID: orig.OwnerID,
Name: takeFirst(orig.Name, taskname.GenerateFallback()),
+2 -71
View File
@@ -5,6 +5,7 @@ package dbmetrics
import (
"context"
"database/sql"
"slices"
"time"
@@ -158,13 +159,6 @@ func (m queryMetricsStore) BulkMarkNotificationMessagesSent(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
start := time.Now()
r0, r1 := m.s.CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg)
m.queryLatencies.WithLabelValues("CalculateAIBridgeInterceptionsTelemetrySummary").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
start := time.Now()
r0, r1 := m.s.ClaimPrebuiltWorkspace(ctx, arg)
@@ -221,13 +215,6 @@ func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]data
return r0, r1
}
func (m queryMetricsStore) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
start := time.Now()
r0, r1 := m.s.CountPendingNonActivePrebuilds(ctx)
m.queryLatencies.WithLabelValues("CountPendingNonActivePrebuilds").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
start := time.Now()
r0, r1 := m.s.CountUnreadInboxNotificationsByUserID(ctx, userID)
@@ -410,13 +397,6 @@ func (m queryMetricsStore) DeleteOldProvisionerDaemons(ctx context.Context) erro
return r0
}
func (m queryMetricsStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldTelemetryLocks(ctx, periodEndingAtBefore)
m.queryLatencies.WithLabelValues("DeleteOldTelemetryLocks").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, arg time.Time) error {
start := time.Now()
r0 := m.s.DeleteOldWorkspaceAgentLogs(ctx, arg)
@@ -494,13 +474,6 @@ func (m queryMetricsStore) DeleteTailnetTunnel(ctx context.Context, arg database
return r0, r1
}
func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
start := time.Now()
r0, r1 := m.s.DeleteTask(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteTask").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecret(ctx, id)
@@ -1124,7 +1097,7 @@ func (m queryMetricsStore) GetOAuth2ProviderAppByID(ctx context.Context, id uuid
return r0, r1
}
func (m queryMetricsStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
func (m queryMetricsStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
start := time.Now()
r0, r1 := m.s.GetOAuth2ProviderAppByRegistrationToken(ctx, registrationAccessToken)
m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppByRegistrationToken").Observe(time.Since(start).Seconds())
@@ -1243,13 +1216,6 @@ func (m queryMetricsStore) GetOrganizationsByUserID(ctx context.Context, userID
return organizations, err
}
func (m queryMetricsStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
start := time.Now()
r0, r1 := m.s.GetOrganizationsWithPrebuildStatus(ctx, arg)
m.queryLatencies.WithLabelValues("GetOrganizationsWithPrebuildStatus").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
start := time.Now()
schemas, err := m.s.GetParameterSchemasByJobID(ctx, jobID)
@@ -2538,13 +2504,6 @@ func (m queryMetricsStore) InsertTelemetryItemIfNotExists(ctx context.Context, a
return r0
}
func (m queryMetricsStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
start := time.Now()
r0 := m.s.InsertTelemetryLock(ctx, arg)
m.queryLatencies.WithLabelValues("InsertTelemetryLock").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
start := time.Now()
err := m.s.InsertTemplate(ctx, arg)
@@ -2762,13 +2721,6 @@ func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg da
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg)
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptionsTelemetrySummaries").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
start := time.Now()
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
@@ -2923,13 +2875,6 @@ func (m queryMetricsStore) UnfavoriteWorkspace(ctx context.Context, arg uuid.UUI
return r0
}
func (m queryMetricsStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, id database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
start := time.Now()
r0, r1 := m.s.UpdateAIBridgeInterceptionEnded(ctx, id)
m.queryLatencies.WithLabelValues("UpdateAIBridgeInterceptionEnded").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
start := time.Now()
err := m.s.UpdateAPIKeyByID(ctx, arg)
@@ -3049,13 +2994,6 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
return r0
}
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
start := time.Now()
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
start := time.Now()
r0 := m.s.UpdatePresetPrebuildStatus(ctx, arg)
@@ -3126,13 +3064,6 @@ func (m queryMetricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Cont
return r0
}
func (m queryMetricsStore) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
start := time.Now()
r0, r1 := m.s.UpdateTaskWorkspaceID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateTaskWorkspaceID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m queryMetricsStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
start := time.Now()
err := m.s.UpdateTemplateACLByID(ctx, arg)
+2 -149
View File
@@ -11,6 +11,7 @@ package dbmock
import (
context "context"
sql "database/sql"
reflect "reflect"
time "time"
@@ -190,21 +191,6 @@ func (mr *MockStoreMockRecorder) BulkMarkNotificationMessagesSent(ctx, arg any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkMarkNotificationMessagesSent", reflect.TypeOf((*MockStore)(nil).BulkMarkNotificationMessagesSent), ctx, arg)
}
// CalculateAIBridgeInterceptionsTelemetrySummary mocks base method.
func (m *MockStore) CalculateAIBridgeInterceptionsTelemetrySummary(ctx context.Context, arg database.CalculateAIBridgeInterceptionsTelemetrySummaryParams) (database.CalculateAIBridgeInterceptionsTelemetrySummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CalculateAIBridgeInterceptionsTelemetrySummary", ctx, arg)
ret0, _ := ret[0].(database.CalculateAIBridgeInterceptionsTelemetrySummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CalculateAIBridgeInterceptionsTelemetrySummary indicates an expected call of CalculateAIBridgeInterceptionsTelemetrySummary.
func (mr *MockStoreMockRecorder) CalculateAIBridgeInterceptionsTelemetrySummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CalculateAIBridgeInterceptionsTelemetrySummary", reflect.TypeOf((*MockStore)(nil).CalculateAIBridgeInterceptionsTelemetrySummary), ctx, arg)
}
// ClaimPrebuiltWorkspace mocks base method.
func (m *MockStore) ClaimPrebuiltWorkspace(ctx context.Context, arg database.ClaimPrebuiltWorkspaceParams) (database.ClaimPrebuiltWorkspaceRow, error) {
m.ctrl.T.Helper()
@@ -367,21 +353,6 @@ func (mr *MockStoreMockRecorder) CountInProgressPrebuilds(ctx any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountInProgressPrebuilds", reflect.TypeOf((*MockStore)(nil).CountInProgressPrebuilds), ctx)
}
// CountPendingNonActivePrebuilds mocks base method.
func (m *MockStore) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountPendingNonActivePrebuilds", ctx)
ret0, _ := ret[0].([]database.CountPendingNonActivePrebuildsRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountPendingNonActivePrebuilds indicates an expected call of CountPendingNonActivePrebuilds.
func (mr *MockStoreMockRecorder) CountPendingNonActivePrebuilds(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountPendingNonActivePrebuilds", reflect.TypeOf((*MockStore)(nil).CountPendingNonActivePrebuilds), ctx)
}
// CountUnreadInboxNotificationsByUserID mocks base method.
func (m *MockStore) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
m.ctrl.T.Helper()
@@ -751,20 +722,6 @@ func (mr *MockStoreMockRecorder) DeleteOldProvisionerDaemons(ctx any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).DeleteOldProvisionerDaemons), ctx)
}
// DeleteOldTelemetryLocks mocks base method.
func (m *MockStore) DeleteOldTelemetryLocks(ctx context.Context, periodEndingAtBefore time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldTelemetryLocks", ctx, periodEndingAtBefore)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteOldTelemetryLocks indicates an expected call of DeleteOldTelemetryLocks.
func (mr *MockStoreMockRecorder) DeleteOldTelemetryLocks(ctx, periodEndingAtBefore any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldTelemetryLocks", reflect.TypeOf((*MockStore)(nil).DeleteOldTelemetryLocks), ctx, periodEndingAtBefore)
}
// DeleteOldWorkspaceAgentLogs mocks base method.
func (m *MockStore) DeleteOldWorkspaceAgentLogs(ctx context.Context, threshold time.Time) error {
m.ctrl.T.Helper()
@@ -923,21 +880,6 @@ func (mr *MockStoreMockRecorder) DeleteTailnetTunnel(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetTunnel", reflect.TypeOf((*MockStore)(nil).DeleteTailnetTunnel), ctx, arg)
}
// DeleteTask mocks base method.
func (m *MockStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTask", ctx, arg)
ret0, _ := ret[0].(database.TaskTable)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteTask indicates an expected call of DeleteTask.
func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg)
}
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
@@ -2368,7 +2310,7 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppByID(ctx, id any) *gomock.C
}
// GetOAuth2ProviderAppByRegistrationToken mocks base method.
func (m *MockStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
func (m *MockStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOAuth2ProviderAppByRegistrationToken", ctx, registrationAccessToken)
ret0, _ := ret[0].(database.OAuth2ProviderApp)
@@ -2622,21 +2564,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationsByUserID(ctx, arg any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationsByUserID), ctx, arg)
}
// GetOrganizationsWithPrebuildStatus mocks base method.
func (m *MockStore) GetOrganizationsWithPrebuildStatus(ctx context.Context, arg database.GetOrganizationsWithPrebuildStatusParams) ([]database.GetOrganizationsWithPrebuildStatusRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrganizationsWithPrebuildStatus", ctx, arg)
ret0, _ := ret[0].([]database.GetOrganizationsWithPrebuildStatusRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrganizationsWithPrebuildStatus indicates an expected call of GetOrganizationsWithPrebuildStatus.
func (mr *MockStoreMockRecorder) GetOrganizationsWithPrebuildStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationsWithPrebuildStatus", reflect.TypeOf((*MockStore)(nil).GetOrganizationsWithPrebuildStatus), ctx, arg)
}
// GetParameterSchemasByJobID mocks base method.
func (m *MockStore) GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]database.ParameterSchema, error) {
m.ctrl.T.Helper()
@@ -5436,20 +5363,6 @@ func (mr *MockStoreMockRecorder) InsertTelemetryItemIfNotExists(ctx, arg any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryItemIfNotExists", reflect.TypeOf((*MockStore)(nil).InsertTelemetryItemIfNotExists), ctx, arg)
}
// InsertTelemetryLock mocks base method.
func (m *MockStore) InsertTelemetryLock(ctx context.Context, arg database.InsertTelemetryLockParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertTelemetryLock", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// InsertTelemetryLock indicates an expected call of InsertTelemetryLock.
func (mr *MockStoreMockRecorder) InsertTelemetryLock(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertTelemetryLock", reflect.TypeOf((*MockStore)(nil).InsertTelemetryLock), ctx, arg)
}
// InsertTemplate mocks base method.
func (m *MockStore) InsertTemplate(ctx context.Context, arg database.InsertTemplateParams) error {
m.ctrl.T.Helper()
@@ -5905,21 +5818,6 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptions(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptions), ctx, arg)
}
// ListAIBridgeInterceptionsTelemetrySummaries mocks base method.
func (m *MockStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg database.ListAIBridgeInterceptionsTelemetrySummariesParams) ([]database.ListAIBridgeInterceptionsTelemetrySummariesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListAIBridgeInterceptionsTelemetrySummaries", ctx, arg)
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsTelemetrySummariesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListAIBridgeInterceptionsTelemetrySummaries indicates an expected call of ListAIBridgeInterceptionsTelemetrySummaries.
func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg)
}
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
m.ctrl.T.Helper()
@@ -6289,21 +6187,6 @@ func (mr *MockStoreMockRecorder) UnfavoriteWorkspace(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnfavoriteWorkspace", reflect.TypeOf((*MockStore)(nil).UnfavoriteWorkspace), ctx, id)
}
// UpdateAIBridgeInterceptionEnded mocks base method.
func (m *MockStore) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg database.UpdateAIBridgeInterceptionEndedParams) (database.AIBridgeInterception, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAIBridgeInterceptionEnded", ctx, arg)
ret0, _ := ret[0].(database.AIBridgeInterception)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateAIBridgeInterceptionEnded indicates an expected call of UpdateAIBridgeInterceptionEnded.
func (mr *MockStoreMockRecorder) UpdateAIBridgeInterceptionEnded(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAIBridgeInterceptionEnded", reflect.TypeOf((*MockStore)(nil).UpdateAIBridgeInterceptionEnded), ctx, arg)
}
// UpdateAPIKeyByID mocks base method.
func (m *MockStore) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error {
m.ctrl.T.Helper()
@@ -6554,21 +6437,6 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationDeletedByID", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationDeletedByID), ctx, arg)
}
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
ret0, _ := ret[0].([]database.UpdatePrebuildProvisionerJobWithCancelRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdatePrebuildProvisionerJobWithCancel indicates an expected call of UpdatePrebuildProvisionerJobWithCancel.
func (mr *MockStoreMockRecorder) UpdatePrebuildProvisionerJobWithCancel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePrebuildProvisionerJobWithCancel", reflect.TypeOf((*MockStore)(nil).UpdatePrebuildProvisionerJobWithCancel), ctx, arg)
}
// UpdatePresetPrebuildStatus mocks base method.
func (m *MockStore) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
m.ctrl.T.Helper()
@@ -6710,21 +6578,6 @@ func (mr *MockStoreMockRecorder) UpdateTailnetPeerStatusByCoordinator(ctx, arg a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTailnetPeerStatusByCoordinator", reflect.TypeOf((*MockStore)(nil).UpdateTailnetPeerStatusByCoordinator), ctx, arg)
}
// UpdateTaskWorkspaceID mocks base method.
func (m *MockStore) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateTaskWorkspaceID", ctx, arg)
ret0, _ := ret[0].(database.TaskTable)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateTaskWorkspaceID indicates an expected call of UpdateTaskWorkspaceID.
func (mr *MockStoreMockRecorder) UpdateTaskWorkspaceID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWorkspaceID", reflect.TypeOf((*MockStore)(nil).UpdateTaskWorkspaceID), ctx, arg)
}
// UpdateTemplateACLByID mocks base method.
func (m *MockStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
m.ctrl.T.Helper()
-10
View File
@@ -24,12 +24,6 @@ const (
// but we won't touch the `connection_logs` table.
maxAuditLogConnectionEventAge = 90 * 24 * time.Hour // 90 days
auditLogConnectionEventBatchSize = 1000
// Telemetry heartbeats are used to deduplicate events across replicas. We
// don't need to persist heartbeat rows for longer than 24 hours, as they
// are only used for deduplication across replicas. The time needs to be
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
)
// New creates a new periodically purging database instance.
@@ -77,10 +71,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, clk quartz.
if err := tx.ExpirePrebuildsAPIKeys(ctx, dbtime.Time(start)); err != nil {
return xerrors.Errorf("failed to expire prebuilds user api keys: %w", err)
}
deleteOldTelemetryLocksBefore := start.Add(-maxTelemetryHeartbeatAge)
if err := tx.DeleteOldTelemetryLocks(ctx, deleteOldTelemetryLocksBefore); err != nil {
return xerrors.Errorf("failed to delete old telemetry locks: %w", err)
}
deleteOldAuditLogConnectionEventsBefore := start.Add(-maxAuditLogConnectionEventAge)
if err := tx.DeleteOldAuditLogConnectionEvents(ctx, database.DeleteOldAuditLogConnectionEventsParams{
-53
View File
@@ -704,56 +704,3 @@ func TestExpireOldAPIKeys(t *testing.T) {
// Out of an abundance of caution, we do not expire explicitly named prebuilds API keys.
assertKeyActive(namedPrebuildsAPIKey.ID)
}
func TestDeleteOldTelemetryHeartbeats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, _, sqlDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
clk := quartz.NewMock(t)
now := clk.Now().UTC()
// Insert telemetry heartbeats.
err := db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-25 * time.Hour), // should be purged
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now.Add(-23 * time.Hour), // should be kept
})
require.NoError(t, err)
err = db.InsertTelemetryLock(ctx, database.InsertTelemetryLockParams{
EventType: "aibridge_interceptions_summary",
PeriodEndingAt: now, // should be kept
})
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, clk)
defer closer.Close()
<-done // doTick() has now run.
require.Eventuallyf(t, func() bool {
// We use an SQL queries directly here because we don't expose queries
// for deleting heartbeats in the application code.
var totalCount int
err := sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks;
`).Scan(&totalCount)
assert.NoError(t, err)
var oldCount int
err = sqlDB.QueryRowContext(ctx, `
SELECT COUNT(*) FROM telemetry_locks WHERE period_ending_at < $1;
`, now.Add(-24*time.Hour)).Scan(&oldCount)
assert.NoError(t, err)
// Expect 2 heartbeats remaining and none older than 24 hours.
t.Logf("eventually: total count: %d, old count: %d", totalCount, oldCount)
return totalCount == 2 && oldCount == 0
}, testutil.WaitShort, testutil.IntervalFast, "it should delete old telemetry heartbeats")
}
@@ -52,6 +52,10 @@ func (w *wrapUpsertDB) UpsertTemplateUsageStats(ctx context.Context) error {
func TestRollup_TwoInstancesUseLocking(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("Skipping test; only works with PostgreSQL.")
}
db, ps := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
logger := testutil.Logger(t)
+6 -51
View File
@@ -6,8 +6,6 @@ import (
_ "embed"
"fmt"
"os"
"runtime"
"strings"
"sync"
"time"
@@ -47,8 +45,6 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
host = defaultConnectionParams.Host
port = defaultConnectionParams.Port
)
packageName := getTestPackageName(t)
testName := t.Name()
// Use a time-based prefix to make it easier to find the database
// when debugging.
@@ -59,9 +55,9 @@ func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error
}
dbName := now + "_" + dbSuffix
// TODO: add package and test name
_, err = b.coderTestingDB.Exec(
"INSERT INTO test_databases (name, process_uuid, test_package, test_name) VALUES ($1, $2, $3, $4)",
dbName, b.uuid, packageName, testName)
"INSERT INTO test_databases (name, process_uuid) VALUES ($1, $2)", dbName, b.uuid)
if err != nil {
return ConnectionParams{}, xerrors.Errorf("insert test_database row: %w", err)
}
@@ -108,10 +104,10 @@ func (b *Broker) clean(t TBSubset, dbName string) func() {
func (b *Broker) init(t TBSubset) error {
b.Lock()
defer b.Unlock()
b.refCount++
t.Cleanup(b.decRef)
if b.coderTestingDB != nil {
// already initialized
b.refCount++
t.Cleanup(b.decRef)
return nil
}
@@ -128,8 +124,8 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("open postgres connection: %w", err)
}
// coderTestingSQLInit is idempotent, so we can run it every time.
_, err = coderTestingDB.Exec(coderTestingSQLInit)
// creating the db can succeed even if the database doesn't exist. Ping it to find out.
err = coderTestingDB.Ping()
var pqErr *pq.Error
if xerrors.As(err, &pqErr) && pqErr.Code == "3D000" {
// database does not exist.
@@ -149,8 +145,6 @@ func (b *Broker) init(t TBSubset) error {
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
}
b.coderTestingDB = coderTestingDB
b.refCount++
t.Cleanup(b.decRef)
if b.uuid == uuid.Nil {
b.uuid = uuid.New()
@@ -192,42 +186,3 @@ func (b *Broker) decRef() {
b.coderTestingDB = nil
}
}
// getTestPackageName returns the package name of the test that called it.
func getTestPackageName(t TBSubset) string {
packageName := "unknown"
// Ask runtime.Callers for up to 100 program counters, including runtime.Callers itself.
pc := make([]uintptr, 100)
n := runtime.Callers(0, pc)
if n == 0 {
// No PCs available. This can happen if the first argument to
// runtime.Callers is large.
//
// Return now to avoid processing the zero Frame that would
// otherwise be returned by frames.Next below.
t.Logf("could not determine test package name: no PCs available")
return packageName
}
pc = pc[:n] // pass only valid pcs to runtime.CallersFrames
frames := runtime.CallersFrames(pc)
// Loop to get frames.
// A fixed number of PCs can expand to an indefinite number of Frames.
for {
frame, more := frames.Next()
if strings.HasPrefix(frame.Function, "github.com/coder/coder/v2/") {
packageName = strings.SplitN(strings.TrimPrefix(frame.Function, "github.com/coder/coder/v2/"), ".", 2)[0]
}
if strings.HasPrefix(frame.Function, "testing") {
break
}
// Check whether there are more frames to process after this one.
if !more {
break
}
}
return packageName
}
@@ -1,13 +0,0 @@
package dbtestutil
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetTestPackageName(t *testing.T) {
t.Parallel()
packageName := getTestPackageName(t)
require.Equal(t, "coderd/database/dbtestutil", packageName)
}
@@ -1,6 +1,3 @@
BEGIN TRANSACTION;
SELECT pg_advisory_xact_lock(7283699);
CREATE TABLE IF NOT EXISTS test_databases (
name text PRIMARY KEY,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -9,10 +6,3 @@ CREATE TABLE IF NOT EXISTS test_databases (
);
CREATE INDEX IF NOT EXISTS test_databases_process_uuid ON test_databases (process_uuid, dropped_at);
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_name text;
COMMENT ON COLUMN test_databases.test_name IS 'Name of the test that created the database';
ALTER TABLE test_databases ADD COLUMN IF NOT EXISTS test_package text;
COMMENT ON COLUMN test_databases.test_package IS 'Package of the test that created the database';
COMMIT;
+11
View File
@@ -23,6 +23,13 @@ import (
"github.com/coder/coder/v2/testutil"
)
// WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub.
// TODO(hugodutka): since we removed the in-memory database, this is always true,
// and we need to remove this function. https://github.com/coder/internal/issues/758
func WillUsePostgres() bool {
return true
}
type options struct {
fixedTimezone string
dumpOnFailure bool
@@ -68,6 +75,10 @@ func withReturnSQLDB(f func(*sql.DB)) Option {
func NewDBWithSQLDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub, *sql.DB) {
t.Helper()
if !WillUsePostgres() {
t.Fatal("cannot use NewDBWithSQLDB without PostgreSQL, consider adding `if !dbtestutil.WillUsePostgres() { t.Skip() }` to this test")
}
var sqlDB *sql.DB
opts = append(opts, withReturnSQLDB(func(db *sql.DB) {
sqlDB = db
@@ -20,6 +20,9 @@ func TestMain(m *testing.M) {
func TestOpen(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
connect, err := dbtestutil.Open(t)
require.NoError(t, err)
@@ -34,6 +37,9 @@ func TestOpen(t *testing.T) {
func TestOpen_InvalidDBFrom(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
_, err := dbtestutil.Open(t, dbtestutil.WithDBFrom("__invalid__"))
require.Error(t, err)
@@ -43,6 +49,9 @@ func TestOpen_InvalidDBFrom(t *testing.T) {
func TestOpen_ValidDBFrom(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
// first check if we can create a new template db
dsn, err := dbtestutil.Open(t, dbtestutil.WithDBFrom(""))
@@ -106,6 +115,9 @@ func TestOpen_ValidDBFrom(t *testing.T) {
func TestOpen_Panic(t *testing.T) {
t.Skip("unskip this to manually test that we don't leak a database into postgres")
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
_, err := dbtestutil.Open(t)
require.NoError(t, err)
@@ -115,6 +127,9 @@ func TestOpen_Panic(t *testing.T) {
func TestOpen_Timeout(t *testing.T) {
t.Skip("unskip this and set a short timeout to manually test that we don't leak a database into postgres")
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
_, err := dbtestutil.Open(t)
require.NoError(t, err)
+17 -46
View File
@@ -1055,8 +1055,7 @@ CREATE TABLE aibridge_interceptions (
provider text NOT NULL,
model text NOT NULL,
started_at timestamp with time zone NOT NULL,
metadata jsonb,
ended_at timestamp with time zone
metadata jsonb
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
@@ -1126,8 +1125,7 @@ CREATE TABLE api_keys (
ip_address inet DEFAULT '0.0.0.0'::inet NOT NULL,
token_name text DEFAULT ''::text NOT NULL,
scopes api_key_scope[] NOT NULL,
allow_list text[] NOT NULL,
CONSTRAINT api_keys_allow_list_not_empty CHECK ((array_length(allow_list, 1) > 0))
allow_list text[] NOT NULL
);
COMMENT ON COLUMN api_keys.hashed_secret IS 'hashed_secret contains a SHA256 hash of the key secret. This is considered a secret and MUST NOT be returned from the API as it is used for API key encryption in app proxying code.';
@@ -1538,7 +1536,7 @@ CREATE TABLE oauth2_provider_apps (
jwks jsonb,
software_id text,
software_version text,
registration_access_token bytea,
registration_access_token text,
registration_client_uri text
);
@@ -1828,15 +1826,6 @@ CREATE TABLE tasks (
deleted_at timestamp with time zone
);
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE TABLE workspace_agents (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
@@ -1987,16 +1976,8 @@ CREATE VIEW tasks_with_status AS
END AS status,
task_app.workspace_build_number,
task_app.workspace_agent_id,
task_app.workspace_app_id,
task_owner.owner_username,
task_owner.owner_name,
task_owner.owner_avatar_url
FROM (((((tasks
CROSS JOIN LATERAL ( SELECT vu.username AS owner_username,
vu.name AS owner_name,
vu.avatar_url AS owner_avatar_url
FROM visible_users vu
WHERE (vu.id = tasks.owner_id)) task_owner)
task_app.workspace_app_id
FROM ((((tasks
LEFT JOIN LATERAL ( SELECT task_app_1.workspace_build_number,
task_app_1.workspace_agent_id,
task_app_1.workspace_app_id
@@ -2029,18 +2010,6 @@ CREATE TABLE telemetry_items (
updated_at timestamp with time zone DEFAULT now() NOT NULL
);
CREATE TABLE telemetry_locks (
event_type text NOT NULL,
period_ending_at timestamp with time zone NOT NULL,
CONSTRAINT telemetry_lock_event_type_constraint CHECK ((event_type = 'aibridge_interceptions_summary'::text))
);
COMMENT ON TABLE telemetry_locks IS 'Telemetry lock tracking table for deduplication of heartbeat events across replicas.';
COMMENT ON COLUMN telemetry_locks.event_type IS 'The type of event that was sent.';
COMMENT ON COLUMN telemetry_locks.period_ending_at IS 'The heartbeat period end timestamp.';
CREATE TABLE template_usage_stats (
start_time timestamp with time zone NOT NULL,
end_time timestamp with time zone NOT NULL,
@@ -2227,6 +2196,15 @@ COMMENT ON COLUMN template_versions.external_auth_providers IS 'IDs of External
COMMENT ON COLUMN template_versions.message IS 'Message describing the changes in this version of the template, similar to a Git commit message. Like a commit message, this should be a short, high-level description of the changes in this version of the template. This message is immutable and should not be updated after the fact.';
CREATE VIEW visible_users AS
SELECT users.id,
users.username,
users.name,
users.avatar_url
FROM users;
COMMENT ON VIEW visible_users IS 'Visible fields of users are allowed to be joined with other tables for including context of other resources.';
CREATE VIEW template_version_with_user AS
SELECT template_versions.id,
template_versions.template_id,
@@ -2922,13 +2900,11 @@ CREATE VIEW workspaces_expanded AS
templates.name AS template_name,
templates.display_name AS template_display_name,
templates.icon AS template_icon,
templates.description AS template_description,
tasks.id AS task_id
FROM ((((workspaces
templates.description AS template_description
FROM (((workspaces
JOIN visible_users ON ((workspaces.owner_id = visible_users.id)))
JOIN organizations ON ((workspaces.organization_id = organizations.id)))
JOIN templates ON ((workspaces.template_id = templates.id)))
LEFT JOIN tasks ON ((workspaces.id = tasks.workspace_id)));
JOIN templates ON ((workspaces.template_id = templates.id)));
COMMENT ON VIEW workspaces_expanded IS 'Joins in the display name information such as username, avatar, and organization name.';
@@ -3112,9 +3088,6 @@ ALTER TABLE ONLY tasks
ALTER TABLE ONLY telemetry_items
ADD CONSTRAINT telemetry_items_pkey PRIMARY KEY (key);
ALTER TABLE ONLY telemetry_locks
ADD CONSTRAINT telemetry_locks_pkey PRIMARY KEY (event_type, period_ending_at);
ALTER TABLE ONLY template_usage_stats
ADD CONSTRAINT template_usage_stats_pkey PRIMARY KEY (start_time, template_id, user_id);
@@ -3340,8 +3313,6 @@ CREATE INDEX idx_tailnet_tunnels_dst_id ON tailnet_tunnels USING hash (dst_id);
CREATE INDEX idx_tailnet_tunnels_src_id ON tailnet_tunnels USING hash (src_id);
CREATE INDEX idx_telemetry_locks_period_ending_at ON telemetry_locks USING btree (period_ending_at);
CREATE UNIQUE INDEX idx_template_version_presets_default ON template_version_presets USING btree (template_version_id) WHERE (is_default = true);
CREATE INDEX idx_template_versions_has_ai_task ON template_versions USING btree (has_ai_task);
@@ -141,19 +141,13 @@ ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:read';
ALTER TYPE api_key_scope ADD VALUE IF NOT EXISTS 'workspace_proxy:update';
-- End enum extensions
-- Purge old API keys to speed up the migration for large deployments.
-- Note: that problem should be solved in coderd once PR 20863 is released:
-- https://github.com/coder/coder/blob/main/coderd/database/dbpurge/dbpurge.go#L85
DELETE FROM api_keys WHERE expires_at < NOW() - INTERVAL '7 days';
-- Add new columns without defaults; backfill; then enforce NOT NULL
ALTER TABLE api_keys ADD COLUMN scopes api_key_scope[];
ALTER TABLE api_keys ADD COLUMN allow_list text[];
-- Backfill existing rows for compatibility
UPDATE api_keys SET
scopes = ARRAY[scope::api_key_scope],
allow_list = ARRAY['*:*'];
UPDATE api_keys SET scopes = ARRAY[scope::api_key_scope];
UPDATE api_keys SET allow_list = ARRAY['*:*'];
-- Enforce NOT NULL
ALTER TABLE api_keys ALTER COLUMN scopes SET NOT NULL;

Some files were not shown because too many files have changed in this diff Show More