Compare commits
96 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9764926f92 | |||
| 10d4e42fc1 | |||
| 217ddf46c4 | |||
| 0d3d493eae | |||
| 89b060e245 | |||
| 820d53b66a | |||
| f550028052 | |||
| e6873c8d61 | |||
| 8c0bfcb570 | |||
| c322b92ab0 | |||
| 216a5ac562 | |||
| 86447126d5 | |||
| 55c5b707fb | |||
| 4616c82f3c | |||
| 9ca30e28d6 | |||
| 34c1370090 | |||
| 851c4f907c | |||
| e3dfe45f35 | |||
| 7bad7e35ae | |||
| cd0a2849d0 | |||
| f6e86c6fdb | |||
| c301a0d804 | |||
| 6c621364f8 | |||
| 51d3abb904 | |||
| c6e551f538 | |||
| f684831f56 | |||
| f947a34103 | |||
| fb9d8e3030 | |||
| e60112e54f | |||
| e8e31dcb2c | |||
| 40e1784846 | |||
| 5a31c590e6 | |||
| e13a34c145 | |||
| 33b42fca7a | |||
| 86ef3fb497 | |||
| 13ca9ead3a | |||
| 906149317d | |||
| 6187acff8a | |||
| a106d67c07 | |||
| 2c6cbf15e2 | |||
| 1cb2ac65e5 | |||
| c6f63990cf | |||
| 9855460524 | |||
| 79728c30fa | |||
| 8daf4f35b1 | |||
| 5c802c2627 | |||
| 0f342ecc04 | |||
| e62c5db678 | |||
| 4244b20823 | |||
| 70cc3dd14a | |||
| d455f6ea2b | |||
| 4bd7c7b7e0 | |||
| 5f97ad0988 | |||
| 48f77d0c01 | |||
| da31a4bed9 | |||
| 9730c86f17 | |||
| 5ecab7b5f0 | |||
| df3b1bb6c7 | |||
| caeca1097b | |||
| 823b14aa34 | |||
| f2a410566c | |||
| aa689cbb39 | |||
| 1230cacf78 | |||
| 7bbeef4999 | |||
| f64ac8f5f7 | |||
| 69c2c40512 | |||
| 9da60a9dc5 | |||
| e73f9d356b | |||
| 87ce021035 | |||
| 86f0f39863 | |||
| 650dc860bd | |||
| c2dcf9348a | |||
| ea261a1f7c | |||
| 01ff28db11 | |||
| 77e8d2b887 | |||
| ccf0b34872 | |||
| 0652b18ebc | |||
| 5a18cf4c86 | |||
| 6a3bf6ff53 | |||
| b022ccefa7 | |||
| 66f1603f6a | |||
| 2e45236d31 | |||
| 0c2288d802 | |||
| 712d036192 | |||
| c1f8465de6 | |||
| 88851d248c | |||
| a13f29ff95 | |||
| e92b4fe13d | |||
| 784592a2dc | |||
| 251f787743 | |||
| 1a766a271f | |||
| cbaa97cb78 | |||
| 141ef23c81 | |||
| b0a045cba0 | |||
| f6526b789a | |||
| cfbbcfc65a |
@@ -91,6 +91,9 @@
|
||||
|
||||
## Systematic Debugging Approach
|
||||
|
||||
YOU MUST ALWAYS find the root cause of any issue you are debugging
|
||||
YOU MUST NEVER fix a symptom or add a workaround instead of finding a root cause, even if it is faster.
|
||||
|
||||
### Multi-Issue Problem Solving
|
||||
|
||||
When facing multiple failing tests or complex integration issues:
|
||||
@@ -98,16 +101,21 @@ When facing multiple failing tests or complex integration issues:
|
||||
1. **Identify Root Causes**:
|
||||
- Run failing tests individually to isolate issues
|
||||
- Use LSP tools to trace through call chains
|
||||
- Check both compilation and runtime errors
|
||||
- Read Error Messages Carefully: Check both compilation and runtime errors
|
||||
- Reproduce Consistently: Ensure you can reliably reproduce the issue before investigating
|
||||
- Check Recent Changes: What changed that could have caused this? Git diff, recent commits, etc.
|
||||
- When You Don't Know: Say "I don't understand X" rather than pretending to know
|
||||
|
||||
2. **Fix in Logical Order**:
|
||||
- Address compilation issues first (imports, syntax)
|
||||
- Fix authorization and RBAC issues next
|
||||
- Resolve business logic and validation issues
|
||||
- Handle edge cases and race conditions last
|
||||
- IF your first fix doesn't work, STOP and re-analyze rather than adding more fixes
|
||||
|
||||
3. **Verification Strategy**:
|
||||
- Test each fix individually before moving to next issue
|
||||
- Always Test each fix individually before moving to next issue
|
||||
- Verify Before Continuing: Did your test work? If not, form new hypothesis - don't add more fixes
|
||||
- Use `make lint` and `make gen` after database changes
|
||||
- Verify RFC compliance with actual specifications
|
||||
- Run comprehensive test suites before considering complete
|
||||
|
||||
@@ -40,11 +40,15 @@
|
||||
- Use proper error types
|
||||
- Pattern: `xerrors.Errorf("failed to X: %w", err)`
|
||||
|
||||
### Naming Conventions
|
||||
## Naming Conventions
|
||||
|
||||
- Use clear, descriptive names
|
||||
- Abbreviate only when obvious
|
||||
- Names MUST tell what code does, not how it's implemented or its history
|
||||
- Follow Go and TypeScript naming conventions
|
||||
- When changing code, never document the old behavior or the behavior change
|
||||
- NEVER use implementation details in names (e.g., "ZodValidator", "MCPWrapper", "JSONParser")
|
||||
- NEVER use temporal/historical context in names (e.g., "LegacyHandler", "UnifiedTool", "ImprovedInterface", "EnhancedParser")
|
||||
- NEVER use pattern names unless they add clarity (e.g., prefer "Tool" over "ToolFactory")
|
||||
- Abbreviate only when obvious
|
||||
|
||||
### Comments
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ jobs:
|
||||
echo "LINT_CACHE_DIR=$dir" >> "$GITHUB_ENV"
|
||||
|
||||
- name: golangci-lint cache
|
||||
uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
|
||||
with:
|
||||
path: |
|
||||
${{ env.LINT_CACHE_DIR }}
|
||||
@@ -191,7 +191,7 @@ jobs:
|
||||
|
||||
# Check for any typos
|
||||
- name: Check for typos
|
||||
uses: crate-ci/typos@85f62a8a84f939ae994ab3763f01a0296d61a7ee # v1.36.2
|
||||
uses: crate-ci/typos@80c8a4945eec0f6d464eaf9e65ed98ef085283d1 # v1.38.1
|
||||
with:
|
||||
config: .github/workflows/typos.toml
|
||||
|
||||
@@ -806,7 +806,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -838,7 +838,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
|
||||
uses: chromaui/action@4ffe736a2a8262ea28067ff05a13b635ba31ec05 # v13.3.0
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -1123,7 +1123,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -1537,7 +1537,7 @@ jobs:
|
||||
steps:
|
||||
- name: Send Slack notification
|
||||
run: |
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U08TJ4YNCA3> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data '{
|
||||
"blocks": [
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Set up Flux CLI
|
||||
uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4
|
||||
uses: fluxcd/flux2/action@4a15fa6a023259353ef750acf1c98fe88407d4d0 # v2.7.2
|
||||
with:
|
||||
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
|
||||
version: "2.7.0"
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Docker login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- uses: tj-actions/changed-files@4563c729c555b4141fac99c80f699f571219b836 # v45.0.7
|
||||
- uses: tj-actions/changed-files@d03a93c0dbfac6d6dd6a0d8a5e7daff992b07449 # v45.0.7
|
||||
id: changed-files
|
||||
with:
|
||||
files: |
|
||||
|
||||
@@ -82,7 +82,7 @@ jobs:
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
@@ -170,7 +170,7 @@ jobs:
|
||||
steps:
|
||||
- name: Send Slack notification
|
||||
run: |
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U08TJ4YNCA3> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data '{
|
||||
"blocks": [
|
||||
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Find Comment
|
||||
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
|
||||
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
|
||||
id: fc
|
||||
with:
|
||||
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Comment on PR
|
||||
id: comment_id
|
||||
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
|
||||
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
|
||||
with:
|
||||
comment-id: ${{ steps.fc.outputs.comment-id }}
|
||||
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
|
||||
@@ -248,7 +248,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -491,7 +491,7 @@ jobs:
|
||||
PASSWORD: ${{ steps.setup_deployment.outputs.password }}
|
||||
|
||||
- name: Find Comment
|
||||
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
|
||||
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
|
||||
id: fc
|
||||
with:
|
||||
issue-number: ${{ env.PR_NUMBER }}
|
||||
@@ -500,7 +500,7 @@ jobs:
|
||||
direction: last
|
||||
|
||||
- name: Comment on PR
|
||||
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
|
||||
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
|
||||
env:
|
||||
STATUS: ${{ needs.get_info.outputs.NEW == 'true' && 'Created' || 'Updated' }}
|
||||
with:
|
||||
|
||||
@@ -239,7 +239,7 @@ jobs:
|
||||
cat "$CODER_RELEASE_NOTES_FILE"
|
||||
|
||||
- name: Docker Login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -785,7 +785,7 @@ jobs:
|
||||
|
||||
- name: Send repository-dispatch event
|
||||
if: ${{ !inputs.dry_run }}
|
||||
uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0
|
||||
uses: peter-evans/repository-dispatch@5fc4efd1a4797ddb68ffd0714a238564e4cc0e6f # v4.0.0
|
||||
with:
|
||||
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
|
||||
repository: coder/packages
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: "Run analysis"
|
||||
uses: ossf/scorecard-action@05b42c624433fc40578a4040d5cf5e36ddca8cde # v2.4.2
|
||||
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
|
||||
with:
|
||||
results_file: results.sarif
|
||||
results_format: sarif
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/init@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/analyze@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@16140ae1a102900babc80a33c44059580f687047 # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
@@ -23,7 +23,7 @@ jobs:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
|
||||
@@ -13,12 +13,12 @@ on:
|
||||
template_name:
|
||||
description: "Coder template to use for workspace"
|
||||
required: true
|
||||
default: "traiage"
|
||||
default: "coder"
|
||||
type: string
|
||||
template_preset:
|
||||
description: "Template preset to use"
|
||||
required: true
|
||||
default: "Default"
|
||||
default: "none"
|
||||
type: string
|
||||
prefix:
|
||||
description: "Prefix for workspace name"
|
||||
@@ -66,8 +66,8 @@ jobs:
|
||||
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
|
||||
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
|
||||
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
|
||||
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'traiage' }}
|
||||
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'Default'}}
|
||||
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'coder' }}
|
||||
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'none'}}
|
||||
INPUTS_PREFIX: ${{ inputs.prefix || 'traiage' }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
@@ -168,7 +168,7 @@ jobs:
|
||||
echo "coder_username=${coder_username}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check Markdown links
|
||||
uses: umbrelladocs/action-linkspector@874d01cae9fd488e3077b08952093235bd626977 # v1.3.7
|
||||
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
|
||||
id: markdown-link-check
|
||||
# checks all markdown files from /docs including all subfolders
|
||||
with:
|
||||
|
||||
@@ -12,6 +12,9 @@ node_modules/
|
||||
vendor/
|
||||
yarn-error.log
|
||||
|
||||
# Test output files
|
||||
test-output/
|
||||
|
||||
# VSCode settings.
|
||||
**/.vscode/*
|
||||
# Allow VSCode recommendations and default settings in project root.
|
||||
|
||||
+11
-1
@@ -169,6 +169,16 @@ linters-settings:
|
||||
- name: var-declaration
|
||||
- name: var-naming
|
||||
- name: waitgroup-by-value
|
||||
usetesting:
|
||||
# Only os-setenv is enabled because we migrated to usetesting from another linter that
|
||||
# only covered os-setenv.
|
||||
os-setenv: true
|
||||
os-create-temp: false
|
||||
os-mkdir-temp: false
|
||||
os-temp-dir: false
|
||||
os-chdir: false
|
||||
context-background: false
|
||||
context-todo: false
|
||||
|
||||
# irrelevant as of Go v1.22: https://go.dev/blog/loopvar-preview
|
||||
govet:
|
||||
@@ -252,7 +262,6 @@ linters:
|
||||
# - wastedassign
|
||||
|
||||
- staticcheck
|
||||
- tenv
|
||||
# In Go, it's possible for a package to test it's internal functionality
|
||||
# without testing any exported functions. This is enabled to promote
|
||||
# decomposing a package before testing it's internals. A function caller
|
||||
@@ -265,4 +274,5 @@ linters:
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unused
|
||||
- usetesting
|
||||
- dupl
|
||||
|
||||
@@ -1,11 +1,41 @@
|
||||
# Coder Development Guidelines
|
||||
|
||||
You are an experienced, pragmatic software engineer. You don't over-engineer a solution when a simple one is possible.
|
||||
Rule #1: If you want exception to ANY rule, YOU MUST STOP and get explicit permission first. BREAKING THE LETTER OR SPIRIT OF THE RULES IS FAILURE.
|
||||
|
||||
## Foundational rules
|
||||
|
||||
- Doing it right is better than doing it fast. You are not in a rush. NEVER skip steps or take shortcuts.
|
||||
- Tedious, systematic work is often the correct solution. Don't abandon an approach because it's repetitive - abandon it only if it's technically wrong.
|
||||
- Honesty is a core value.
|
||||
|
||||
## Our relationship
|
||||
|
||||
- Act as a critical peer reviewer. Your job is to disagree with me when I'm wrong, not to please me. Prioritize accuracy and reasoning over agreement.
|
||||
- YOU MUST speak up immediately when you don't know something or we're in over our heads
|
||||
- YOU MUST call out bad ideas, unreasonable expectations, and mistakes - I depend on this
|
||||
- NEVER be agreeable just to be nice - I NEED your HONEST technical judgment
|
||||
- NEVER write the phrase "You're absolutely right!" You are not a sycophant. We're working together because I value your opinion. Do not agree with me unless you can justify it with evidence or reasoning.
|
||||
- YOU MUST ALWAYS STOP and ask for clarification rather than making assumptions.
|
||||
- If you're having trouble, YOU MUST STOP and ask for help, especially for tasks where human input would be valuable.
|
||||
- When you disagree with my approach, YOU MUST push back. Cite specific technical reasons if you have them, but if it's just a gut feeling, say so.
|
||||
- If you're uncomfortable pushing back out loud, just say "Houston, we have a problem". I'll know what you mean
|
||||
- We discuss architectutral decisions (framework changes, major refactoring, system design) together before implementation. Routine fixes and clear implementations don't need discussion.
|
||||
|
||||
## Proactiveness
|
||||
|
||||
When asked to do something, just do it - including obvious follow-up actions needed to complete the task properly.
|
||||
Only pause to ask for confirmation when:
|
||||
|
||||
- Multiple valid approaches exist and the choice matters
|
||||
- The action would delete or significantly restructure existing code
|
||||
- You genuinely don't understand what's being asked
|
||||
- Your partner asked a question (answer the question, don't jump to implementation)
|
||||
|
||||
@.claude/docs/WORKFLOWS.md
|
||||
@.cursorrules
|
||||
@README.md
|
||||
@package.json
|
||||
|
||||
## 🚀 Essential Commands
|
||||
## Essential Commands
|
||||
|
||||
| Task | Command | Notes |
|
||||
|-------------------|--------------------------|----------------------------------|
|
||||
@@ -21,22 +51,13 @@
|
||||
| **Format** | `make fmt` | Auto-format code |
|
||||
| **Clean** | `make clean` | Clean build artifacts |
|
||||
|
||||
### Frontend Commands (site directory)
|
||||
|
||||
- `pnpm build` - Build frontend
|
||||
- `pnpm dev` - Run development server
|
||||
- `pnpm check` - Run code checks
|
||||
- `pnpm format` - Format frontend code
|
||||
- `pnpm lint` - Lint frontend code
|
||||
- `pnpm test` - Run frontend tests
|
||||
|
||||
### Documentation Commands
|
||||
|
||||
- `pnpm run format-docs` - Format markdown tables in docs
|
||||
- `pnpm run lint-docs` - Lint and fix markdown files
|
||||
- `pnpm run storybook` - Run Storybook (from site directory)
|
||||
|
||||
## 🔧 Critical Patterns
|
||||
## Critical Patterns
|
||||
|
||||
### Database Changes (ALWAYS FOLLOW)
|
||||
|
||||
@@ -78,7 +99,7 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
|
||||
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
```
|
||||
|
||||
## 📋 Quick Reference
|
||||
## Quick Reference
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
|
||||
@@ -88,14 +109,14 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
- [ ] Check if feature touches database - you'll need migrations
|
||||
- [ ] Check if feature touches audit logs - update `enterprise/audit/table.go`
|
||||
|
||||
## 🏗️ Architecture
|
||||
## Architecture
|
||||
|
||||
- **coderd**: Main API service
|
||||
- **provisionerd**: Infrastructure provisioning
|
||||
- **Agents**: Workspace services (SSH, port forwarding)
|
||||
- **Database**: PostgreSQL with `dbauthz` authorization
|
||||
|
||||
## 🧪 Testing
|
||||
## Testing
|
||||
|
||||
### Race Condition Prevention
|
||||
|
||||
@@ -112,21 +133,21 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
NEVER use `time.Sleep` to mitigate timing issues. If an issue
|
||||
seems like it should use `time.Sleep`, read through https://github.com/coder/quartz and specifically the [README](https://github.com/coder/quartz/blob/main/README.md) to better understand how to handle timing issues.
|
||||
|
||||
## 🎯 Code Style
|
||||
## Code Style
|
||||
|
||||
### Detailed guidelines in imported WORKFLOWS.md
|
||||
|
||||
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
|
||||
- Commit format: `type(scope): message`
|
||||
|
||||
## 📚 Detailed Development Guides
|
||||
## Detailed Development Guides
|
||||
|
||||
@.claude/docs/OAUTH2.md
|
||||
@.claude/docs/TESTING.md
|
||||
@.claude/docs/TROUBLESHOOTING.md
|
||||
@.claude/docs/DATABASE.md
|
||||
|
||||
## 🚨 Common Pitfalls
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Audit table errors** → Update `enterprise/audit/table.go`
|
||||
2. **OAuth2 errors** → Return RFC-compliant format
|
||||
|
||||
@@ -642,6 +642,7 @@ AIBRIDGED_MOCKS := \
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
@@ -676,6 +677,7 @@ gen/db: $(DB_GEN_FILES)
|
||||
.PHONY: gen/db
|
||||
|
||||
gen/golden-files: \
|
||||
agent/unit/testdata/.gen-golden \
|
||||
cli/testdata/.gen-golden \
|
||||
coderd/.gen-golden \
|
||||
coderd/notifications/.gen-golden \
|
||||
@@ -799,6 +801,14 @@ agent/proto/agent.pb.go: agent/proto/agent.proto
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/proto/agent.proto
|
||||
|
||||
agent/agentsocket/proto/agentsocket.pb.go: agent/agentsocket/proto/agentsocket.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
@@ -952,6 +962,10 @@ clean/golden-files:
|
||||
-type f -name '*.golden' -delete
|
||||
.PHONY: clean/golden-files
|
||||
|
||||
agent/unit/testdata/.gen-golden: $(wildcard agent/unit/testdata/*.golden) $(GO_SRC_FILES) $(wildcard agent/unit/*_test.go)
|
||||
TZ=UTC go test ./agent/unit -run="TestGraph" -update
|
||||
touch "$@"
|
||||
|
||||
cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) $(wildcard cli/*_test.go)
|
||||
TZ=UTC go test ./cli -run="Test(CommandHelp|ServerYAML|ErrorExamples|.*Golden)" -update
|
||||
touch "$@"
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
|
||||
@@ -91,6 +92,7 @@ type Options struct {
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
Clock quartz.Clock
|
||||
SocketPath string // Path for the agent socket server
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@@ -190,6 +192,7 @@ func New(options Options) Agent {
|
||||
|
||||
devcontainers: options.Devcontainers,
|
||||
containerAPIOptions: options.DevcontainerAPIOptions,
|
||||
socketPath: options.SocketPath,
|
||||
}
|
||||
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
|
||||
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
|
||||
@@ -271,6 +274,9 @@ type agent struct {
|
||||
devcontainers bool
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
socketPath string
|
||||
socketServer *agentsocket.Server
|
||||
}
|
||||
|
||||
func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
@@ -350,9 +356,35 @@ func (a *agent) init() {
|
||||
s.ExperimentalContainers = a.devcontainers
|
||||
},
|
||||
)
|
||||
|
||||
a.initSocketServer()
|
||||
|
||||
go a.runLoop()
|
||||
}
|
||||
|
||||
// initSocketServer initializes server that allows direct communication with a workspace agent using IPC.
|
||||
func (a *agent) initSocketServer() {
|
||||
if a.socketPath == "" {
|
||||
a.logger.Info(a.hardCtx, "socket server disabled (no path configured)")
|
||||
return
|
||||
}
|
||||
|
||||
server, err := agentsocket.NewServer(a.socketPath, a.logger.Named("socket"))
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start socket server", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
a.socketServer = server
|
||||
a.logger.Debug(a.hardCtx, "socket server started", slog.F("path", a.socketPath))
|
||||
}
|
||||
|
||||
// runLoop attempts to start the agent in a retry loop.
|
||||
// Coder may be offline temporarily, a connection issue
|
||||
// may be happening, but regardless after the intermittent
|
||||
@@ -1920,6 +1952,13 @@ func (a *agent) Close() error {
|
||||
lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError
|
||||
}
|
||||
}
|
||||
|
||||
if a.socketServer != nil {
|
||||
if err := a.socketServer.Stop(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "socket server close", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
a.setLifecycle(lifecycleState)
|
||||
|
||||
err = a.scriptRunner.Close()
|
||||
|
||||
+76
-33
@@ -3462,11 +3462,7 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
registry := prometheus.NewRegistry()
|
||||
|
||||
//nolint:dogsled
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
|
||||
// Make sure we always get a DERP connection for
|
||||
// currently_reachable_peers.
|
||||
DisableDirectConnections: true,
|
||||
}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.PrometheusRegistry = registry
|
||||
})
|
||||
|
||||
@@ -3481,16 +3477,31 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
err = session.Shell()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := []*proto.Stats_Metric{
|
||||
expected := []struct {
|
||||
Name string
|
||||
Type proto.Stats_Metric_Type
|
||||
CheckFn func(float64) error
|
||||
Labels []*proto.Stats_Metric_Label
|
||||
}{
|
||||
{
|
||||
Name: "agent_reconnecting_pty_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
Value: 0,
|
||||
Name: "agent_reconnecting_pty_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
CheckFn: func(v float64) error {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected 0, got %f", v)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agent_sessions_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
Value: 1,
|
||||
Name: "agent_sessions_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
CheckFn: func(v float64) error {
|
||||
if v == 1 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected 1, got %f", v)
|
||||
},
|
||||
Labels: []*proto.Stats_Metric_Label{
|
||||
{
|
||||
Name: "magic_type",
|
||||
@@ -3503,24 +3514,44 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agent_ssh_server_failed_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
Value: 0,
|
||||
Name: "agent_ssh_server_failed_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
CheckFn: func(v float64) error {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected 0, got %f", v)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agent_ssh_server_sftp_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
Value: 0,
|
||||
Name: "agent_ssh_server_sftp_connections_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
CheckFn: func(v float64) error {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected 0, got %f", v)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "agent_ssh_server_sftp_server_errors_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
Value: 0,
|
||||
Name: "agent_ssh_server_sftp_server_errors_total",
|
||||
Type: proto.Stats_Metric_COUNTER,
|
||||
CheckFn: func(v float64) error {
|
||||
if v == 0 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected 0, got %f", v)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "coderd_agentstats_currently_reachable_peers",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
Value: 1,
|
||||
Name: "coderd_agentstats_currently_reachable_peers",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
CheckFn: func(float64) error {
|
||||
// We can't reliably ping a peer here, and networking is out of
|
||||
// scope of this test, so we just test that the metric exists
|
||||
// with the correct labels.
|
||||
return nil
|
||||
},
|
||||
Labels: []*proto.Stats_Metric_Label{
|
||||
{
|
||||
Name: "connection_type",
|
||||
@@ -3529,9 +3560,11 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "coderd_agentstats_currently_reachable_peers",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
Value: 0,
|
||||
Name: "coderd_agentstats_currently_reachable_peers",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
CheckFn: func(float64) error {
|
||||
return nil
|
||||
},
|
||||
Labels: []*proto.Stats_Metric_Label{
|
||||
{
|
||||
Name: "connection_type",
|
||||
@@ -3540,9 +3573,20 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "coderd_agentstats_startup_script_seconds",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
Value: 1,
|
||||
Name: "coderd_agentstats_startup_script_seconds",
|
||||
Type: proto.Stats_Metric_GAUGE,
|
||||
CheckFn: func(f float64) error {
|
||||
if f >= 0 {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("expected >= 0, got %f", f)
|
||||
},
|
||||
Labels: []*proto.Stats_Metric_Label{
|
||||
{
|
||||
Name: "success",
|
||||
Value: "true",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3564,11 +3608,10 @@ func TestAgent_Metrics_SSH(t *testing.T) {
|
||||
for _, m := range mf.GetMetric() {
|
||||
assert.Equal(t, expected[i].Name, mf.GetName())
|
||||
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
|
||||
// Value is max expected
|
||||
if expected[i].Type == proto.Stats_Metric_GAUGE {
|
||||
assert.GreaterOrEqualf(t, expected[i].Value, m.GetGauge().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetGauge().GetValue())
|
||||
assert.NoError(t, expected[i].CheckFn(m.GetGauge().GetValue()), "check fn for %s failed", expected[i].Name)
|
||||
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
|
||||
assert.GreaterOrEqualf(t, expected[i].Value, m.GetCounter().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetCounter().GetValue())
|
||||
assert.NoError(t, expected[i].CheckFn(m.GetCounter().GetValue()), "check fn for %s failed", expected[i].Name)
|
||||
}
|
||||
for j, lbl := range expected[i].Labels {
|
||||
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,88 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/v2/agent/agentsocket/proto";
|
||||
|
||||
package coder.agentsocket.v1;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message PingRequest {}
|
||||
|
||||
message PingResponse {
|
||||
string message = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
}
|
||||
|
||||
message SyncStartRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncStartResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncWantRequest {
|
||||
string unit = 1;
|
||||
string depends_on = 2;
|
||||
}
|
||||
|
||||
message SyncWantResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncCompleteRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncCompleteResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncReadyRequest {
|
||||
string unit = 1;
|
||||
}
|
||||
|
||||
message SyncReadyResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message SyncStatusRequest {
|
||||
string unit = 1;
|
||||
bool recursive = 2;
|
||||
}
|
||||
|
||||
message DependencyInfo {
|
||||
string depends_on = 1;
|
||||
string required_status = 2;
|
||||
string current_status = 3;
|
||||
bool is_satisfied = 4;
|
||||
}
|
||||
|
||||
message SyncStatusResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
string unit = 3;
|
||||
string status = 4;
|
||||
bool is_ready = 5;
|
||||
repeated DependencyInfo dependencies = 6;
|
||||
string dot = 7;
|
||||
}
|
||||
|
||||
// AgentSocket provides direct access to the agent over local IPC.
|
||||
service AgentSocket {
|
||||
// Ping the agent to check if it is alive.
|
||||
rpc Ping(PingRequest) returns (PingResponse);
|
||||
// Report the start of a unit.
|
||||
rpc SyncStart(SyncStartRequest) returns (SyncStartResponse);
|
||||
// Declare a dependency between units.
|
||||
rpc SyncWant(SyncWantRequest) returns (SyncWantResponse);
|
||||
// Report the completion of a unit.
|
||||
rpc SyncComplete(SyncCompleteRequest) returns (SyncCompleteResponse);
|
||||
// Request whether a unit is ready to be started. That is, all dependencies are satisfied.
|
||||
rpc SyncReady(SyncReadyRequest) returns (SyncReadyResponse);
|
||||
// Get the status of a unit and list its dependencies.
|
||||
rpc SyncStatus(SyncStatusRequest) returns (SyncStatusResponse);
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.34
|
||||
// source: agent/agentsocket/proto/agentsocket.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
errors "errors"
|
||||
protojson "google.golang.org/protobuf/encoding/protojson"
|
||||
proto "google.golang.org/protobuf/proto"
|
||||
drpc "storj.io/drpc"
|
||||
drpcerr "storj.io/drpc/drpcerr"
|
||||
)
|
||||
|
||||
type drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto struct{}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Marshal(msg drpc.Message) ([]byte, error) {
|
||||
return proto.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) MarshalAppend(buf []byte, msg drpc.Message) ([]byte, error) {
|
||||
return proto.MarshalOptions{}.MarshalAppend(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) Unmarshal(buf []byte, msg drpc.Message) error {
|
||||
return proto.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONMarshal(msg drpc.Message) ([]byte, error) {
|
||||
return protojson.Marshal(msg.(proto.Message))
|
||||
}
|
||||
|
||||
func (drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto) JSONUnmarshal(buf []byte, msg drpc.Message) error {
|
||||
return protojson.Unmarshal(buf, msg.(proto.Message))
|
||||
}
|
||||
|
||||
type DRPCAgentSocketClient interface {
|
||||
DRPCConn() drpc.Conn
|
||||
|
||||
Ping(ctx context.Context, in *PingRequest) (*PingResponse, error)
|
||||
SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentSocketClient struct {
|
||||
cc drpc.Conn
|
||||
}
|
||||
|
||||
func NewDRPCAgentSocketClient(cc drpc.Conn) DRPCAgentSocketClient {
|
||||
return &drpcAgentSocketClient{cc}
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcAgentSocketClient) Ping(ctx context.Context, in *PingRequest) (*PingResponse, error) {
|
||||
out := new(PingResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStart(ctx context.Context, in *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
out := new(SyncStartResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncWant(ctx context.Context, in *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
out := new(SyncWantResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncComplete(ctx context.Context, in *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
out := new(SyncCompleteResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncReady(ctx context.Context, in *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
out := new(SyncReadyResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentSocketClient) SyncStatus(ctx context.Context, in *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
out := new(SyncStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentSocketServer interface {
|
||||
Ping(context.Context, *PingRequest) (*PingResponse, error)
|
||||
SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error)
|
||||
SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error)
|
||||
SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error)
|
||||
SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error)
|
||||
SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketUnimplementedServer struct{}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) Ping(context.Context, *PingRequest) (*PingResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStart(context.Context, *SyncStartRequest) (*SyncStartResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncWant(context.Context, *SyncWantRequest) (*SyncWantResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncComplete(context.Context, *SyncCompleteRequest) (*SyncCompleteResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncReady(context.Context, *SyncReadyRequest) (*SyncReadyResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketUnimplementedServer) SyncStatus(context.Context, *SyncStatusRequest) (*SyncStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentSocketDescription struct{}
|
||||
|
||||
func (DRPCAgentSocketDescription) NumMethods() int { return 6 }
|
||||
|
||||
func (DRPCAgentSocketDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/coder.agentsocket.v1.AgentSocket/Ping", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
Ping(
|
||||
ctx,
|
||||
in1.(*PingRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.Ping, true
|
||||
case 1:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStart", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStart(
|
||||
ctx,
|
||||
in1.(*SyncStartRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStart, true
|
||||
case 2:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncWant", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncWant(
|
||||
ctx,
|
||||
in1.(*SyncWantRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncWant, true
|
||||
case 3:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncComplete", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncComplete(
|
||||
ctx,
|
||||
in1.(*SyncCompleteRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncComplete, true
|
||||
case 4:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncReady", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncReady(
|
||||
ctx,
|
||||
in1.(*SyncReadyRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncReady, true
|
||||
case 5:
|
||||
return "/coder.agentsocket.v1.AgentSocket/SyncStatus", drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentSocketServer).
|
||||
SyncStatus(
|
||||
ctx,
|
||||
in1.(*SyncStatusRequest),
|
||||
)
|
||||
}, DRPCAgentSocketServer.SyncStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func DRPCRegisterAgentSocket(mux drpc.Mux, impl DRPCAgentSocketServer) error {
|
||||
return mux.Register(impl, DRPCAgentSocketDescription{})
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_PingStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*PingResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_PingStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_PingStream) SendAndClose(m *PingResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStartStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStartResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStartStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStartStream) SendAndClose(m *SyncStartResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncWantStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncWantResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncWantStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncWantStream) SendAndClose(m *SyncWantResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncCompleteStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncCompleteResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncCompleteStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncCompleteStream) SendAndClose(m *SyncCompleteResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncReadyStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncReadyResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncReadyStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncReadyStream) SendAndClose(m *SyncReadyResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgentSocket_SyncStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SyncStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgentSocket_SyncStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgentSocket_SyncStatusStream) SendAndClose(m *SyncStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_agentsocket_proto_agentsocket_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package proto
|
||||
|
||||
import "github.com/coder/coder/v2/apiversion"
|
||||
|
||||
// Version history:
|
||||
//
|
||||
// API v1.0:
|
||||
// - Initial release
|
||||
// - Ping
|
||||
// - Sync operations: SyncStart, SyncWant, SyncComplete, SyncWait, SyncStatus
|
||||
|
||||
const (
|
||||
CurrentMajor = 1
|
||||
CurrentMinor = 0
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
|
||||
@@ -0,0 +1,185 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
)
|
||||
|
||||
// Server provides access to the DRPCAgentSocketService via a Unix domain socket.
|
||||
// Do not invoke Server{} directly. Use NewServer() instead.
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
path string
|
||||
listener net.Listener
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
drpcServer *drpcserver.Server
|
||||
service *DRPCAgentSocketService
|
||||
}
|
||||
|
||||
func NewServer(path string, logger slog.Logger) (*Server, error) {
|
||||
logger = logger.Named("agentsocket")
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
path: path,
|
||||
service: &DRPCAgentSocketService{
|
||||
logger: logger,
|
||||
unitManager: unit.NewManager[string, string](),
|
||||
},
|
||||
}
|
||||
|
||||
mux := drpcmux.New()
|
||||
err := proto.DRPCRegisterAgentSocket(mux, server.service)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to register drpc service: %w", err)
|
||||
}
|
||||
|
||||
server.drpcServer = drpcserver.NewWithOptions(mux, drpcserver.Options{
|
||||
Manager: drpcsdk.DefaultDRPCOptions(nil),
|
||||
Log: func(err error) {
|
||||
if errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Debug(context.Background(), "drpc server error", slog.Error(err))
|
||||
},
|
||||
})
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
var ErrServerAlreadyStarted = xerrors.New("server already started")
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return ErrServerAlreadyStarted
|
||||
}
|
||||
|
||||
// This context is canceled by s.Stop() when the server is stopped.
|
||||
// canceling it will close all connections.
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if s.path == "" {
|
||||
var err error
|
||||
s.path, err = getDefaultSocketPath()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get default socket path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := createSocket(s.path)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create socket: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
|
||||
s.logger.Info(s.ctx, "agent socket server started", slog.F("path", s.path))
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.acceptConnections()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info(s.ctx, "stopping agent socket server")
|
||||
|
||||
s.cancel()
|
||||
|
||||
if err := s.listener.Close(); err != nil {
|
||||
s.logger.Warn(s.ctx, "error closing socket listener", slog.Error(err))
|
||||
}
|
||||
|
||||
// Wait for all connections to finish
|
||||
s.wg.Wait()
|
||||
|
||||
if err := cleanupSocket(s.path); err != nil {
|
||||
s.logger.Warn(s.ctx, "error cleaning up socket file", slog.Error(err))
|
||||
}
|
||||
|
||||
s.listener = nil
|
||||
s.logger.Info(s.ctx, "agent socket server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptConnections() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
s.logger.Warn(s.ctx, "error accepting connection", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to set connection deadline", slog.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Debug(s.ctx, "new connection accepted", slog.F("remote_addr", conn.RemoteAddr()))
|
||||
|
||||
config := yamux.DefaultConfig()
|
||||
config.Logger = nil
|
||||
session, err := yamux.Server(conn, config)
|
||||
if err != nil {
|
||||
s.logger.Warn(s.ctx, "failed to create yamux session", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
err = s.drpcServer.Serve(s.ctx, session)
|
||||
if err != nil {
|
||||
s.logger.Debug(s.ctx, "drpc server finished", slog.Error(err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("StartStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
|
||||
t.Run("AlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.ErrorIs(t, server.Start(), agentsocket.ErrServerAlreadyStarted)
|
||||
})
|
||||
|
||||
t.Run("AutoSocketPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(socketPath, logger)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Start())
|
||||
require.NoError(t, server.Stop())
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket/proto"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
var _ proto.DRPCAgentSocketServer = (*DRPCAgentSocketService)(nil)
|
||||
|
||||
type DRPCAgentSocketService struct {
|
||||
mu sync.RWMutex
|
||||
unitManager *unit.Manager[string, string]
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (*DRPCAgentSocketService) Ping(_ context.Context, _ *proto.PingRequest) (*proto.PingResponse, error) {
|
||||
return &proto.PingResponse{
|
||||
Message: "pong",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStart(_ context.Context, req *proto.SyncStartRequest) (*proto.SyncStartResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "dependency tracker not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
// If already registered, that's okay - we can still update status
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
if !isReady {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Unit is not ready",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusStarted); err != nil {
|
||||
return &proto.SyncStartResponse{
|
||||
Success: false,
|
||||
Message: "Failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncStartResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " started successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncWant(_ context.Context, req *proto.SyncWantRequest) (*proto.SyncWantResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" || req.DependsOn == "" {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "unit and depends_on are required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.Unit); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.Register(req.DependsOn); err != nil {
|
||||
if !errors.Is(err, unit.ErrConsumerAlreadyRegistered) {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to register dependency unit: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.unitManager.AddDependency(req.Unit, req.DependsOn, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncWantResponse{
|
||||
Success: false,
|
||||
Message: "failed to add dependency: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncWantResponse{
|
||||
Success: true,
|
||||
Message: "Unit " + req.Unit + " now depends on " + req.DependsOn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncComplete(_ context.Context, req *proto.SyncCompleteRequest) (*proto.SyncCompleteResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := s.unitManager.UpdateStatus(req.Unit, unit.StatusComplete); err != nil {
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: false,
|
||||
Message: "failed to update status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncCompleteResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " completed successfully",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncReady(_ context.Context, req *proto.SyncReadyRequest) (*proto.SyncReadyResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if !isReady {
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: false,
|
||||
Message: unit.ErrDependenciesNotSatisfied.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &proto.SyncReadyResponse{
|
||||
Success: true,
|
||||
Message: "unit " + req.Unit + " dependencies are satisfied",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DRPCAgentSocketService) SyncStatus(_ context.Context, req *proto.SyncStatusRequest) (*proto.SyncStatusResponse, error) {
|
||||
if s.unitManager == nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit manager not available",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if req.Unit == "" {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "unit name is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
status, err := s.unitManager.GetStatus(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get unit status: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
isReady, err := s.unitManager.IsReady(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to check readiness: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dependencies, err := s.unitManager.GetAllDependencies(req.Unit)
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to get dependencies: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var depInfos []*proto.DependencyInfo
|
||||
for _, dep := range dependencies {
|
||||
depInfos = append(depInfos, &proto.DependencyInfo{
|
||||
DependsOn: dep.DependsOn,
|
||||
RequiredStatus: dep.RequiredStatus,
|
||||
CurrentStatus: dep.CurrentStatus,
|
||||
IsSatisfied: dep.IsSatisfied,
|
||||
})
|
||||
}
|
||||
|
||||
var dotStr string
|
||||
if req.Recursive {
|
||||
dotStr, err = s.unitManager.ExportDOT("dependency_graph")
|
||||
if err != nil {
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: false,
|
||||
Message: "failed to export DOT: " + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.SyncStatusResponse{
|
||||
Success: true,
|
||||
Message: "unit status retrieved successfully",
|
||||
Unit: req.Unit,
|
||||
Status: status,
|
||||
IsReady: isReady,
|
||||
Dependencies: depInfos,
|
||||
Dot: dotStr,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
response, err := client.Ping(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "pong", response.Message)
|
||||
})
|
||||
|
||||
t.Run("SyncStart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
// First Start
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Second Start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, unit.ErrSameStatusAlreadySet.Error())
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// First start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Complete the unit
|
||||
err = client.SyncComplete(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", status.Status)
|
||||
|
||||
// Second start
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
})
|
||||
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.ErrorContains(t, err, "Unit is not ready")
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", status.Status)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SyncWant", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NewUnits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// If units are not registered, they are registered automatically
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependency unit
|
||||
err = client.SyncStart(context.Background(), "dependency-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "dependency-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependency unit has already started
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
|
||||
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
socketPath,
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err = server.Start()
|
||||
require.NoError(t, err)
|
||||
defer server.Stop()
|
||||
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: socketPath,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Start the dependent unit
|
||||
err = client.SyncStart(context.Background(), "test-unit")
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "started", status.Status)
|
||||
|
||||
// Add the dependency after the dependency unit has already started
|
||||
err = client.SyncWant(context.Background(), "test-unit", "dependency-unit")
|
||||
|
||||
// Dependencies can be added even if the dependent unit has already started.
|
||||
// The dependency applies the next time a unit is started. The current status is not updated.
|
||||
// This is to allow flexible dependency management. It does mean that users of this API should
|
||||
// take care to add dependencies before they start their dependent units.
|
||||
require.NoError(t, err)
|
||||
|
||||
// The dependency is now reflected in the test unit's status
|
||||
status, err = client.SyncStatus(context.Background(), "test-unit", false)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "dependency-unit", status.Dependencies[0].DependsOn)
|
||||
require.Equal(t, "completed", status.Dependencies[0].RequiredStatus)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
//go:build !windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener
|
||||
func createSocket(path string) (net.Listener, error) {
|
||||
if !isSocketAvailable(path) {
|
||||
return nil, xerrors.Errorf("socket path %s is not available", path)
|
||||
}
|
||||
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("remove existing socket: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory if it doesn't exist
|
||||
parentDir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(parentDir, 0o700); err != nil {
|
||||
return nil, xerrors.Errorf("create socket directory: %w", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listen on unix socket: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(path, 0o600); err != nil {
|
||||
_ = listener.Close()
|
||||
return nil, xerrors.Errorf("set socket permissions: %w", err)
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Unix-like systems
|
||||
func getDefaultSocketPath() (string, error) {
|
||||
// Try XDG_RUNTIME_DIR first
|
||||
if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" {
|
||||
return filepath.Join(runtimeDir, "coder-agent.sock"), nil
|
||||
}
|
||||
|
||||
// Fall back to /tmp with user-specific path
|
||||
uid := os.Getuid()
|
||||
return filepath.Join("/tmp", fmt.Sprintf("coder-agent-%d.sock", uid)), nil
|
||||
}
|
||||
|
||||
// CleanupSocket removes the socket file
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func isSocketAvailable(path string) bool {
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
// Socket is available for use
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
// Socket is in use
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build windows
|
||||
|
||||
package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// createSocket creates a Unix domain socket listener on Windows
|
||||
// Falls back to named pipe if Unix sockets are not supported
|
||||
func CreateSocket(path string) (net.Listener, error) {
|
||||
// Try Unix domain socket first (Windows 10 build 17063+)
|
||||
listener, err := net.Listen("unix", path)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// Fall back to named pipe
|
||||
pipePath := `\\.\pipe\coder-agent`
|
||||
listener, err = net.Listen("tcp", pipePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// getDefaultSocketPath returns the default socket path for Windows
|
||||
func GetDefaultSocketPath() (string, error) {
|
||||
// Try to use a temporary directory
|
||||
tempDir := os.TempDir()
|
||||
if tempDir == "" {
|
||||
tempDir = "C:\\temp"
|
||||
}
|
||||
|
||||
// Create a user-specific subdirectory
|
||||
uid := os.Getuid()
|
||||
userDir := filepath.Join(tempDir, "coder-agent", strconv.Itoa(uid))
|
||||
|
||||
if err := os.MkdirAll(userDir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("create user directory: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userDir, "agent.sock"), nil
|
||||
}
|
||||
|
||||
// cleanupSocket removes the socket file
|
||||
func CleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// isSocketAvailable checks if a socket path is available for use
|
||||
func IsSocketAvailable(path string, logger slog.Logger) bool {
|
||||
logger.Debug(context.Background(), "Checking socket availability on Windows", slog.F("path", path))
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
logger.Debug(context.Background(), "Socket file does not exist, path is available", slog.F("path", path))
|
||||
return true
|
||||
}
|
||||
logger.Debug(context.Background(), "Socket file exists, checking if it's listening", slog.F("path", path))
|
||||
|
||||
// Try to connect to see if it's actually listening
|
||||
conn, err := net.Dial("unix", path)
|
||||
if err != nil {
|
||||
// If we can't connect, the socket is not in use
|
||||
logger.Debug(context.Background(), "Cannot connect to socket, path is available", slog.F("path", path), slog.Error(err))
|
||||
return true
|
||||
}
|
||||
_ = conn.Close()
|
||||
logger.Debug(context.Background(), "Socket is listening, path is not available", slog.F("path", path))
|
||||
return false
|
||||
}
|
||||
|
||||
// getSocketInfo returns information about the socket file
|
||||
func GetSocketInfo(path string) (*SocketInfo, error) {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// On Windows, we'll use a simplified approach for now
|
||||
// In a real implementation, you'd get the security descriptor
|
||||
return &SocketInfo{
|
||||
Path: path,
|
||||
UID: 0, // Simplified for now
|
||||
GID: 0, // Simplified for now
|
||||
Mode: stat.Mode(),
|
||||
ModTime: stat.ModTime(),
|
||||
Owner: "unknown",
|
||||
Group: "unknown",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SocketInfo contains information about a socket file
|
||||
type SocketInfo struct {
|
||||
Path string
|
||||
UID int
|
||||
GID int
|
||||
Mode os.FileMode
|
||||
ModTime time.Time
|
||||
Owner string // Windows SID string
|
||||
Group string // Windows SID string
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"gonum.org/v1/gonum/graph/encoding/dot"
|
||||
"gonum.org/v1/gonum/graph/simple"
|
||||
"gonum.org/v1/gonum/graph/topo"
|
||||
)
|
||||
|
||||
// Graph provides a bidirectional interface over gonum's directed graph implementation.
|
||||
// While the underlying gonum graph is directed, we overlay bidirectional semantics
|
||||
// by distinguishing between forward and reverse edges. Wanting and being wanted by
|
||||
// other units are related but different concepts that have different graph traversal
|
||||
// implications when Units update their status.
|
||||
//
|
||||
// The graph stores edge types to represent different relationships between units,
|
||||
// allowing for domain-specific semantics beyond simple connectivity.
|
||||
type Graph[EdgeType, VertexType comparable] struct {
|
||||
mu sync.RWMutex
|
||||
// The underlying gonum graph. It stores vertices and edges without knowing about the types of the vertices and edges.
|
||||
gonumGraph *simple.DirectedGraph
|
||||
// Maps vertices to their IDs so that a gonum vertex ID can be used to lookup the vertex type.
|
||||
vertexToID map[VertexType]int64
|
||||
// Maps vertex IDs to their types so that a vertex type can be used to lookup the gonum vertex ID.
|
||||
idToVertex map[int64]VertexType
|
||||
// The next ID to assign to a vertex.
|
||||
nextID int64
|
||||
// Store edge types by "fromID->toID" key. This is used to lookup the edge type for a given edge.
|
||||
edgeTypes map[string]EdgeType
|
||||
}
|
||||
|
||||
// Edge is a convenience type for representing an edge in the graph.
|
||||
// It encapsulates the from and to vertices and the edge type itself.
|
||||
type Edge[EdgeType, VertexType comparable] struct {
|
||||
From VertexType
|
||||
To VertexType
|
||||
Edge EdgeType
|
||||
}
|
||||
|
||||
// AddEdge adds an edge to the graph. It initializes the graph and metadata on first use,
|
||||
// checks for cycles, and adds the edge to the gonum graph.
|
||||
func (g *Graph[EdgeType, VertexType]) AddEdge(from, to VertexType, edge EdgeType) error {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
if g.gonumGraph == nil {
|
||||
g.gonumGraph = simple.NewDirectedGraph()
|
||||
g.vertexToID = make(map[VertexType]int64)
|
||||
g.idToVertex = make(map[int64]VertexType)
|
||||
g.edgeTypes = make(map[string]EdgeType)
|
||||
g.nextID = 1
|
||||
}
|
||||
|
||||
fromID := g.getOrCreateVertexID(from)
|
||||
toID := g.getOrCreateVertexID(to)
|
||||
|
||||
if g.canReach(to, from) {
|
||||
return xerrors.Errorf("adding edge (%v -> %v) would create a cycle", from, to)
|
||||
}
|
||||
|
||||
g.gonumGraph.SetEdge(simple.Edge{F: simple.Node(fromID), T: simple.Node(toID)})
|
||||
|
||||
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
|
||||
g.edgeTypes[edgeKey] = edge
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetForwardAdjacentVertices returns all the edges that originate from the given vertex.
|
||||
func (g *Graph[EdgeType, VertexType]) GetForwardAdjacentVertices(from VertexType) []Edge[EdgeType, VertexType] {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
|
||||
fromID, exists := g.vertexToID[from]
|
||||
if !exists {
|
||||
return []Edge[EdgeType, VertexType]{}
|
||||
}
|
||||
|
||||
edges := []Edge[EdgeType, VertexType]{}
|
||||
toNodes := g.gonumGraph.From(fromID)
|
||||
for toNodes.Next() {
|
||||
toID := toNodes.Node().ID()
|
||||
to := g.idToVertex[toID]
|
||||
|
||||
// Get the edge type
|
||||
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
|
||||
edgeType := g.edgeTypes[edgeKey]
|
||||
|
||||
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
|
||||
}
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// GetReverseAdjacentVertices returns all the edges that terminate at the given vertex.
|
||||
func (g *Graph[EdgeType, VertexType]) GetReverseAdjacentVertices(to VertexType) []Edge[EdgeType, VertexType] {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
|
||||
toID, exists := g.vertexToID[to]
|
||||
if !exists {
|
||||
return []Edge[EdgeType, VertexType]{}
|
||||
}
|
||||
|
||||
edges := []Edge[EdgeType, VertexType]{}
|
||||
fromNodes := g.gonumGraph.To(toID)
|
||||
for fromNodes.Next() {
|
||||
fromID := fromNodes.Node().ID()
|
||||
from := g.idToVertex[fromID]
|
||||
|
||||
// Get the edge type
|
||||
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
|
||||
edgeType := g.edgeTypes[edgeKey]
|
||||
|
||||
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
|
||||
}
|
||||
|
||||
return edges
|
||||
}
|
||||
|
||||
// getOrCreateVertexID returns the ID for a vertex, creating it if it doesn't exist.
|
||||
func (g *Graph[EdgeType, VertexType]) getOrCreateVertexID(vertex VertexType) int64 {
|
||||
if id, exists := g.vertexToID[vertex]; exists {
|
||||
return id
|
||||
}
|
||||
|
||||
id := g.nextID
|
||||
g.nextID++
|
||||
g.vertexToID[vertex] = id
|
||||
g.idToVertex[id] = vertex
|
||||
|
||||
// Add the node to the gonum graph
|
||||
g.gonumGraph.AddNode(simple.Node(id))
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
// canReach checks if there is a path from the start vertex to the end vertex.
|
||||
func (g *Graph[EdgeType, VertexType]) canReach(start, end VertexType) bool {
|
||||
if start == end {
|
||||
return true
|
||||
}
|
||||
|
||||
startID, startExists := g.vertexToID[start]
|
||||
endID, endExists := g.vertexToID[end]
|
||||
|
||||
if !startExists || !endExists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Use gonum's built-in path existence check
|
||||
return topo.PathExistsIn(g.gonumGraph, simple.Node(startID), simple.Node(endID))
|
||||
}
|
||||
|
||||
// ToDOT exports the graph to DOT format for visualization
|
||||
func (g *Graph[EdgeType, VertexType]) ToDOT(name string) (string, error) {
|
||||
g.mu.RLock()
|
||||
defer g.mu.RUnlock()
|
||||
|
||||
if g.gonumGraph == nil {
|
||||
return "", xerrors.New("graph is not initialized")
|
||||
}
|
||||
|
||||
// Marshal the graph to DOT format
|
||||
dotBytes, err := dot.Marshal(g.gonumGraph, name, "", " ")
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("failed to marshal graph to DOT: %w", err)
|
||||
}
|
||||
|
||||
return string(dotBytes), nil
|
||||
}
|
||||
@@ -0,0 +1,454 @@
|
||||
// Package unit_test provides tests for the unit package.
|
||||
//
|
||||
// DOT Graph Testing:
|
||||
// The graph tests use golden files for DOT representation verification.
|
||||
// To update the golden files:
|
||||
// make gen/golden-files
|
||||
//
|
||||
// The golden files contain the expected DOT representation and can be easily
|
||||
// inspected, version controlled, and updated when the graph structure changes.
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
)
|
||||
|
||||
type testGraphEdge string
|
||||
|
||||
const (
|
||||
testEdgeStarted testGraphEdge = "started"
|
||||
testEdgeCompleted testGraphEdge = "completed"
|
||||
)
|
||||
|
||||
type testGraphVertex struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type (
|
||||
testGraph = unit.Graph[testGraphEdge, *testGraphVertex]
|
||||
testEdge = unit.Edge[testGraphEdge, *testGraphVertex]
|
||||
)
|
||||
|
||||
// randInt generates a random integer in the range [0, limit).
|
||||
func randInt(limit int) int {
|
||||
if limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
n, err := cryptorand.Int63n(int64(limit))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return int(n)
|
||||
}
|
||||
|
||||
// UpdateGoldenFiles indicates golden files should be updated.
|
||||
// To update the golden files:
|
||||
// make gen/golden-files
|
||||
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
|
||||
|
||||
// assertDOTGraph requires that the graph's DOT representation matches the golden file
|
||||
func assertDOTGraph(t *testing.T, graph *testGraph, goldenName string) {
|
||||
t.Helper()
|
||||
|
||||
dot, err := graph.ToDOT(goldenName)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFile := filepath.Join("testdata", goldenName+".golden")
|
||||
if *UpdateGoldenFiles {
|
||||
t.Logf("update golden file for: %q: %s", goldenName, goldenFile)
|
||||
err := os.MkdirAll(filepath.Dir(goldenFile), 0o755)
|
||||
require.NoError(t, err, "want no error creating golden file directory")
|
||||
err = os.WriteFile(goldenFile, []byte(dot), 0o600)
|
||||
require.NoError(t, err, "update golden file")
|
||||
}
|
||||
|
||||
expected, err := os.ReadFile(goldenFile)
|
||||
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
|
||||
|
||||
// Normalize line endings for cross-platform compatibility
|
||||
expected = normalizeLineEndings(expected)
|
||||
normalizedDot := normalizeLineEndings([]byte(dot))
|
||||
|
||||
assert.Empty(t, cmp.Diff(string(expected), string(normalizedDot)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenFile)
|
||||
}
|
||||
|
||||
// normalizeLineEndings ensures that all line endings are normalized to \n.
|
||||
// Required for Windows compatibility.
|
||||
func normalizeLineEndings(content []byte) []byte {
|
||||
content = bytes.ReplaceAll(content, []byte("\r\n"), []byte("\n"))
|
||||
content = bytes.ReplaceAll(content, []byte("\r"), []byte("\n"))
|
||||
return content
|
||||
}
|
||||
|
||||
func TestGraph(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testFuncs := map[string]func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex]{
|
||||
"ForwardAndReverseEdges": func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex] {
|
||||
graph := &unit.Graph[testGraphEdge, *testGraphVertex]{}
|
||||
unit1 := &testGraphVertex{Name: "unit1"}
|
||||
unit2 := &testGraphVertex{Name: "unit2"}
|
||||
unit3 := &testGraphVertex{Name: "unit3"}
|
||||
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
err = graph.AddEdge(unit1, unit3, testEdgeStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check for forward edge
|
||||
vertices := graph.GetForwardAdjacentVertices(unit1)
|
||||
require.Len(t, vertices, 2)
|
||||
// Unit 1 depends on the completion of Unit2
|
||||
require.Contains(t, vertices, testEdge{
|
||||
From: unit1,
|
||||
To: unit2,
|
||||
Edge: testEdgeCompleted,
|
||||
})
|
||||
// Unit 1 depends on the start of Unit3
|
||||
require.Contains(t, vertices, testEdge{
|
||||
From: unit1,
|
||||
To: unit3,
|
||||
Edge: testEdgeStarted,
|
||||
})
|
||||
|
||||
// Check for reverse edges
|
||||
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
|
||||
require.Len(t, unit2ReverseEdges, 1)
|
||||
// Unit 2 must be completed before Unit 1 can start
|
||||
require.Contains(t, unit2ReverseEdges, testEdge{
|
||||
From: unit1,
|
||||
To: unit2,
|
||||
Edge: testEdgeCompleted,
|
||||
})
|
||||
|
||||
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
|
||||
require.Len(t, unit3ReverseEdges, 1)
|
||||
// Unit 3 must be started before Unit 1 can complete
|
||||
require.Contains(t, unit3ReverseEdges, testEdge{
|
||||
From: unit1,
|
||||
To: unit3,
|
||||
Edge: testEdgeStarted,
|
||||
})
|
||||
|
||||
return graph
|
||||
},
|
||||
"SelfReference": func(t *testing.T) *testGraph {
|
||||
graph := &testGraph{}
|
||||
unit1 := &testGraphVertex{Name: "unit1"}
|
||||
err := graph.AddEdge(unit1, unit1, testEdgeCompleted)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit1, unit1))
|
||||
|
||||
return graph
|
||||
},
|
||||
"Cycle": func(t *testing.T) *testGraph {
|
||||
graph := &testGraph{}
|
||||
unit1 := &testGraphVertex{Name: "unit1"}
|
||||
unit2 := &testGraphVertex{Name: "unit2"}
|
||||
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
err = graph.AddEdge(unit2, unit1, testEdgeStarted)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit2, unit1))
|
||||
|
||||
return graph
|
||||
},
|
||||
"MultipleDependenciesSameStatus": func(t *testing.T) *testGraph {
|
||||
graph := &testGraph{}
|
||||
unit1 := &testGraphVertex{Name: "unit1"}
|
||||
unit2 := &testGraphVertex{Name: "unit2"}
|
||||
unit3 := &testGraphVertex{Name: "unit3"}
|
||||
unit4 := &testGraphVertex{Name: "unit4"}
|
||||
|
||||
// Unit1 depends on completion of both unit2 and unit3 (same status type)
|
||||
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
err = graph.AddEdge(unit1, unit3, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unit1 also depends on starting of unit4 (different status type)
|
||||
err = graph.AddEdge(unit1, unit4, testEdgeStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that unit1 has 3 forward dependencies
|
||||
forwardEdges := graph.GetForwardAdjacentVertices(unit1)
|
||||
require.Len(t, forwardEdges, 3)
|
||||
|
||||
// Verify all expected dependencies exist
|
||||
expectedDependencies := []testEdge{
|
||||
{From: unit1, To: unit2, Edge: testEdgeCompleted},
|
||||
{From: unit1, To: unit3, Edge: testEdgeCompleted},
|
||||
{From: unit1, To: unit4, Edge: testEdgeStarted},
|
||||
}
|
||||
|
||||
for _, expected := range expectedDependencies {
|
||||
require.Contains(t, forwardEdges, expected)
|
||||
}
|
||||
|
||||
// Check reverse dependencies
|
||||
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
|
||||
require.Len(t, unit2ReverseEdges, 1)
|
||||
require.Contains(t, unit2ReverseEdges, testEdge{
|
||||
From: unit1, To: unit2, Edge: testEdgeCompleted,
|
||||
})
|
||||
|
||||
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
|
||||
require.Len(t, unit3ReverseEdges, 1)
|
||||
require.Contains(t, unit3ReverseEdges, testEdge{
|
||||
From: unit1, To: unit3, Edge: testEdgeCompleted,
|
||||
})
|
||||
|
||||
unit4ReverseEdges := graph.GetReverseAdjacentVertices(unit4)
|
||||
require.Len(t, unit4ReverseEdges, 1)
|
||||
require.Contains(t, unit4ReverseEdges, testEdge{
|
||||
From: unit1, To: unit4, Edge: testEdgeStarted,
|
||||
})
|
||||
|
||||
return graph
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testFunc := range testFuncs {
|
||||
var graph *testGraph
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
graph = testFunc(t)
|
||||
assertDOTGraph(t, graph, testName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphThreadSafety(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentReadWrite", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph := &testGraph{}
|
||||
var wg sync.WaitGroup
|
||||
const numWriters = 50
|
||||
const numReaders = 100
|
||||
const operationsPerWriter = 1000
|
||||
const operationsPerReader = 2000
|
||||
|
||||
barrier := make(chan struct{})
|
||||
// Launch writers
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(writerID int) {
|
||||
defer wg.Done()
|
||||
<-barrier
|
||||
for j := 0; j < operationsPerWriter; j++ {
|
||||
from := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j)}
|
||||
to := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j+1)}
|
||||
graph.AddEdge(from, to, testEdgeCompleted)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Launch readers
|
||||
readerResults := make([]struct {
|
||||
panicked bool
|
||||
readCount int
|
||||
}, numReaders)
|
||||
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func(readerID int) {
|
||||
defer wg.Done()
|
||||
<-barrier
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
readerResults[readerID].panicked = true
|
||||
}
|
||||
}()
|
||||
|
||||
readCount := 0
|
||||
for j := 0; j < operationsPerReader; j++ {
|
||||
// Create a test vertex and read
|
||||
testUnit := &testGraphVertex{Name: fmt.Sprintf("test-reader-%d-%d", readerID, j)}
|
||||
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
|
||||
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
|
||||
|
||||
// Just verify no panics (results may be nil for non-existent vertices)
|
||||
_ = forwardEdges
|
||||
_ = reverseEdges
|
||||
readCount++
|
||||
}
|
||||
readerResults[readerID].readCount = readCount
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(barrier)
|
||||
wg.Wait()
|
||||
|
||||
// Verify no panics occurred in readers
|
||||
for i, result := range readerResults {
|
||||
require.False(t, result.panicked, "reader %d panicked", i)
|
||||
require.Equal(t, operationsPerReader, result.readCount, "reader %d should have performed expected reads", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentCycleDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph := &testGraph{}
|
||||
|
||||
// Pre-create chain: A→B→C→D
|
||||
unitA := &testGraphVertex{Name: "A"}
|
||||
unitB := &testGraphVertex{Name: "B"}
|
||||
unitC := &testGraphVertex{Name: "C"}
|
||||
unitD := &testGraphVertex{Name: "D"}
|
||||
|
||||
err := graph.AddEdge(unitA, unitB, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
err = graph.AddEdge(unitB, unitC, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
err = graph.AddEdge(unitC, unitD, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
barrier := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 50
|
||||
cycleErrors := make([]error, numGoroutines)
|
||||
|
||||
// Launch goroutines trying to add D→A (creates cycle)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
<-barrier
|
||||
err := graph.AddEdge(unitD, unitA, testEdgeCompleted)
|
||||
cycleErrors[goroutineID] = err
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(barrier)
|
||||
wg.Wait()
|
||||
|
||||
// Verify all attempts correctly returned cycle error
|
||||
for i, err := range cycleErrors {
|
||||
require.Error(t, err, "goroutine %d should have detected cycle", i)
|
||||
require.Contains(t, err.Error(), "would create a cycle")
|
||||
}
|
||||
|
||||
// Verify graph remains valid (original chain intact)
|
||||
dot, err := graph.ToDOT("test")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, dot)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentToDOT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph := &testGraph{}
|
||||
|
||||
// Pre-populate graph
|
||||
for i := 0; i < 20; i++ {
|
||||
from := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i)}
|
||||
to := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i+1)}
|
||||
err := graph.AddEdge(from, to, testEdgeCompleted)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
barrier := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
const numReaders = 100
|
||||
const numWriters = 20
|
||||
dotResults := make([]string, numReaders)
|
||||
|
||||
// Launch readers calling ToDOT
|
||||
dotErrors := make([]error, numReaders)
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func(readerID int) {
|
||||
defer wg.Done()
|
||||
<-barrier
|
||||
dot, err := graph.ToDOT(fmt.Sprintf("test-%d", readerID))
|
||||
dotErrors[readerID] = err
|
||||
if err == nil {
|
||||
dotResults[readerID] = dot
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Launch writers adding edges
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(writerID int) {
|
||||
defer wg.Done()
|
||||
<-barrier
|
||||
from := &testGraphVertex{Name: fmt.Sprintf("writer-dot-%d", writerID)}
|
||||
to := &testGraphVertex{Name: fmt.Sprintf("writer-dot-target-%d", writerID)}
|
||||
graph.AddEdge(from, to, testEdgeCompleted)
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(barrier)
|
||||
wg.Wait()
|
||||
|
||||
// Verify no errors occurred during DOT generation
|
||||
for i, err := range dotErrors {
|
||||
require.NoError(t, err, "DOT generation error at index %d", i)
|
||||
}
|
||||
|
||||
// Verify all DOT results are valid
|
||||
for i, dot := range dotResults {
|
||||
require.NotEmpty(t, dot, "DOT result %d should not be empty", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGraph_ConcurrentMixedOperations(b *testing.B) {
|
||||
graph := &testGraph{}
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 200
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Launch goroutines performing random operations
|
||||
for j := 0; j < numGoroutines; j++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
operationCount := 0
|
||||
|
||||
for operationCount < 50 {
|
||||
operation := float32(randInt(100)) / 100.0
|
||||
|
||||
if operation < 0.6 { // 60% reads
|
||||
// Read operation
|
||||
testUnit := &testGraphVertex{Name: fmt.Sprintf("bench-read-%d-%d", goroutineID, operationCount)}
|
||||
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
|
||||
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
|
||||
|
||||
// Just verify no panics (results may be nil for non-existent vertices)
|
||||
_ = forwardEdges
|
||||
_ = reverseEdges
|
||||
} else { // 40% writes
|
||||
// Write operation
|
||||
from := &testGraphVertex{Name: fmt.Sprintf("bench-write-%d-%d", goroutineID, operationCount)}
|
||||
to := &testGraphVertex{Name: fmt.Sprintf("bench-write-target-%d-%d", goroutineID, operationCount)}
|
||||
graph.AddEdge(from, to, testEdgeCompleted)
|
||||
}
|
||||
|
||||
operationCount++
|
||||
}
|
||||
}(j)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package unit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// ErrConsumerNotFound is returned when a consumer ID is not registered.
|
||||
var ErrConsumerNotFound = xerrors.New("consumer not found")
|
||||
|
||||
// ErrConsumerAlreadyRegistered is returned when a consumer ID is already registered.
|
||||
var ErrConsumerAlreadyRegistered = xerrors.New("consumer already registered")
|
||||
|
||||
// ErrCannotUpdateOtherConsumer is returned when attempting to update another consumer's status.
|
||||
var ErrCannotUpdateOtherConsumer = xerrors.New("cannot update other consumer's status")
|
||||
|
||||
// ErrDependenciesNotSatisfied is returned when a consumer's dependencies are not satisfied.
|
||||
var ErrDependenciesNotSatisfied = xerrors.New("unit dependencies not satisfied")
|
||||
|
||||
// ErrSameStatusAlreadySet is returned when attempting to set the same status as the current status.
|
||||
var ErrSameStatusAlreadySet = xerrors.New("same status already set")
|
||||
|
||||
// Status constants for dependency tracking
|
||||
const (
|
||||
StatusStarted = "started"
|
||||
StatusComplete = "completed"
|
||||
)
|
||||
|
||||
// dependencyVertex represents a vertex in the dependency graph that is associated with a consumer.
|
||||
type dependencyVertex[ConsumerID comparable] struct {
|
||||
ID ConsumerID
|
||||
}
|
||||
|
||||
// Dependency represents a dependency relationship between consumers.
|
||||
type Dependency[StatusType, ConsumerID comparable] struct {
|
||||
Consumer ConsumerID
|
||||
DependsOn ConsumerID
|
||||
RequiredStatus StatusType
|
||||
CurrentStatus StatusType
|
||||
IsSatisfied bool
|
||||
}
|
||||
|
||||
// Manager provides reactive dependency tracking over a Graph.
|
||||
// It manages consumer registration, dependency relationships, and status updates
|
||||
// with automatic recalculation of readiness when dependencies are satisfied.
|
||||
type Manager[StatusType, ConsumerID comparable] struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying graph that stores dependency relationships
|
||||
graph *Graph[StatusType, *dependencyVertex[ConsumerID]]
|
||||
|
||||
// Track current status of each consumer
|
||||
consumerStatus map[ConsumerID]StatusType
|
||||
|
||||
// Track readiness state (cached to avoid repeated graph traversal)
|
||||
consumerReadiness map[ConsumerID]bool
|
||||
|
||||
// Track which consumers are registered
|
||||
registeredConsumers map[ConsumerID]bool
|
||||
|
||||
// Store vertex instances for each consumer to ensure consistent references
|
||||
consumerVertices map[ConsumerID]*dependencyVertex[ConsumerID]
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager instance.
|
||||
func NewManager[StatusType, ConsumerID comparable]() *Manager[StatusType, ConsumerID] {
|
||||
return &Manager[StatusType, ConsumerID]{
|
||||
graph: &Graph[StatusType, *dependencyVertex[ConsumerID]]{},
|
||||
consumerStatus: make(map[ConsumerID]StatusType),
|
||||
consumerReadiness: make(map[ConsumerID]bool),
|
||||
registeredConsumers: make(map[ConsumerID]bool),
|
||||
consumerVertices: make(map[ConsumerID]*dependencyVertex[ConsumerID]),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new consumer as a vertex in the dependency graph.
|
||||
func (dt *Manager[StatusType, ConsumerID]) Register(id ConsumerID) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if dt.registeredConsumers[id] {
|
||||
return ErrConsumerAlreadyRegistered
|
||||
}
|
||||
|
||||
// Create and store the vertex for this consumer
|
||||
vertex := &dependencyVertex[ConsumerID]{ID: id}
|
||||
dt.consumerVertices[id] = vertex
|
||||
dt.registeredConsumers[id] = true
|
||||
dt.consumerReadiness[id] = true // New consumers start as ready (no dependencies)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDependency adds a dependency relationship between consumers.
|
||||
// The consumer depends on the dependsOn consumer reaching the requiredStatus.
|
||||
func (dt *Manager[StatusType, ConsumerID]) AddDependency(consumer ConsumerID, dependsOn ConsumerID, requiredStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return xerrors.Errorf("consumer %v is not registered", consumer)
|
||||
}
|
||||
if !dt.registeredConsumers[dependsOn] {
|
||||
return xerrors.Errorf("consumer %v is not registered", dependsOn)
|
||||
}
|
||||
|
||||
// Get the stored vertices for both consumers
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependsOnVertex := dt.consumerVertices[dependsOn]
|
||||
|
||||
// Add the dependency edge to the graph
|
||||
// The edge goes from consumer to dependsOn, representing the dependency
|
||||
err := dt.graph.AddEdge(consumerVertex, dependsOnVertex, requiredStatus)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to add dependency: %w", err)
|
||||
}
|
||||
|
||||
// Recalculate readiness for the consumer since it now has a dependency
|
||||
dt.recalculateReadinessUnsafe(consumer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates a consumer's status and recalculates readiness for affected dependents.
|
||||
func (dt *Manager[StatusType, ConsumerID]) UpdateStatus(consumer ConsumerID, newStatus StatusType) error {
|
||||
dt.mu.Lock()
|
||||
defer dt.mu.Unlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return ErrConsumerNotFound
|
||||
}
|
||||
|
||||
// Update the consumer's status
|
||||
if dt.consumerStatus[consumer] == newStatus {
|
||||
return ErrSameStatusAlreadySet
|
||||
}
|
||||
dt.consumerStatus[consumer] = newStatus
|
||||
|
||||
// Get all consumers that depend on this one (reverse adjacent vertices)
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
dependentEdges := dt.graph.GetReverseAdjacentVertices(consumerVertex)
|
||||
|
||||
// Recalculate readiness for all dependents
|
||||
for _, edge := range dependentEdges {
|
||||
dt.recalculateReadinessUnsafe(edge.From.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady checks if all dependencies for a consumer are satisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) IsReady(consumer ConsumerID) (bool, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return false, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
return dt.consumerReadiness[consumer], nil
|
||||
}
|
||||
|
||||
// GetUnmetDependencies returns a list of unsatisfied dependencies for a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetUnmetDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var unmetDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
if !isSatisfied {
|
||||
unmetDependencies = append(unmetDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return unmetDependencies, nil
|
||||
}
|
||||
|
||||
// recalculateReadinessUnsafe recalculates the readiness state for a consumer.
|
||||
// This method assumes the caller holds the write lock.
|
||||
func (dt *Manager[StatusType, ConsumerID]) recalculateReadinessUnsafe(consumer ConsumerID) {
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
// If there are no dependencies, the consumer is ready
|
||||
if len(forwardEdges) == 0 {
|
||||
dt.consumerReadiness[consumer] = true
|
||||
return
|
||||
}
|
||||
|
||||
// Check if all dependencies are satisfied
|
||||
allSatisfied := true
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists || currentStatus != requiredStatus {
|
||||
allSatisfied = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dt.consumerReadiness[consumer] = allSatisfied
|
||||
}
|
||||
|
||||
// GetGraph returns the underlying graph for visualization and debugging.
|
||||
// This should be used carefully as it exposes the internal graph structure.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetGraph() *Graph[StatusType, *dependencyVertex[ConsumerID]] {
|
||||
return dt.graph
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of a consumer.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetStatus(consumer ConsumerID) (StatusType, error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
status, exists := dt.consumerStatus[consumer]
|
||||
if !exists {
|
||||
var zeroStatus StatusType
|
||||
return zeroStatus, nil
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// GetAllDependencies returns all dependencies for a consumer, both satisfied and unsatisfied.
|
||||
func (dt *Manager[StatusType, ConsumerID]) GetAllDependencies(consumer ConsumerID) ([]Dependency[StatusType, ConsumerID], error) {
|
||||
dt.mu.RLock()
|
||||
defer dt.mu.RUnlock()
|
||||
|
||||
if !dt.registeredConsumers[consumer] {
|
||||
return nil, ErrConsumerNotFound
|
||||
}
|
||||
|
||||
consumerVertex := dt.consumerVertices[consumer]
|
||||
forwardEdges := dt.graph.GetForwardAdjacentVertices(consumerVertex)
|
||||
|
||||
var allDependencies []Dependency[StatusType, ConsumerID]
|
||||
|
||||
for _, edge := range forwardEdges {
|
||||
dependsOnConsumer := edge.To.ID
|
||||
requiredStatus := edge.Edge
|
||||
currentStatus, exists := dt.consumerStatus[dependsOnConsumer]
|
||||
if !exists {
|
||||
// If the dependency consumer has no status, it's not satisfied
|
||||
var zeroStatus StatusType
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: zeroStatus, // Zero value
|
||||
IsSatisfied: false,
|
||||
})
|
||||
} else {
|
||||
isSatisfied := currentStatus == requiredStatus
|
||||
allDependencies = append(allDependencies, Dependency[StatusType, ConsumerID]{
|
||||
Consumer: consumer,
|
||||
DependsOn: dependsOnConsumer,
|
||||
RequiredStatus: requiredStatus,
|
||||
CurrentStatus: currentStatus,
|
||||
IsSatisfied: isSatisfied,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return allDependencies, nil
|
||||
}
|
||||
|
||||
// ExportDOT exports the dependency graph to DOT format for visualization.
|
||||
func (dt *Manager[StatusType, ConsumerID]) ExportDOT(name string) (string, error) {
|
||||
return dt.graph.ToDOT(name)
|
||||
}
|
||||
@@ -0,0 +1,691 @@
|
||||
package unit_test
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
)
|
||||
|
||||
type testStatus string
|
||||
|
||||
const (
|
||||
statusStarted testStatus = "started"
|
||||
statusRunning testStatus = "running"
|
||||
statusCompleted testStatus = "completed"
|
||||
)
|
||||
|
||||
type testConsumerID string
|
||||
|
||||
const (
|
||||
consumerA testConsumerID = "serviceA"
|
||||
consumerB testConsumerID = "serviceB"
|
||||
consumerC testConsumerID = "serviceC"
|
||||
consumerD testConsumerID = "serviceD"
|
||||
)
|
||||
|
||||
func TestDependencyTracker_Register(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
t.Run("RegisterNewConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Consumer should be ready initially (no dependencies)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("RegisterDuplicateConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tracker.Register(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
|
||||
t.Run("RegisterMultipleConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// All should be ready initially
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_AddDependency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AddDependencyBetweenRegisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should no longer be ready (depends on B)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// B should still be ready (no dependencies)
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency to unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("AddDependencyFromUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add dependency from unregistered consumer
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_UpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusTriggersReadinessRecalculation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially A is not ready
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("LinearChainDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create chain: A depends on B being "started", B depends on C being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only C is ready
|
||||
ready, err := tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - B should become ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "started" - A should become ready
|
||||
err = tracker.UpdateStatus(consumerB, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_GetUnmetDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithNoDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithUnsatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, unmet, 1)
|
||||
|
||||
assert.Equal(t, consumerA, unmet[0].Consumer)
|
||||
assert.Equal(t, consumerB, unmet[0].DependsOn)
|
||||
assert.Equal(t, statusRunning, unmet[0].RequiredStatus)
|
||||
assert.False(t, unmet[0].IsSatisfied)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForConsumerWithSatisfiedDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running"
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesForUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ConcurrentOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConcurrentStatusUpdates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create dependencies: A depends on B, B depends on C, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch goroutines that update statuses
|
||||
errors := make([]error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Update D to completed (should make C ready)
|
||||
err := tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update C to started (should make B ready)
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
|
||||
// Update B to running (should make A ready)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
if err != nil {
|
||||
errors[goroutineID] = err
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for any errors in goroutines
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "goroutine %d had error", i)
|
||||
}
|
||||
|
||||
// All consumers should be ready after the updates
|
||||
for _, consumer := range consumers {
|
||||
ready, err := tracker.IsReady(consumer)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentReadinessChecks", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const numGoroutines = 20
|
||||
|
||||
// Launch goroutines that check readiness
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(goroutineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Check readiness multiple times
|
||||
for j := 0; j < 10; j++ {
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
// Initially should be false, then true after B is updated
|
||||
_ = ready
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
// B should always be ready (no dependencies)
|
||||
assert.True(t, ready)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Update B to "running" in the middle of readiness checks
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_MultipleDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ConsumerWithMultipleDependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// A depends on B being "running" AND C being "started"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A should not be ready (depends on both B and C)
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C too)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("ComplexDependencyChain", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph:
|
||||
// A depends on B being "running" AND C being "started"
|
||||
// B depends on D being "completed"
|
||||
// C depends on D being "completed"
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initially only D is ready
|
||||
ready, err := tracker.IsReady(consumerD)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update D to "completed" - B and C should become ready
|
||||
err = tracker.UpdateStatus(consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerB)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerC)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update B to "running" - A should still not be ready (needs C)
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "started" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
|
||||
t.Run("DifferentStatusTypes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerC)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B being "running" AND C being "completed"
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update B to "running" but not C - A should not be ready
|
||||
err = tracker.UpdateStatus(consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ready)
|
||||
|
||||
// Update C to "completed" - A should now be ready
|
||||
err = tracker.UpdateStatus(consumerC, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
ready, err = tracker.IsReady(consumerA)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ready)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ErrorCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("UpdateStatusWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
err := tracker.UpdateStatus(consumerA, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
})
|
||||
|
||||
t.Run("IsReadyWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
ready, err := tracker.IsReady(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.False(t, ready)
|
||||
})
|
||||
|
||||
t.Run("GetUnmetDependenciesWithUnregisteredConsumer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
unmet, err := tracker.GetUnmetDependencies(consumerA)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, unit.ErrConsumerNotFound, err)
|
||||
assert.Nil(t, unmet)
|
||||
})
|
||||
|
||||
t.Run("AddDependencyWithUnregisteredConsumers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Try to add dependency with unregistered consumers
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not registered")
|
||||
})
|
||||
|
||||
t.Run("CyclicDependencyDetection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// A depends on B
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to make B depend on A (creates cycle)
|
||||
err = tracker.AddDependency(consumerB, consumerA, statusStarted)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "would create a cycle")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDependencyTracker_ToDOT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExportSimpleGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register consumers
|
||||
err := tracker.Register(consumerA)
|
||||
require.NoError(t, err)
|
||||
err = tracker.Register(consumerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add dependency
|
||||
err = tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("test")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
|
||||
t.Run("ExportComplexGraph", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracker := unit.NewManager[testStatus, testConsumerID]()
|
||||
|
||||
// Register all consumers
|
||||
consumers := []testConsumerID{consumerA, consumerB, consumerC, consumerD}
|
||||
for _, consumer := range consumers {
|
||||
err := tracker.Register(consumer)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create complex dependency graph
|
||||
// A depends on B and C, B depends on D, C depends on D
|
||||
err := tracker.AddDependency(consumerA, consumerB, statusRunning)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerA, consumerC, statusStarted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerB, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
err = tracker.AddDependency(consumerC, consumerD, statusCompleted)
|
||||
require.NoError(t, err)
|
||||
|
||||
dot, err := tracker.ExportDOT("complex")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dot)
|
||||
assert.Contains(t, dot, "digraph")
|
||||
})
|
||||
}
|
||||
Vendored
+8
@@ -0,0 +1,8 @@
|
||||
strict digraph Cycle {
|
||||
// Node definitions.
|
||||
1;
|
||||
2;
|
||||
|
||||
// Edge definitions.
|
||||
1 -> 2;
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
strict digraph ForwardAndReverseEdges {
|
||||
// Node definitions.
|
||||
1;
|
||||
2;
|
||||
3;
|
||||
|
||||
// Edge definitions.
|
||||
1 -> 2;
|
||||
1 -> 3;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
strict digraph MultipleDependenciesSameStatus {
|
||||
// Node definitions.
|
||||
1;
|
||||
2;
|
||||
3;
|
||||
4;
|
||||
|
||||
// Edge definitions.
|
||||
1 -> 2;
|
||||
1 -> 3;
|
||||
1 -> 4;
|
||||
}
|
||||
+4
@@ -0,0 +1,4 @@
|
||||
strict digraph SelfReference {
|
||||
// Node definitions.
|
||||
1;
|
||||
}
|
||||
@@ -56,6 +56,7 @@ func workspaceAgent() *serpent.Command {
|
||||
devcontainers bool
|
||||
devcontainerProjectDiscovery bool
|
||||
devcontainerDiscoveryAutostart bool
|
||||
socketPath string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
@@ -297,6 +298,7 @@ func workspaceAgent() *serpent.Command {
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
|
||||
},
|
||||
SocketPath: socketPath,
|
||||
})
|
||||
|
||||
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
|
||||
@@ -449,6 +451,12 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: "Allow the agent to autostart devcontainer projects it discovers based on their configuration.",
|
||||
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
|
||||
},
|
||||
{
|
||||
Flag: "socket-path",
|
||||
Env: "CODER_AGENT_SOCKET_PATH",
|
||||
Description: "Specify the path for the agent socket.",
|
||||
Value: serpent.StringOf(&socketPath),
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
|
||||
+18
-9
@@ -296,22 +296,23 @@ func renderTable(out any, sort string, headers table.Row, filterColumns []string
|
||||
// returned. If the table tag is malformed, an error is returned.
|
||||
//
|
||||
// The returned name is transformed from "snake_case" to "normal text".
|
||||
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName bool, err error) {
|
||||
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName, emptyNil bool, err error) {
|
||||
tags, err := structtag.Parse(string(field.Tag))
|
||||
if err != nil {
|
||||
return "", false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
|
||||
return "", false, false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
|
||||
}
|
||||
|
||||
tag, err := tags.Get("table")
|
||||
if err != nil || tag.Name == "-" {
|
||||
// tags.Get only returns an error if the tag is not found.
|
||||
return "", false, false, false, false, nil
|
||||
return "", false, false, false, false, false, nil
|
||||
}
|
||||
|
||||
defaultSortOpt := false
|
||||
noSortOpt = false
|
||||
recursiveOpt := false
|
||||
skipParentNameOpt := false
|
||||
emptyNilOpt := false
|
||||
for _, opt := range tag.Options {
|
||||
switch opt {
|
||||
case "default_sort":
|
||||
@@ -326,12 +327,14 @@ func parseTableStructTag(field reflect.StructField) (name string, defaultSort, n
|
||||
// make sure the child name is unique across all nested structs in the parent.
|
||||
recursiveOpt = true
|
||||
skipParentNameOpt = true
|
||||
case "empty_nil":
|
||||
emptyNilOpt = true
|
||||
default:
|
||||
return "", false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
|
||||
return "", false, false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, nil
|
||||
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, emptyNilOpt, nil
|
||||
}
|
||||
|
||||
func isStructOrStructPointer(t reflect.Type) bool {
|
||||
@@ -358,7 +361,7 @@ func typeToTableHeaders(t reflect.Type, requireDefault bool) ([]string, string,
|
||||
noSortOpt := false
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
name, defaultSort, noSort, recursive, skip, err := parseTableStructTag(field)
|
||||
name, defaultSort, noSort, recursive, skip, _, err := parseTableStructTag(field)
|
||||
if err != nil {
|
||||
return nil, "", xerrors.Errorf("parse struct tags for field %q in type %q: %w", field.Name, t.String(), err)
|
||||
}
|
||||
@@ -435,7 +438,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Type().Field(i)
|
||||
fieldVal := val.Field(i)
|
||||
name, _, _, recursive, skip, err := parseTableStructTag(field)
|
||||
name, _, _, recursive, skip, emptyNil, err := parseTableStructTag(field)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse struct tags for field %q in type %T: %w", field.Name, val, err)
|
||||
}
|
||||
@@ -443,8 +446,14 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Recurse if it's a struct.
|
||||
fieldType := field.Type
|
||||
|
||||
// If empty_nil is set and this is a nil pointer, use a zero value.
|
||||
if emptyNil && fieldVal.Kind() == reflect.Pointer && fieldVal.IsNil() {
|
||||
fieldVal = reflect.New(fieldType.Elem())
|
||||
}
|
||||
|
||||
// Recurse if it's a struct.
|
||||
if recursive {
|
||||
if !isStructOrStructPointer(fieldType) {
|
||||
return nil, xerrors.Errorf("field %q in type %q is marked as recursive but does not contain a struct or a pointer to a struct", field.Name, fieldType.String())
|
||||
@@ -467,7 +476,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
|
||||
}
|
||||
|
||||
// Otherwise, we just use the field value.
|
||||
row[name] = val.Field(i).Interface()
|
||||
row[name] = fieldVal.Interface()
|
||||
}
|
||||
|
||||
return row, nil
|
||||
|
||||
@@ -400,6 +400,78 @@ foo <nil> 10 [a, b, c] foo1 11 foo2 12 fo
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("EmptyNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type emptyNilTest struct {
|
||||
Name string `table:"name,default_sort"`
|
||||
EmptyOnNil *string `table:"empty_on_nil,empty_nil"`
|
||||
NormalBehavior *string `table:"normal_behavior"`
|
||||
}
|
||||
|
||||
value := "value"
|
||||
in := []emptyNilTest{
|
||||
{
|
||||
Name: "has_value",
|
||||
EmptyOnNil: &value,
|
||||
NormalBehavior: &value,
|
||||
},
|
||||
{
|
||||
Name: "has_nil",
|
||||
EmptyOnNil: nil,
|
||||
NormalBehavior: nil,
|
||||
},
|
||||
}
|
||||
|
||||
expected := `
|
||||
NAME EMPTY ON NIL NORMAL BEHAVIOR
|
||||
has_nil <nil>
|
||||
has_value value value
|
||||
`
|
||||
|
||||
out, err := cliui.DisplayTable(in, "", nil)
|
||||
log.Println("rendered table:\n" + out)
|
||||
require.NoError(t, err)
|
||||
compareTables(t, expected, out)
|
||||
})
|
||||
|
||||
t.Run("EmptyNilWithRecursiveInline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type nestedData struct {
|
||||
Name string `table:"name"`
|
||||
}
|
||||
|
||||
type inlineTest struct {
|
||||
Nested *nestedData `table:"ignored,recursive_inline,empty_nil"`
|
||||
Count int `table:"count,default_sort"`
|
||||
}
|
||||
|
||||
in := []inlineTest{
|
||||
{
|
||||
Nested: &nestedData{
|
||||
Name: "alice",
|
||||
},
|
||||
Count: 1,
|
||||
},
|
||||
{
|
||||
Nested: nil,
|
||||
Count: 2,
|
||||
},
|
||||
}
|
||||
|
||||
expected := `
|
||||
NAME COUNT
|
||||
alice 1
|
||||
2
|
||||
`
|
||||
|
||||
out, err := cliui.DisplayTable(in, "", nil)
|
||||
log.Println("rendered table:\n" + out)
|
||||
require.NoError(t, err)
|
||||
compareTables(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
// compareTables normalizes the incoming table lines
|
||||
|
||||
@@ -185,9 +185,6 @@ func TestDelete(t *testing.T) {
|
||||
|
||||
t.Run("WarnNoProvisioners", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
store, ps, db := dbtestutil.NewDBWithSQLDB(t)
|
||||
client, closeDaemon := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{
|
||||
@@ -228,9 +225,6 @@ func TestDelete(t *testing.T) {
|
||||
|
||||
t.Run("Prebuilt workspace delete permissions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
// Setup
|
||||
db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
|
||||
|
||||
+12
-324
@@ -29,7 +29,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
@@ -40,7 +39,6 @@ import (
|
||||
"github.com/coder/coder/v2/scaletest/dashboard"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/coder/v2/scaletest/notifications"
|
||||
"github.com/coder/coder/v2/scaletest/reconnectingpty"
|
||||
"github.com/coder/coder/v2/scaletest/workspacebuild"
|
||||
"github.com/coder/coder/v2/scaletest/workspacetraffic"
|
||||
@@ -66,6 +64,7 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestWorkspaceTraffic(),
|
||||
r.scaletestAutostart(),
|
||||
r.scaletestNotifications(),
|
||||
r.scaletestSMTP(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1921,259 +1920,6 @@ func (r *RootCmd) scaletestAutostart() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) scaletestNotifications() *serpent.Command {
|
||||
var (
|
||||
userCount int64
|
||||
ownerUserPercentage float64
|
||||
notificationTimeout time.Duration
|
||||
dialTimeout time.Duration
|
||||
noCleanup bool
|
||||
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
|
||||
// This test requires unlimited concurrency.
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "notifications",
|
||||
Short: "Simulate notification delivery by creating many users listening to notifications.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if userCount <= 0 {
|
||||
return xerrors.Errorf("--user-count must be greater than 0")
|
||||
}
|
||||
|
||||
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
|
||||
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
|
||||
}
|
||||
|
||||
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
|
||||
if ownerUserCount == 0 && ownerUserPercentage > 0 {
|
||||
ownerUserCount = 1
|
||||
}
|
||||
regularUserCount := userCount - ownerUserCount
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("could not parse --output flags")
|
||||
}
|
||||
|
||||
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := notifications.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
|
||||
if err := closeTracing(ctx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
// Wait for prometheus metrics to be scraped
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
|
||||
|
||||
dialBarrier := &sync.WaitGroup{}
|
||||
ownerWatchBarrier := &sync.WaitGroup{}
|
||||
dialBarrier.Add(int(userCount))
|
||||
ownerWatchBarrier.Add(int(ownerUserCount))
|
||||
|
||||
expectedNotifications := map[uuid.UUID]chan time.Time{
|
||||
notificationsLib.TemplateUserAccountCreated: make(chan time.Time, 1),
|
||||
notificationsLib.TemplateUserAccountDeleted: make(chan time.Time, 1),
|
||||
}
|
||||
|
||||
configs := make([]notifications.Config, 0, userCount)
|
||||
for range ownerUserCount {
|
||||
config := notifications.Config{
|
||||
User: createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
},
|
||||
Roles: []string{codersdk.RoleOwner},
|
||||
NotificationTimeout: notificationTimeout,
|
||||
DialTimeout: dialTimeout,
|
||||
DialBarrier: dialBarrier,
|
||||
ReceivingWatchBarrier: ownerWatchBarrier,
|
||||
ExpectedNotifications: expectedNotifications,
|
||||
Metrics: metrics,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
configs = append(configs, config)
|
||||
}
|
||||
for range regularUserCount {
|
||||
config := notifications.Config{
|
||||
User: createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
},
|
||||
Roles: []string{},
|
||||
NotificationTimeout: notificationTimeout,
|
||||
DialTimeout: dialTimeout,
|
||||
DialBarrier: dialBarrier,
|
||||
ReceivingWatchBarrier: ownerWatchBarrier,
|
||||
Metrics: metrics,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
configs = append(configs, config)
|
||||
}
|
||||
|
||||
go triggerUserNotifications(
|
||||
ctx,
|
||||
logger,
|
||||
client,
|
||||
me.OrganizationIDs[0],
|
||||
dialBarrier,
|
||||
dialTimeout,
|
||||
expectedNotifications,
|
||||
)
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i, config := range configs {
|
||||
id := strconv.Itoa(i)
|
||||
name := fmt.Sprintf("notifications-%s", id)
|
||||
var runner harness.Runnable = notifications.NewRunner(client, config)
|
||||
if tracingEnabled {
|
||||
runner = &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
spanName: name,
|
||||
runner: runner,
|
||||
}
|
||||
}
|
||||
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "user-count",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
|
||||
Description: "Required: Total number of users to create.",
|
||||
Value: serpent.Int64Of(&userCount),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "owner-user-percentage",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
|
||||
Default: "20.0",
|
||||
Description: "Percentage of users to assign Owner role to (0-100).",
|
||||
Value: serpent.Float64Of(&ownerUserPercentage),
|
||||
},
|
||||
{
|
||||
Flag: "notification-timeout",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
|
||||
Default: "5m",
|
||||
Description: "How long to wait for notifications after triggering.",
|
||||
Value: serpent.DurationOf(¬ificationTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "dial-timeout",
|
||||
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
|
||||
Default: "2m",
|
||||
Description: "Timeout for dialing the notification websocket endpoint.",
|
||||
Value: serpent.DurationOf(&dialTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
}
|
||||
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
type runnableTraceWrapper struct {
|
||||
tracer trace.Tracer
|
||||
spanName string
|
||||
@@ -2183,8 +1929,9 @@ type runnableTraceWrapper struct {
|
||||
}
|
||||
|
||||
var (
|
||||
_ harness.Runnable = &runnableTraceWrapper{}
|
||||
_ harness.Cleanable = &runnableTraceWrapper{}
|
||||
_ harness.Runnable = &runnableTraceWrapper{}
|
||||
_ harness.Cleanable = &runnableTraceWrapper{}
|
||||
_ harness.Collectable = &runnableTraceWrapper{}
|
||||
)
|
||||
|
||||
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
@@ -2226,6 +1973,14 @@ func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string, logs io.W
|
||||
return c.Cleanup(ctx, id, logs)
|
||||
}
|
||||
|
||||
func (r *runnableTraceWrapper) GetMetrics() map[string]any {
|
||||
c, ok := r.runner.(harness.Collectable)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return c.GetMetrics()
|
||||
}
|
||||
|
||||
func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner, template string) ([]codersdk.Workspace, int, error) {
|
||||
var (
|
||||
pageNumber = 0
|
||||
@@ -2374,73 +2129,6 @@ func parseTargetRange(name, targets string) (start, end int, err error) {
|
||||
return start, end, nil
|
||||
}
|
||||
|
||||
// triggerUserNotifications waits for all test users to connect,
|
||||
// then creates and deletes a test user to trigger notification events for testing.
|
||||
func triggerUserNotifications(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
client *codersdk.Client,
|
||||
orgID uuid.UUID,
|
||||
dialBarrier *sync.WaitGroup,
|
||||
dialTimeout time.Duration,
|
||||
expectedNotifications map[uuid.UUID]chan time.Time,
|
||||
) {
|
||||
logger.Info(ctx, "waiting for all users to connect")
|
||||
|
||||
// Wait for all users to connect
|
||||
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
dialBarrier.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Info(ctx, "all users connected")
|
||||
case <-waitCtx.Done():
|
||||
if waitCtx.Err() == context.DeadlineExceeded {
|
||||
logger.Error(ctx, "timeout waiting for users to connect")
|
||||
} else {
|
||||
logger.Info(ctx, "context canceled while waiting for users")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
triggerUsername = "scaletest-trigger-user"
|
||||
triggerEmail = "scaletest-trigger@example.com"
|
||||
)
|
||||
|
||||
logger.Info(ctx, "creating test user to test notifications",
|
||||
slog.F("username", triggerUsername),
|
||||
slog.F("email", triggerEmail),
|
||||
slog.F("org_id", orgID))
|
||||
|
||||
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{orgID},
|
||||
Username: triggerUsername,
|
||||
Email: triggerEmail,
|
||||
Password: "test-password-123",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "create test user", slog.Error(err))
|
||||
return
|
||||
}
|
||||
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
|
||||
|
||||
err = client.DeleteUser(ctx, testUser.ID)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "delete test user", slog.Error(err))
|
||||
return
|
||||
}
|
||||
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
|
||||
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
|
||||
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
|
||||
}
|
||||
|
||||
func createWorkspaceAppConfig(client *codersdk.Client, appHost, app string, workspace codersdk.Workspace, agent codersdk.WorkspaceAgent) (workspacetraffic.AppConfig, error) {
|
||||
if app == "" {
|
||||
return workspacetraffic.AppConfig{}, nil
|
||||
|
||||
@@ -4,16 +4,20 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/dynamicparameters"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,8 +25,15 @@ const (
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
|
||||
var templateName string
|
||||
var numEvals int64
|
||||
var (
|
||||
templateName string
|
||||
provisionerTags []string
|
||||
numEvals int64
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
// This test requires unlimited concurrency
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
)
|
||||
orgContext := NewOrganizationContext()
|
||||
output := &scaletestOutputFlags{}
|
||||
|
||||
@@ -46,20 +57,63 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
|
||||
return xerrors.Errorf("template cannot be empty")
|
||||
}
|
||||
|
||||
tags, err := ParseProvisionerTags(provisionerTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
org, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := dynamicparameters.NewMetrics(reg, "concurrent_evaluations")
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug)
|
||||
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, numEvals, logger)
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
// Allow time for traces to flush even if command context is
|
||||
// canceled. This is a no-op if tracing is not enabled.
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
|
||||
if err := closeTracing(ctx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
// Wait for prometheus metrics to be scraped
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, tags, numEvals, logger)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("setup dynamic parameters partitions: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(harness.ConcurrentExecutionStrategy{}, harness.ConcurrentExecutionStrategy{})
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := dynamicparameters.NewMetrics(reg, "concurrent_evaluations")
|
||||
th := harness.NewTestHarness(
|
||||
timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}),
|
||||
// there is no cleanup since it's just a connection that we sever.
|
||||
nil)
|
||||
|
||||
for i, part := range partitions {
|
||||
for j := range part.ConcurrentEvaluations {
|
||||
@@ -68,12 +122,21 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
|
||||
Metrics: metrics,
|
||||
MetricLabelValues: []string{fmt.Sprintf("%d", part.ConcurrentEvaluations)},
|
||||
}
|
||||
runner := dynamicparameters.NewRunner(client, cfg)
|
||||
var runner harness.Runnable = dynamicparameters.NewRunner(client, cfg)
|
||||
if tracingEnabled {
|
||||
runner = &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
spanName: fmt.Sprintf("%s/%d/%d", dynamicParametersTestName, i, j),
|
||||
runner: runner,
|
||||
}
|
||||
}
|
||||
th.AddRun(dynamicParametersTestName, fmt.Sprintf("%d/%d", j, i), runner)
|
||||
}
|
||||
}
|
||||
|
||||
err = th.Run(ctx)
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness: %w", err)
|
||||
}
|
||||
@@ -103,8 +166,16 @@ func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
|
||||
Default: "100",
|
||||
Value: serpent.Int64Of(&numEvals),
|
||||
},
|
||||
{
|
||||
Flag: "provisioner-tag",
|
||||
Description: "Specify a set of tags to target provisioner daemons.",
|
||||
Value: serpent.StringArrayOf(&provisionerTags),
|
||||
},
|
||||
}
|
||||
orgContext.AttachOptions(cmd)
|
||||
output.attach(&cmd.Options)
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -0,0 +1,447 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/notifications"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestNotifications() *serpent.Command {
|
||||
var (
|
||||
userCount int64
|
||||
ownerUserPercentage float64
|
||||
notificationTimeout time.Duration
|
||||
dialTimeout time.Duration
|
||||
noCleanup bool
|
||||
smtpAPIURL string
|
||||
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
|
||||
// This test requires unlimited concurrency.
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "notifications",
|
||||
Short: "Simulate notification delivery by creating many users listening to notifications.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if userCount <= 0 {
|
||||
return xerrors.Errorf("--user-count must be greater than 0")
|
||||
}
|
||||
|
||||
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
|
||||
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
|
||||
}
|
||||
|
||||
if smtpAPIURL != "" && !strings.HasPrefix(smtpAPIURL, "http://") && !strings.HasPrefix(smtpAPIURL, "https://") {
|
||||
return xerrors.Errorf("--smtp-api-url must start with http:// or https://")
|
||||
}
|
||||
|
||||
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
|
||||
if ownerUserCount == 0 && ownerUserPercentage > 0 {
|
||||
ownerUserCount = 1
|
||||
}
|
||||
regularUserCount := userCount - ownerUserCount
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
|
||||
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("could not parse --output flags")
|
||||
}
|
||||
|
||||
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := notifications.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
|
||||
if err := closeTracing(ctx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
// Wait for prometheus metrics to be scraped
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
|
||||
|
||||
dialBarrier := &sync.WaitGroup{}
|
||||
ownerWatchBarrier := &sync.WaitGroup{}
|
||||
dialBarrier.Add(int(userCount))
|
||||
ownerWatchBarrier.Add(int(ownerUserCount))
|
||||
|
||||
expectedNotificationIDs := map[uuid.UUID]struct{}{
|
||||
notificationsLib.TemplateUserAccountCreated: {},
|
||||
notificationsLib.TemplateUserAccountDeleted: {},
|
||||
}
|
||||
|
||||
triggerTimes := make(map[uuid.UUID]chan time.Time, len(expectedNotificationIDs))
|
||||
for id := range expectedNotificationIDs {
|
||||
triggerTimes[id] = make(chan time.Time, 1)
|
||||
}
|
||||
|
||||
configs := make([]notifications.Config, 0, userCount)
|
||||
for range ownerUserCount {
|
||||
config := notifications.Config{
|
||||
User: createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
},
|
||||
Roles: []string{codersdk.RoleOwner},
|
||||
NotificationTimeout: notificationTimeout,
|
||||
DialTimeout: dialTimeout,
|
||||
DialBarrier: dialBarrier,
|
||||
ReceivingWatchBarrier: ownerWatchBarrier,
|
||||
ExpectedNotificationsIDs: expectedNotificationIDs,
|
||||
Metrics: metrics,
|
||||
SMTPApiURL: smtpAPIURL,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
configs = append(configs, config)
|
||||
}
|
||||
for range regularUserCount {
|
||||
config := notifications.Config{
|
||||
User: createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
},
|
||||
Roles: []string{},
|
||||
NotificationTimeout: notificationTimeout,
|
||||
DialTimeout: dialTimeout,
|
||||
DialBarrier: dialBarrier,
|
||||
ReceivingWatchBarrier: ownerWatchBarrier,
|
||||
Metrics: metrics,
|
||||
SMTPApiURL: smtpAPIURL,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
configs = append(configs, config)
|
||||
}
|
||||
|
||||
go triggerUserNotifications(
|
||||
ctx,
|
||||
logger,
|
||||
client,
|
||||
me.OrganizationIDs[0],
|
||||
dialBarrier,
|
||||
dialTimeout,
|
||||
triggerTimes,
|
||||
)
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i, config := range configs {
|
||||
id := strconv.Itoa(i)
|
||||
name := fmt.Sprintf("notifications-%s", id)
|
||||
var runner harness.Runnable = notifications.NewRunner(client, config)
|
||||
if tracingEnabled {
|
||||
runner = &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
spanName: name,
|
||||
runner: runner,
|
||||
}
|
||||
}
|
||||
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
|
||||
if err := computeNotificationLatencies(ctx, logger, triggerTimes, res, metrics); err != nil {
|
||||
return xerrors.Errorf("compute notification latencies: %w", err)
|
||||
}
|
||||
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "user-count",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
|
||||
Description: "Required: Total number of users to create.",
|
||||
Value: serpent.Int64Of(&userCount),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "owner-user-percentage",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
|
||||
Default: "20.0",
|
||||
Description: "Percentage of users to assign Owner role to (0-100).",
|
||||
Value: serpent.Float64Of(&ownerUserPercentage),
|
||||
},
|
||||
{
|
||||
Flag: "notification-timeout",
|
||||
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
|
||||
Default: "5m",
|
||||
Description: "How long to wait for notifications after triggering.",
|
||||
Value: serpent.DurationOf(¬ificationTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "dial-timeout",
|
||||
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
|
||||
Default: "2m",
|
||||
Description: "Timeout for dialing the notification websocket endpoint.",
|
||||
Value: serpent.DurationOf(&dialTimeout),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "smtp-api-url",
|
||||
Env: "CODER_SCALETEST_SMTP_API_URL",
|
||||
Description: "SMTP mock HTTP API address.",
|
||||
Value: serpent.StringOf(&smtpAPIURL),
|
||||
},
|
||||
}
|
||||
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func computeNotificationLatencies(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
expectedNotifications map[uuid.UUID]chan time.Time,
|
||||
results harness.Results,
|
||||
metrics *notifications.Metrics,
|
||||
) error {
|
||||
triggerTimes := make(map[uuid.UUID]time.Time)
|
||||
for notificationID, triggerTimeChan := range expectedNotifications {
|
||||
select {
|
||||
case triggerTime := <-triggerTimeChan:
|
||||
triggerTimes[notificationID] = triggerTime
|
||||
logger.Info(ctx, "received trigger time",
|
||||
slog.F("notification_id", notificationID),
|
||||
slog.F("trigger_time", triggerTime))
|
||||
default:
|
||||
logger.Warn(ctx, "no trigger time received for notification",
|
||||
slog.F("notification_id", notificationID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(triggerTimes) == 0 {
|
||||
logger.Warn(ctx, "no trigger times available, skipping latency computation")
|
||||
return nil
|
||||
}
|
||||
|
||||
var totalLatencies int
|
||||
for runID, runResult := range results.Runs {
|
||||
if runResult.Error != nil {
|
||||
logger.Debug(ctx, "skipping failed run for latency computation",
|
||||
slog.F("run_id", runID))
|
||||
continue
|
||||
}
|
||||
|
||||
if runResult.Metrics == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Process websocket notifications.
|
||||
if wsReceiptTimes, ok := runResult.Metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
|
||||
for notificationID, receiptTime := range wsReceiptTimes {
|
||||
if triggerTime, ok := triggerTimes[notificationID]; ok {
|
||||
latency := receiptTime.Sub(triggerTime)
|
||||
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeWebsocket)
|
||||
totalLatencies++
|
||||
logger.Debug(ctx, "computed websocket latency",
|
||||
slog.F("run_id", runID),
|
||||
slog.F("notification_id", notificationID),
|
||||
slog.F("latency", latency))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process SMTP notifications
|
||||
if smtpReceiptTimes, ok := runResult.Metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
|
||||
for notificationID, receiptTime := range smtpReceiptTimes {
|
||||
if triggerTime, ok := triggerTimes[notificationID]; ok {
|
||||
latency := receiptTime.Sub(triggerTime)
|
||||
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeSMTP)
|
||||
totalLatencies++
|
||||
logger.Debug(ctx, "computed SMTP latency",
|
||||
slog.F("run_id", runID),
|
||||
slog.F("notification_id", notificationID),
|
||||
slog.F("latency", latency))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(ctx, "finished computing notification latencies",
|
||||
slog.F("total_runs", results.TotalRuns),
|
||||
slog.F("total_latencies_computed", totalLatencies))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// triggerUserNotifications waits for all test users to connect,
|
||||
// then creates and deletes a test user to trigger notification events for testing.
|
||||
func triggerUserNotifications(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
client *codersdk.Client,
|
||||
orgID uuid.UUID,
|
||||
dialBarrier *sync.WaitGroup,
|
||||
dialTimeout time.Duration,
|
||||
expectedNotifications map[uuid.UUID]chan time.Time,
|
||||
) {
|
||||
logger.Info(ctx, "waiting for all users to connect")
|
||||
|
||||
// Wait for all users to connect
|
||||
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
dialBarrier.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
logger.Info(ctx, "all users connected")
|
||||
case <-waitCtx.Done():
|
||||
if waitCtx.Err() == context.DeadlineExceeded {
|
||||
logger.Error(ctx, "timeout waiting for users to connect")
|
||||
} else {
|
||||
logger.Info(ctx, "context canceled while waiting for users")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
triggerUsername = "scaletest-trigger-user"
|
||||
triggerEmail = "scaletest-trigger@example.com"
|
||||
)
|
||||
|
||||
logger.Info(ctx, "creating test user to test notifications",
|
||||
slog.F("username", triggerUsername),
|
||||
slog.F("email", triggerEmail),
|
||||
slog.F("org_id", orgID))
|
||||
|
||||
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{orgID},
|
||||
Username: triggerUsername,
|
||||
Email: triggerEmail,
|
||||
Password: "test-password-123",
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "create test user", slog.Error(err))
|
||||
return
|
||||
}
|
||||
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
|
||||
|
||||
err = client.DeleteUser(ctx, testUser.ID)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "delete test user", slog.Error(err))
|
||||
return
|
||||
}
|
||||
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
|
||||
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
|
||||
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/smtpmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) scaletestSMTP() *serpent.Command {
|
||||
var (
|
||||
hostAddress string
|
||||
smtpPort int64
|
||||
apiPort int64
|
||||
purgeAtCount int64
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "smtp",
|
||||
Short: "Start a mock SMTP server for testing",
|
||||
Long: `Start a mock SMTP server with an HTTP API server that can be used to purge
|
||||
messages and get messages by email.`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
|
||||
config := smtpmock.Config{
|
||||
HostAddress: hostAddress,
|
||||
SMTPPort: int(smtpPort),
|
||||
APIPort: int(apiPort),
|
||||
Logger: logger,
|
||||
}
|
||||
srv := new(smtpmock.Server)
|
||||
|
||||
if err := srv.Start(ctx, config); err != nil {
|
||||
return xerrors.Errorf("start mock SMTP server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock SMTP server started on %s\n", srv.SMTPAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "HTTP API server started on %s\n", srv.APIAddress())
|
||||
if purgeAtCount > 0 {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " Auto-purge when message count reaches %d\n", purgeAtCount)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "\nTotal messages received since last purge: %d\n", srv.MessageCount())
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
count := srv.MessageCount()
|
||||
if count > 0 {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Messages received: %d\n", count)
|
||||
}
|
||||
|
||||
if purgeAtCount > 0 && int64(count) >= purgeAtCount {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Message count (%d) reached threshold (%d). Purging...\n", count, purgeAtCount)
|
||||
srv.Purge()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = []serpent.Option{
|
||||
{
|
||||
Flag: "host-address",
|
||||
Env: "CODER_SCALETEST_SMTP_HOST_ADDRESS",
|
||||
Default: "localhost",
|
||||
Description: "Host address to bind the mock SMTP and API servers.",
|
||||
Value: serpent.StringOf(&hostAddress),
|
||||
},
|
||||
{
|
||||
Flag: "smtp-port",
|
||||
Env: "CODER_SCALETEST_SMTP_PORT",
|
||||
Description: "Port for the mock SMTP server. Uses a random port if not specified.",
|
||||
Value: serpent.Int64Of(&smtpPort),
|
||||
},
|
||||
{
|
||||
Flag: "api-port",
|
||||
Env: "CODER_SCALETEST_SMTP_API_PORT",
|
||||
Description: "Port for the HTTP API server. Uses a random port if not specified.",
|
||||
Value: serpent.Int64Of(&apiPort),
|
||||
},
|
||||
{
|
||||
Flag: "purge-at-count",
|
||||
Env: "CODER_SCALETEST_SMTP_PURGE_AT_COUNT",
|
||||
Default: "100000",
|
||||
Description: "Maximum number of messages to keep before auto-purging. Set to 0 to disable.",
|
||||
Value: serpent.Int64Of(&purgeAtCount),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
+10
-34
@@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
@@ -47,43 +46,19 @@ func (r *RootCmd) taskDelete() *serpent.Command {
|
||||
}
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
|
||||
type toDelete struct {
|
||||
ID uuid.UUID
|
||||
Owner string
|
||||
Display string
|
||||
}
|
||||
|
||||
var items []toDelete
|
||||
var tasks []codersdk.Task
|
||||
for _, identifier := range inv.Args {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return xerrors.New("task identifier cannot be empty or whitespace")
|
||||
}
|
||||
|
||||
// Check task identifier, try UUID first.
|
||||
if id, err := uuid.Parse(identifier); err == nil {
|
||||
task, err := exp.TaskByID(ctx, id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", identifier, err)
|
||||
}
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
items = append(items, toDelete{ID: id, Display: display, Owner: task.OwnerName})
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-UUID, treat as a workspace identifier (name or owner/name).
|
||||
ws, err := namedWorkspace(ctx, client, identifier)
|
||||
task, err := exp.TaskByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", identifier, err)
|
||||
}
|
||||
display := ws.FullName()
|
||||
items = append(items, toDelete{ID: ws.ID, Display: display, Owner: ws.OwnerName})
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
// Confirm deletion of the tasks.
|
||||
var displayList []string
|
||||
for _, it := range items {
|
||||
displayList = append(displayList, it.Display)
|
||||
for _, task := range tasks {
|
||||
displayList = append(displayList, fmt.Sprintf("%s/%s", task.OwnerName, task.Name))
|
||||
}
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Delete these tasks: %s?", pretty.Sprint(cliui.DefaultStyles.Code, strings.Join(displayList, ", "))),
|
||||
@@ -94,12 +69,13 @@ func (r *RootCmd) taskDelete() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if err := exp.DeleteTask(ctx, item.Owner, item.ID); err != nil {
|
||||
return xerrors.Errorf("delete task %q: %w", item.Display, err)
|
||||
for i, task := range tasks {
|
||||
display := displayList[i]
|
||||
if err := exp.DeleteTask(ctx, task.OwnerName, task.ID); err != nil {
|
||||
return xerrors.Errorf("delete task %q: %w", display, err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(
|
||||
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, item.Display)+" at "+cliui.Timestamp(time.Now()),
|
||||
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, display)+" at "+cliui.Timestamp(time.Now()),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
+41
-17
@@ -56,12 +56,18 @@ func TestExpTaskDelete(t *testing.T) {
|
||||
taskID := uuid.MustParse(id1)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/exists":
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: taskID,
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: taskID,
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id1:
|
||||
c.deleteCalls.Add(1)
|
||||
@@ -104,12 +110,18 @@ func TestExpTaskDelete(t *testing.T) {
|
||||
firstID := uuid.MustParse(id3)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/first":
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: firstID,
|
||||
Name: "first",
|
||||
OwnerName: "me",
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: firstID,
|
||||
Name: "first",
|
||||
OwnerName: "me",
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id4:
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
|
||||
@@ -139,8 +151,14 @@ func TestExpTaskDelete(t *testing.T) {
|
||||
buildHandler: func(_ *testCounters) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/doesnotexist":
|
||||
httpapi.ResourceNotFound(w)
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{},
|
||||
Count: 0,
|
||||
})
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
@@ -156,12 +174,18 @@ func TestExpTaskDelete(t *testing.T) {
|
||||
taskID := uuid.MustParse(id5)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/bad":
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: taskID,
|
||||
Name: "bad",
|
||||
OwnerName: "me",
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: taskID,
|
||||
Name: "bad",
|
||||
OwnerName: "me",
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id5:
|
||||
httpapi.InternalServerError(w, xerrors.New("boom"))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
@@ -98,10 +99,10 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "status",
|
||||
Description: "Filter by task status (e.g. running, failed, etc).",
|
||||
Description: "Filter by task status.",
|
||||
Flag: "status",
|
||||
Default: "",
|
||||
Value: serpent.StringOf(&statusFilter),
|
||||
Value: serpent.EnumOf(&statusFilter, slice.ToStrings(codersdk.AllTaskStatuses())...),
|
||||
},
|
||||
{
|
||||
Name: "all",
|
||||
@@ -143,7 +144,7 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
|
||||
tasks, err := exp.Tasks(ctx, &codersdk.TasksFilter{
|
||||
Owner: targetUser,
|
||||
Status: statusFilter,
|
||||
Status: codersdk.TaskStatus(statusFilter),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tasks: %w", err)
|
||||
|
||||
+36
-15
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -29,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
// makeAITask creates an AI-task workspace.
|
||||
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) (workspace database.WorkspaceTable) {
|
||||
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) database.Task {
|
||||
t.Helper()
|
||||
|
||||
tv := dbfake.TemplateVersion(t, db).
|
||||
@@ -91,7 +92,27 @@ func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UU
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return build.Workspace
|
||||
// Create a task record in the tasks table for the new data model.
|
||||
task := dbgen.Task(t, db, database.TaskTable{
|
||||
OrganizationID: orgID,
|
||||
OwnerID: ownerID,
|
||||
Name: build.Workspace.Name,
|
||||
WorkspaceID: uuid.NullUUID{UUID: build.Workspace.ID, Valid: true},
|
||||
TemplateVersionID: tv.TemplateVersion.ID,
|
||||
TemplateParameters: []byte("{}"),
|
||||
Prompt: prompt,
|
||||
CreatedAt: dbtime.Now(),
|
||||
})
|
||||
|
||||
// Link the task to the workspace app.
|
||||
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
|
||||
TaskID: task.ID,
|
||||
WorkspaceBuildNumber: build.Build.BuildNumber,
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
|
||||
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
||||
})
|
||||
|
||||
return task
|
||||
}
|
||||
|
||||
func TestExpTaskList(t *testing.T) {
|
||||
@@ -128,7 +149,7 @@ func TestExpTaskList(t *testing.T) {
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
wantPrompt := "build me a web app"
|
||||
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
|
||||
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
|
||||
|
||||
inv, root := clitest.New(t, "exp", "task", "list", "--column", "id,name,status,initial prompt")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
@@ -140,8 +161,8 @@ func TestExpTaskList(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate the table includes the task and status.
|
||||
pty.ExpectMatch(ws.Name)
|
||||
pty.ExpectMatch("running")
|
||||
pty.ExpectMatch(task.Name)
|
||||
pty.ExpectMatch("initializing")
|
||||
pty.ExpectMatch(wantPrompt)
|
||||
})
|
||||
|
||||
@@ -154,12 +175,12 @@ func TestExpTaskList(t *testing.T) {
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
// Create two AI tasks: one running, one stopped.
|
||||
running := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
|
||||
stopped := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
|
||||
// Create two AI tasks: one initializing, one paused.
|
||||
initializingTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me initializing")
|
||||
pausedTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
|
||||
|
||||
// Use JSON output to reliably validate filtering.
|
||||
inv, root := clitest.New(t, "exp", "task", "list", "--status=stopped", "--output=json")
|
||||
inv, root := clitest.New(t, "exp", "task", "list", "--status=paused", "--output=json")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
@@ -173,10 +194,10 @@ func TestExpTaskList(t *testing.T) {
|
||||
var tasks []codersdk.Task
|
||||
require.NoError(t, json.Unmarshal(stdout.Bytes(), &tasks))
|
||||
|
||||
// Only the stopped task is returned.
|
||||
// Only the paused task is returned.
|
||||
require.Len(t, tasks, 1, "expected one task after filtering")
|
||||
require.Equal(t, stopped.ID, tasks[0].ID)
|
||||
require.NotEqual(t, running.ID, tasks[0].ID)
|
||||
require.Equal(t, pausedTask.ID, tasks[0].ID)
|
||||
require.NotEqual(t, initializingTask.ID, tasks[0].ID)
|
||||
})
|
||||
|
||||
t.Run("UserFlag_Me_Table", func(t *testing.T) {
|
||||
@@ -188,7 +209,7 @@ func TestExpTaskList(t *testing.T) {
|
||||
_, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
_ = makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "other-task")
|
||||
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
|
||||
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
|
||||
|
||||
inv, root := clitest.New(t, "exp", "task", "list", "--user", "me")
|
||||
//nolint:gocritic // Owner client is intended here smoke test the member task not showing up.
|
||||
@@ -200,7 +221,7 @@ func TestExpTaskList(t *testing.T) {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
pty.ExpectMatch(ws.Name)
|
||||
pty.ExpectMatch(task.Name)
|
||||
})
|
||||
|
||||
t.Run("Quiet", func(t *testing.T) {
|
||||
@@ -213,7 +234,7 @@ func TestExpTaskList(t *testing.T) {
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
// Given: We have two tasks
|
||||
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
|
||||
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me active")
|
||||
task2 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
|
||||
|
||||
// Given: We add the `--quiet` flag
|
||||
|
||||
+7
-15
@@ -3,7 +3,6 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
@@ -41,24 +40,17 @@ func (r *RootCmd) taskLogs() *serpent.Command {
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
task = inv.Args[0]
|
||||
taskID uuid.UUID
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
identifier = inv.Args[0]
|
||||
)
|
||||
|
||||
if id, err := uuid.Parse(task); err == nil {
|
||||
taskID = id
|
||||
} else {
|
||||
ws, err := namedWorkspace(ctx, client, task)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", task, err)
|
||||
}
|
||||
|
||||
taskID = ws.ID
|
||||
task, err := exp.TaskByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", identifier, err)
|
||||
}
|
||||
|
||||
logs, err := exp.TaskLogs(ctx, codersdk.Me, taskID)
|
||||
logs, err := exp.TaskLogs(ctx, codersdk.Me, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get task logs: %w", err)
|
||||
}
|
||||
|
||||
+13
-13
@@ -38,15 +38,15 @@ func Test_TaskLogs(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("ByWorkspaceName_JSON", func(t *testing.T) {
|
||||
t.Run("ByTaskName_JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client // user already has access to their own workspace
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.Name, "--output", "json")
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", task.Name, "--output", "json")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
@@ -64,15 +64,15 @@ func Test_TaskLogs(t *testing.T) {
|
||||
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_JSON", func(t *testing.T) {
|
||||
t.Run("ByTaskID_JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String(), "--output", "json")
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String(), "--output", "json")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
@@ -90,15 +90,15 @@ func Test_TaskLogs(t *testing.T) {
|
||||
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_Table", func(t *testing.T) {
|
||||
t.Run("ByTaskID_Table", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
@@ -112,7 +112,7 @@ func Test_TaskLogs(t *testing.T) {
|
||||
require.Contains(t, output, "output")
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
|
||||
t.Run("TaskNotFound_ByName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -130,7 +130,7 @@ func Test_TaskLogs(t *testing.T) {
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
|
||||
t.Run("TaskNotFound_ByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -152,10 +152,10 @@ func Test_TaskLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
|
||||
+7
-15
@@ -3,7 +3,6 @@ package cli
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -39,12 +38,11 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
task = inv.Args[0]
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
identifier = inv.Args[0]
|
||||
|
||||
taskInput string
|
||||
taskID uuid.UUID
|
||||
)
|
||||
|
||||
if stdin {
|
||||
@@ -62,18 +60,12 @@ func (r *RootCmd) taskSend() *serpent.Command {
|
||||
taskInput = inv.Args[1]
|
||||
}
|
||||
|
||||
if id, err := uuid.Parse(task); err == nil {
|
||||
taskID = id
|
||||
} else {
|
||||
ws, err := namedWorkspace(ctx, client, task)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
taskID = ws.ID
|
||||
task, err := exp.TaskByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
if err = exp.TaskSend(ctx, codersdk.Me, taskID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
if err = exp.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task: %w", err)
|
||||
}
|
||||
|
||||
|
||||
+13
-13
@@ -22,15 +22,15 @@ import (
|
||||
func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ByWorkspaceName_WithArgument", func(t *testing.T) {
|
||||
t.Run("ByTaskName_WithArgument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "carry on with the task")
|
||||
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
@@ -38,15 +38,15 @@ func Test_TaskSend(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_WithArgument", func(t *testing.T) {
|
||||
t.Run("ByTaskID_WithArgument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.ID.String(), "carry on with the task")
|
||||
inv, root := clitest.New(t, "exp", "task", "send", task.ID.String(), "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
@@ -54,15 +54,15 @@ func Test_TaskSend(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceName_WithStdin", func(t *testing.T) {
|
||||
t.Run("ByTaskName_WithStdin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "--stdin")
|
||||
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "--stdin")
|
||||
inv.Stdout = &stdout
|
||||
inv.Stdin = strings.NewReader("carry on with the task")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
@@ -71,7 +71,7 @@ func Test_TaskSend(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
|
||||
t.Run("TaskNotFound_ByName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -89,7 +89,7 @@ func Test_TaskSend(t *testing.T) {
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
|
||||
t.Run("TaskNotFound_ByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
@@ -111,10 +111,10 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
userClient, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
userClient, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "some task input")
|
||||
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
|
||||
+22
-31
@@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
@@ -84,21 +83,10 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
}
|
||||
|
||||
ctx := i.Context()
|
||||
ec := codersdk.NewExperimentalClient(client)
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
identifier := i.Args[0]
|
||||
|
||||
taskID, err := uuid.Parse(identifier)
|
||||
if err != nil {
|
||||
// Try to resolve the task as a named workspace
|
||||
// TODO: right now tasks are still "workspaces" under the hood.
|
||||
// We should update this once we have a proper task model.
|
||||
ws, err := namedWorkspace(ctx, client, identifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID = ws.ID
|
||||
}
|
||||
task, err := ec.TaskByID(ctx, taskID)
|
||||
task, err := exp.TaskByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,7 +107,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
// TODO: implement streaming updates instead of polling
|
||||
lastStatusRow := tsr
|
||||
for range t.C {
|
||||
task, err := ec.TaskByID(ctx, taskID)
|
||||
task, err := exp.TaskByID(ctx, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -152,7 +140,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
}
|
||||
|
||||
func taskWatchIsEnded(task codersdk.Task) bool {
|
||||
if task.Status == codersdk.WorkspaceStatusStopped {
|
||||
if task.WorkspaceStatus == codersdk.WorkspaceStatusStopped {
|
||||
return true
|
||||
}
|
||||
if task.WorkspaceAgentHealth == nil || !task.WorkspaceAgentHealth.Healthy {
|
||||
@@ -168,28 +156,21 @@ func taskWatchIsEnded(task codersdk.Task) bool {
|
||||
}
|
||||
|
||||
type taskStatusRow struct {
|
||||
codersdk.Task `table:"-"`
|
||||
ChangedAgo string `json:"-" table:"state changed,default_sort"`
|
||||
Timestamp time.Time `json:"-" table:"-"`
|
||||
TaskStatus string `json:"-" table:"status"`
|
||||
Healthy bool `json:"-" table:"healthy"`
|
||||
TaskState string `json:"-" table:"state"`
|
||||
Message string `json:"-" table:"message"`
|
||||
codersdk.Task `table:"r,recursive_inline"`
|
||||
ChangedAgo string `json:"-" table:"state changed"`
|
||||
Healthy bool `json:"-" table:"healthy"`
|
||||
}
|
||||
|
||||
func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
|
||||
return r1.TaskStatus == r2.TaskStatus &&
|
||||
return r1.Status == r2.Status &&
|
||||
r1.Healthy == r2.Healthy &&
|
||||
r1.TaskState == r2.TaskState &&
|
||||
r1.Message == r2.Message
|
||||
taskStateEqual(r1.CurrentState, r2.CurrentState)
|
||||
}
|
||||
|
||||
func toStatusRow(task codersdk.Task) taskStatusRow {
|
||||
tsr := taskStatusRow{
|
||||
Task: task,
|
||||
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
|
||||
Timestamp: task.UpdatedAt,
|
||||
TaskStatus: string(task.Status),
|
||||
}
|
||||
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
|
||||
task.WorkspaceAgentHealth.Healthy &&
|
||||
@@ -199,9 +180,19 @@ func toStatusRow(task codersdk.Task) taskStatusRow {
|
||||
|
||||
if task.CurrentState != nil {
|
||||
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
|
||||
tsr.Timestamp = task.CurrentState.Timestamp
|
||||
tsr.TaskState = string(task.CurrentState.State)
|
||||
tsr.Message = task.CurrentState.Message
|
||||
}
|
||||
return tsr
|
||||
}
|
||||
|
||||
func taskStateEqual(se1, se2 *codersdk.TaskStateEntry) bool {
|
||||
var s1, m1, s2, m2 string
|
||||
if se1 != nil {
|
||||
s1 = string(se1.State)
|
||||
m1 = se1.Message
|
||||
}
|
||||
if se2 != nil {
|
||||
s2 = string(se2.State)
|
||||
m2 = se2.Message
|
||||
}
|
||||
return s1 == s2 && m1 == m2
|
||||
}
|
||||
|
||||
+149
-69
@@ -36,26 +36,17 @@ func Test_TaskStatus(t *testing.T) {
|
||||
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/workspace/doesnotexist":
|
||||
httpapi.ResourceNotFound(w)
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"err-fetching-workspace"},
|
||||
expectError: assert.AnError.Error(),
|
||||
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/workspace/err-fetching-workspace":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
})
|
||||
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
|
||||
httpapi.InternalServerError(w, assert.AnError)
|
||||
case "/api/experimental/tasks":
|
||||
if r.URL.Query().Get("q") == "owner:\"me\"" {
|
||||
httpapi.Write(ctx, w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{},
|
||||
Count: 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
@@ -64,21 +55,45 @@ func Test_TaskStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
args: []string{"exists"},
|
||||
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
0s ago running true working Thinking furiously...`,
|
||||
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
0s ago active true working Thinking furiously...`,
|
||||
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/workspace/exists":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
})
|
||||
case "/api/experimental/tasks":
|
||||
if r.URL.Query().Get("q") == "owner:\"me\"" {
|
||||
httpapi.Write(ctx, w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateWorking,
|
||||
Timestamp: now,
|
||||
Message: "Thinking furiously...",
|
||||
},
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
Status: codersdk.TaskStatusActive,
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateWorking,
|
||||
Timestamp: now,
|
||||
@@ -88,7 +103,9 @@ func Test_TaskStatus(t *testing.T) {
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
Status: codersdk.TaskStatusActive,
|
||||
})
|
||||
return
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
@@ -97,50 +114,77 @@ func Test_TaskStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
args: []string{"exists", "--watch"},
|
||||
expectOutput: `
|
||||
STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
4s ago running true
|
||||
3s ago running true working Reticulating splines...
|
||||
2s ago running true complete Splines reticulated successfully!`,
|
||||
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
5s ago pending true
|
||||
4s ago initializing true
|
||||
4s ago active true
|
||||
3s ago active true working Reticulating splines...
|
||||
2s ago active true complete Splines reticulated successfully!`,
|
||||
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
|
||||
var calls atomic.Int64
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer calls.Add(1)
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/workspace/exists":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
})
|
||||
case "/api/experimental/tasks":
|
||||
if r.URL.Query().Get("q") == "owner:\"me\"" {
|
||||
// Return initial task state for --watch test
|
||||
httpapi.Write(ctx, w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusPending,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-5 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
Status: codersdk.TaskStatusPending,
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
|
||||
defer calls.Add(1)
|
||||
switch calls.Load() {
|
||||
case 0:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusPending,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-5 * time.Second),
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
Status: codersdk.TaskStatusInitializing,
|
||||
})
|
||||
return
|
||||
case 1:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
Status: codersdk.TaskStatusActive,
|
||||
})
|
||||
return
|
||||
case 2:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
@@ -150,13 +194,15 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
Timestamp: now.Add(-3 * time.Second),
|
||||
Message: "Reticulating splines...",
|
||||
},
|
||||
Status: codersdk.TaskStatusActive,
|
||||
})
|
||||
return
|
||||
case 3:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
@@ -166,13 +212,16 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
Timestamp: now.Add(-2 * time.Second),
|
||||
Message: "Splines reticulated successfully!",
|
||||
},
|
||||
Status: codersdk.TaskStatusActive,
|
||||
})
|
||||
return
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("too many calls!"))
|
||||
return
|
||||
}
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.Errorf("unexpected path: %q", r.URL.Path))
|
||||
return
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -183,19 +232,24 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
"id": "11111111-1111-1111-1111-111111111111",
|
||||
"organization_id": "00000000-0000-0000-0000-000000000000",
|
||||
"owner_id": "00000000-0000-0000-0000-000000000000",
|
||||
"owner_name": "",
|
||||
"name": "",
|
||||
"owner_name": "me",
|
||||
"name": "exists",
|
||||
"template_id": "00000000-0000-0000-0000-000000000000",
|
||||
"template_version_id": "00000000-0000-0000-0000-000000000000",
|
||||
"template_name": "",
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": null,
|
||||
"workspace_name": "",
|
||||
"workspace_status": "running",
|
||||
"workspace_agent_id": null,
|
||||
"workspace_agent_lifecycle": null,
|
||||
"workspace_agent_health": null,
|
||||
"workspace_agent_lifecycle": "ready",
|
||||
"workspace_agent_health": {
|
||||
"healthy": true
|
||||
},
|
||||
"workspace_app_id": null,
|
||||
"initial_prompt": "",
|
||||
"status": "running",
|
||||
"status": "active",
|
||||
"current_state": {
|
||||
"timestamp": "2025-08-26T12:34:57Z",
|
||||
"state": "working",
|
||||
@@ -205,26 +259,52 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
"created_at": "2025-08-26T12:34:56Z",
|
||||
"updated_at": "2025-08-26T12:34:56Z"
|
||||
}`,
|
||||
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
|
||||
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/workspace/exists":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
})
|
||||
case "/api/experimental/tasks":
|
||||
if r.URL.Query().Get("q") == "owner:\"me\"" {
|
||||
httpapi.Write(ctx, w, http.StatusOK, struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}{
|
||||
Tasks: []codersdk.Task{{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: ts,
|
||||
UpdatedAt: ts,
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateWorking,
|
||||
Timestamp: ts.Add(time.Second),
|
||||
Message: "Thinking furiously...",
|
||||
},
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
Status: codersdk.TaskStatusActive,
|
||||
}},
|
||||
Count: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: ts,
|
||||
UpdatedAt: ts,
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: ts,
|
||||
UpdatedAt: ts,
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateWorking,
|
||||
Timestamp: ts.Add(time.Second),
|
||||
Message: "Thinking furiously...",
|
||||
},
|
||||
Status: codersdk.TaskStatusActive,
|
||||
})
|
||||
return
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
+229
-6
@@ -2,26 +2,242 @@ package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// This test performs an integration-style test for tasks functionality.
|
||||
//
|
||||
//nolint:tparallel // The sub-tests of this test must be run sequentially.
|
||||
func Test_Tasks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: a template configured for tasks
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitLong)
|
||||
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner = coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ = coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
initMsg = agentapisdk.Message{
|
||||
Content: "test task input for " + t.Name(),
|
||||
Id: 0,
|
||||
Role: "user",
|
||||
Time: time.Now().UTC(),
|
||||
}
|
||||
authToken = uuid.NewString()
|
||||
echoAgentAPI = startFakeAgentAPI(t, fakeAgentAPIEcho(ctx, t, initMsg, "hello"))
|
||||
taskTpl = createAITaskTemplate(t, client, owner.OrganizationID, withAgentToken(authToken), withSidebarURL(echoAgentAPI.URL()))
|
||||
taskName = strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
|
||||
)
|
||||
|
||||
//nolint:paralleltest // The sub-tests of this test must be run sequentially.
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
cmdArgs []string
|
||||
assertFn func(stdout string, userClient *codersdk.Client)
|
||||
}{
|
||||
{
|
||||
name: "create task",
|
||||
cmdArgs: []string{"exp", "task", "create", "test task input for " + t.Name(), "--name", taskName, "--template", taskTpl.Name},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
require.Contains(t, stdout, taskName, "task name should be in output")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list tasks after create",
|
||||
cmdArgs: []string{"exp", "task", "list", "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var tasks []codersdk.Task
|
||||
err := json.NewDecoder(strings.NewReader(stdout)).Decode(&tasks)
|
||||
require.NoError(t, err, "list output should unmarshal properly")
|
||||
require.Len(t, tasks, 1, "expected one task")
|
||||
require.Equal(t, taskName, tasks[0].Name, "task name should match")
|
||||
require.Equal(t, initMsg.Content, tasks[0].InitialPrompt, "initial prompt should match")
|
||||
require.True(t, tasks[0].WorkspaceID.Valid, "workspace should be created")
|
||||
// For the next test, we need to wait for the workspace to be healthy
|
||||
ws := coderdtest.MustWorkspace(t, userClient, tasks[0].WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get task status after create",
|
||||
cmdArgs: []string{"exp", "task", "status", taskName, "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var task codersdk.Task
|
||||
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
|
||||
require.Equal(t, task.Name, taskName, "task name should match")
|
||||
require.Equal(t, codersdk.TaskStatusActive, task.Status, "task should be active")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "send task message",
|
||||
cmdArgs: []string{"exp", "task", "send", taskName, "hello"},
|
||||
// Assertions for this happen in the fake agent API handler.
|
||||
},
|
||||
{
|
||||
name: "read task logs",
|
||||
cmdArgs: []string{"exp", "task", "logs", taskName, "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var logs []codersdk.TaskLogEntry
|
||||
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&logs), "should unmarshal task logs")
|
||||
require.Len(t, logs, 3, "should have 3 logs")
|
||||
require.Equal(t, logs[0].Content, initMsg.Content, "first message should be the init message")
|
||||
require.Equal(t, logs[0].Type, codersdk.TaskLogTypeInput, "first message should be an input")
|
||||
require.Equal(t, logs[1].Content, "hello", "second message should be the sent message")
|
||||
require.Equal(t, logs[1].Type, codersdk.TaskLogTypeInput, "second message should be an input")
|
||||
require.Equal(t, logs[2].Content, "hello", "third message should be the echoed message")
|
||||
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete task",
|
||||
cmdArgs: []string{"exp", "task", "delete", taskName, "--yes"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
// The task should eventually no longer show up in the list of tasks
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
expClient := codersdk.NewExperimentalClient(userClient)
|
||||
tasks, err := expClient.Tasks(ctx, &codersdk.TasksFilter{})
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
return slices.IndexFunc(tasks, func(task codersdk.Task) bool {
|
||||
return task.Name == taskName
|
||||
}) == -1
|
||||
}, testutil.IntervalMedium)
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, tc.cmdArgs...)
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
require.NoError(t, inv.WithContext(ctx).Run())
|
||||
if tc.assertFn != nil {
|
||||
tc.assertFn(stdout.String(), userClient)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Message, want ...string) map[string]http.HandlerFunc {
|
||||
t.Helper()
|
||||
var mmu sync.RWMutex
|
||||
msgs := []agentapisdk.Message{initMsg}
|
||||
wantCpy := make([]string, len(want))
|
||||
copy(wantCpy, want)
|
||||
t.Cleanup(func() {
|
||||
mmu.Lock()
|
||||
defer mmu.Unlock()
|
||||
if !t.Failed() {
|
||||
assert.Empty(t, wantCpy, "not all expected messages received: missing %v", wantCpy)
|
||||
}
|
||||
})
|
||||
writeAgentAPIError := func(w http.ResponseWriter, err error, status int) {
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(agentapisdk.ErrorModel{
|
||||
Errors: ptr.Ref([]agentapisdk.ErrorDetail{
|
||||
{
|
||||
Message: ptr.Ref(err.Error()),
|
||||
},
|
||||
}),
|
||||
})
|
||||
}
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(agentapisdk.GetStatusResponse{
|
||||
Status: "stable",
|
||||
})
|
||||
},
|
||||
"/messages": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
mmu.RLock()
|
||||
defer mmu.RUnlock()
|
||||
bs, err := json.Marshal(agentapisdk.GetMessagesResponse{
|
||||
Messages: msgs,
|
||||
})
|
||||
if err != nil {
|
||||
writeAgentAPIError(w, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(bs)
|
||||
},
|
||||
"/message": func(w http.ResponseWriter, r *http.Request) {
|
||||
mmu.Lock()
|
||||
defer mmu.Unlock()
|
||||
var params agentapisdk.PostMessageParams
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewDecoder(r.Body).Decode(¶ms)
|
||||
if !assert.NoError(t, err, "decode message") {
|
||||
writeAgentAPIError(w, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(wantCpy) == 0 {
|
||||
assert.Fail(t, "unexpected message", "received message %v, but no more expected messages", params)
|
||||
writeAgentAPIError(w, xerrors.New("no more expected messages"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
exp := wantCpy[0]
|
||||
wantCpy = wantCpy[1:]
|
||||
|
||||
if !assert.Equal(t, exp, params.Content, "message content mismatch") {
|
||||
writeAgentAPIError(w, xerrors.New("unexpected message content: expected "+exp+", got "+params.Content), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
msgs = append(msgs, agentapisdk.Message{
|
||||
Id: int64(len(msgs) + 1),
|
||||
Content: params.Content,
|
||||
Role: agentapisdk.RoleUser,
|
||||
Time: time.Now().UTC(),
|
||||
})
|
||||
msgs = append(msgs, agentapisdk.Message{
|
||||
Id: int64(len(msgs) + 1),
|
||||
Content: params.Content,
|
||||
Role: agentapisdk.RoleAgent,
|
||||
Time: time.Now().UTC(),
|
||||
})
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(agentapisdk.PostMessageResponse{
|
||||
Ok: true,
|
||||
}))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Workspace) {
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
@@ -34,11 +250,18 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
|
||||
|
||||
wantPrompt := "test prompt"
|
||||
workspace := coderdtest.CreateWorkspace(t, userClient, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: wantPrompt},
|
||||
}
|
||||
exp := codersdk.NewExperimentalClient(userClient)
|
||||
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: wantPrompt,
|
||||
Name: "test-task",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the task's underlying workspace to be built
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
@@ -49,7 +272,7 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
return userClient, workspace
|
||||
return userClient, task
|
||||
}
|
||||
|
||||
// createAITaskTemplate creates a template configured for AI tasks with a sidebar app.
|
||||
|
||||
@@ -144,6 +144,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
r.syncCommand(),
|
||||
r.tasksCommand(),
|
||||
r.boundary(),
|
||||
}
|
||||
|
||||
@@ -176,6 +176,22 @@ func (r *RootCmd) scheduleStart() *serpent.Command {
|
||||
}
|
||||
|
||||
schedStr = ptr.Ref(sched.String())
|
||||
|
||||
// Check if the template has autostart requirements that may conflict
|
||||
// with the user's schedule.
|
||||
template, err := client.Template(inv.Context(), workspace.TemplateID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template: %w", err)
|
||||
}
|
||||
|
||||
if len(template.AutostartRequirement.DaysOfWeek) > 0 {
|
||||
_, _ = fmt.Fprintf(
|
||||
inv.Stderr,
|
||||
"Warning: your workspace template restricts autostart to the following days: %s.\n"+
|
||||
"Your workspace may only autostart on these days.\n",
|
||||
strings.Join(template.AutostartRequirement.DaysOfWeek, ", "),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
err = client.UpdateWorkspaceAutostart(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{
|
||||
|
||||
@@ -373,3 +373,67 @@ func TestScheduleOverride(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:paralleltest // t.Setenv
|
||||
func TestScheduleStart_TemplateAutostartRequirement(t *testing.T) {
|
||||
t.Setenv("TZ", "UTC")
|
||||
loc, err := tz.TimezoneIANA()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "UTC", loc.String())
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
// Update template to have autostart requirement
|
||||
// Note: In AGPL, this will be ignored and all days will be allowed (enterprise feature).
|
||||
template, err = client.UpdateTemplateMeta(context.Background(), template.ID, codersdk.UpdateTemplateMeta{
|
||||
AutostartRequirement: &codersdk.TemplateAutostartRequirement{
|
||||
DaysOfWeek: []string{"monday", "wednesday", "friday"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the template - in AGPL, AutostartRequirement will have all days (enterprise feature)
|
||||
template, err = client.Template(context.Background(), template.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, template.AutostartRequirement.DaysOfWeek, "template should have autostart requirement days")
|
||||
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
t.Run("ShowsWarning", func(t *testing.T) {
|
||||
// When: user sets autostart schedule
|
||||
inv, root := clitest.New(t,
|
||||
"schedule", "start", workspace.Name, "9:30AM", "Mon-Fri",
|
||||
)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
// Then: warning should be shown
|
||||
// In AGPL, this will show all days (enterprise feature defaults to all days allowed)
|
||||
pty.ExpectMatch("Warning")
|
||||
pty.ExpectMatch("may only autostart")
|
||||
})
|
||||
|
||||
t.Run("NoWarningWhenManual", func(t *testing.T) {
|
||||
// When: user sets manual schedule
|
||||
inv, root := clitest.New(t,
|
||||
"schedule", "start", workspace.Name, "manual",
|
||||
)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
inv.Stderr = &stderrBuf
|
||||
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
// Then: no warning should be shown on stderr
|
||||
stderrOutput := stderrBuf.String()
|
||||
require.NotContains(t, stderrOutput, "Warning")
|
||||
})
|
||||
}
|
||||
|
||||
+75
-39
@@ -29,6 +29,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
@@ -1377,6 +1378,7 @@ func IsLocalURL(ctx context.Context, u *url.URL) (bool, error) {
|
||||
}
|
||||
|
||||
func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error {
|
||||
// nolint:gocritic // The magic number is parameterized.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return shutdown(ctx)
|
||||
@@ -2134,50 +2136,83 @@ func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logg
|
||||
return "", nil, xerrors.New("The built-in PostgreSQL cannot run as the root user. Create a non-root user and run again!")
|
||||
}
|
||||
|
||||
// Ensure a password and port have been generated!
|
||||
connectionURL, err := embeddedPostgresURL(cfg)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
pgPassword, err := cfg.PostgresPassword().Read()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("read postgres password: %w", err)
|
||||
}
|
||||
pgPortRaw, err := cfg.PostgresPort().Read()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("read postgres port: %w", err)
|
||||
}
|
||||
pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("parse postgres port: %w", err)
|
||||
}
|
||||
|
||||
cachePath := filepath.Join(cfg.PostgresPath(), "cache")
|
||||
if customCacheDir != "" {
|
||||
cachePath = filepath.Join(customCacheDir, "postgres")
|
||||
}
|
||||
stdlibLogger := slog.Stdlib(ctx, logger.Named("postgres"), slog.LevelDebug)
|
||||
ep := embeddedpostgres.NewDatabase(
|
||||
embeddedpostgres.DefaultConfig().
|
||||
Version(embeddedpostgres.V13).
|
||||
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
|
||||
// Default BinaryRepositoryURL repo1.maven.org is flaky.
|
||||
BinaryRepositoryURL("https://repo.maven.apache.org/maven2").
|
||||
DataPath(filepath.Join(cfg.PostgresPath(), "data")).
|
||||
RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")).
|
||||
CachePath(cachePath).
|
||||
Username("coder").
|
||||
Password(pgPassword).
|
||||
Database("coder").
|
||||
Encoding("UTF8").
|
||||
Port(uint32(pgPort)).
|
||||
Logger(stdlibLogger.Writer()),
|
||||
)
|
||||
err = ep.Start()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("Failed to start built-in PostgreSQL. Optionally, specify an external deployment with `--postgres-url`: %w", err)
|
||||
|
||||
// If the port is not defined, an available port will be found dynamically.
|
||||
maxAttempts := 1
|
||||
_, err = cfg.PostgresPort().Read()
|
||||
retryPortDiscovery := errors.Is(err, os.ErrNotExist) && testing.Testing()
|
||||
if retryPortDiscovery {
|
||||
// There is no way to tell Postgres to use an ephemeral port, so in order to avoid
|
||||
// flaky tests in CI we need to retry EmbeddedPostgres.Start in case of a race
|
||||
// condition where the port we quickly listen on and close in embeddedPostgresURL()
|
||||
// is not free by the time the embedded postgres starts up. This maximum_should
|
||||
// cover most cases where port conflicts occur in CI and cause flaky tests.
|
||||
maxAttempts = 3
|
||||
}
|
||||
return connectionURL, ep.Stop, nil
|
||||
|
||||
var startErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
// Ensure a password and port have been generated.
|
||||
connectionURL, err := embeddedPostgresURL(cfg)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
pgPassword, err := cfg.PostgresPassword().Read()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("read postgres password: %w", err)
|
||||
}
|
||||
pgPortRaw, err := cfg.PostgresPort().Read()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("read postgres port: %w", err)
|
||||
}
|
||||
pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("parse postgres port: %w", err)
|
||||
}
|
||||
|
||||
ep := embeddedpostgres.NewDatabase(
|
||||
embeddedpostgres.DefaultConfig().
|
||||
Version(embeddedpostgres.V13).
|
||||
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
|
||||
// Default BinaryRepositoryURL repo1.maven.org is flaky.
|
||||
BinaryRepositoryURL("https://repo.maven.apache.org/maven2").
|
||||
DataPath(filepath.Join(cfg.PostgresPath(), "data")).
|
||||
RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")).
|
||||
CachePath(cachePath).
|
||||
Username("coder").
|
||||
Password(pgPassword).
|
||||
Database("coder").
|
||||
Encoding("UTF8").
|
||||
Port(uint32(pgPort)).
|
||||
Logger(stdlibLogger.Writer()),
|
||||
)
|
||||
|
||||
startErr = ep.Start()
|
||||
if startErr == nil {
|
||||
return connectionURL, ep.Stop, nil
|
||||
}
|
||||
|
||||
logger.Warn(ctx, "failed to start embedded postgres",
|
||||
slog.F("attempt", attempt+1),
|
||||
slog.F("max_attempts", maxAttempts),
|
||||
slog.F("port", pgPort),
|
||||
slog.Error(startErr),
|
||||
)
|
||||
|
||||
if retryPortDiscovery {
|
||||
// Since a retry is needed, we wipe the port stored here at the beginning of the loop.
|
||||
_ = cfg.PostgresPort().Delete()
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil, xerrors.Errorf("failed to start built-in PostgreSQL after %d attempts. "+
|
||||
"Optionally, specify an external deployment. See https://coder.com/docs/tutorials/external-database "+
|
||||
"for more details: %w", maxAttempts, startErr)
|
||||
}
|
||||
|
||||
func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) {
|
||||
@@ -2286,7 +2321,7 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
|
||||
var err error
|
||||
var sqlDB *sql.DB
|
||||
dbNeedsClosing := true
|
||||
// Try to connect for 30 seconds.
|
||||
// nolint:gocritic // Try to connect for 30 seconds.
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -2382,6 +2417,7 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
|
||||
}
|
||||
|
||||
func pingPostgres(ctx context.Context, db *sql.DB) error {
|
||||
// nolint:gocritic // This is a reasonable magic number for a ping timeout.
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
return db.PingContext(ctx)
|
||||
|
||||
@@ -17,9 +17,6 @@ import (
|
||||
|
||||
func TestRegenerateVapidKeypair(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test is only supported on postgres")
|
||||
}
|
||||
|
||||
t.Run("NoExistingVAPIDKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -348,9 +348,6 @@ func TestServer(t *testing.T) {
|
||||
|
||||
runGitHubProviderTest := func(t *testing.T, tc testCase) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("test requires postgres")
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(testutil.Context(t, testutil.WaitLong))
|
||||
defer cancelFunc()
|
||||
@@ -2142,10 +2139,6 @@ func TestServerYAMLConfig(t *testing.T) {
|
||||
func TestConnectToPostgres(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test does not make sense without postgres")
|
||||
}
|
||||
|
||||
t.Run("Migrate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2256,10 +2249,6 @@ type runServerOpts struct {
|
||||
func TestServer_TelemetryDisabled_FinalReport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
telemetryServerURL, deployment, snapshot := mockTelemetryServer(t)
|
||||
dbConnURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -54,6 +54,7 @@ func TestSharingShare(t *testing.T) {
|
||||
MinimalUser: codersdk.MinimalUser{
|
||||
ID: toShareWithUser.ID,
|
||||
Username: toShareWithUser.Username,
|
||||
Name: toShareWithUser.Name,
|
||||
AvatarURL: toShareWithUser.AvatarURL,
|
||||
},
|
||||
Role: codersdk.WorkspaceRole("use"),
|
||||
@@ -103,6 +104,7 @@ func TestSharingShare(t *testing.T) {
|
||||
MinimalUser: codersdk.MinimalUser{
|
||||
ID: toShareWithUser1.ID,
|
||||
Username: toShareWithUser1.Username,
|
||||
Name: toShareWithUser1.Name,
|
||||
AvatarURL: toShareWithUser1.AvatarURL,
|
||||
},
|
||||
Role: codersdk.WorkspaceRoleUse,
|
||||
@@ -111,6 +113,7 @@ func TestSharingShare(t *testing.T) {
|
||||
MinimalUser: codersdk.MinimalUser{
|
||||
ID: toShareWithUser2.ID,
|
||||
Username: toShareWithUser2.Username,
|
||||
Name: toShareWithUser2.Name,
|
||||
AvatarURL: toShareWithUser2.AvatarURL,
|
||||
},
|
||||
Role: codersdk.WorkspaceRoleUse,
|
||||
@@ -155,6 +158,7 @@ func TestSharingShare(t *testing.T) {
|
||||
MinimalUser: codersdk.MinimalUser{
|
||||
ID: toShareWithUser.ID,
|
||||
Username: toShareWithUser.Username,
|
||||
Name: toShareWithUser.Name,
|
||||
AvatarURL: toShareWithUser.AvatarURL,
|
||||
},
|
||||
Role: codersdk.WorkspaceRoleAdmin,
|
||||
|
||||
+25
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncCommand() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "sync",
|
||||
Short: "Synchronize with the local agent socket",
|
||||
Long: "Commands for interacting with the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.syncPing(),
|
||||
r.syncStart(),
|
||||
r.syncWant(),
|
||||
r.syncComplete(),
|
||||
r.syncWait(),
|
||||
r.syncStatus(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncComplete() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "complete <unit>",
|
||||
Short: "Mark a unit as complete in the dependency graph",
|
||||
Long: "Set a unit's status to complete in the dependency graph.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Completing unit '%s'...\n", unit)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Complete the unit
|
||||
if err := client.SyncComplete(ctx, unit); err != nil {
|
||||
return xerrors.Errorf("complete unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' completed successfully\n", unit)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncPing() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "ping",
|
||||
Short: "Ping the local agent socket",
|
||||
Long: "Test connectivity to the local Coder agent via socket communication.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Show initial message
|
||||
fmt.Println("Pinging agent socket...")
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Measure round-trip time
|
||||
start := time.Now()
|
||||
resp, err := client.Ping(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Display results
|
||||
fmt.Printf("Response: %s\n", resp.Message)
|
||||
fmt.Printf("Timestamp: %s\n", resp.Timestamp.Format(time.RFC3339))
|
||||
fmt.Printf("Round-trip time: %s\n", duration.Round(time.Microsecond))
|
||||
fmt.Println("Status: healthy")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// SyncPollInterval is the interval between dependency checks for sync start
|
||||
SyncPollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStart() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "start <unit>",
|
||||
Short: "Start a unit in the dependency graph",
|
||||
Long: "Register a unit in the dependency graph and set its status to started. Waits for all dependencies to be satisfied before marking as started.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Starting unit '%s'...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Check if dependencies are satisfied first
|
||||
err = client.SyncReady(ctx, unitName)
|
||||
if err != nil {
|
||||
// Check if it's a "not ready" error (expected if dependencies exist)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Dependencies exist but aren't satisfied, start polling
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(SyncPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
pollLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
break pollLoop
|
||||
}
|
||||
|
||||
// Check if it's still a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
} else {
|
||||
// No dependencies or already satisfied
|
||||
fmt.Printf("Dependencies satisfied, marking unit '%s' as started\n", unitName)
|
||||
}
|
||||
|
||||
// Start the unit
|
||||
if err := client.SyncStart(ctx, unitName); err != nil {
|
||||
return xerrors.Errorf("start unit failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Unit '%s' started successfully\n", unitName)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
type outputFormat string
|
||||
|
||||
const (
|
||||
outputFormatHuman outputFormat = "human"
|
||||
outputFormatJSON outputFormat = "json"
|
||||
outputFormatDOT outputFormat = "dot"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncStatus() *serpent.Command {
|
||||
var (
|
||||
output string
|
||||
recursive bool
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "status <unit>",
|
||||
Short: "Show the status of a unit and its dependencies",
|
||||
Long: "Display the current status of a unit and information about its dependencies. Supports multiple output formats.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Get status information
|
||||
statusResp, err := client.SyncStatus(ctx, unit, recursive)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get status failed: %w", err)
|
||||
}
|
||||
|
||||
// Output based on format
|
||||
switch outputFormat(output) {
|
||||
case outputFormatJSON:
|
||||
return outputJSON(statusResp)
|
||||
case outputFormatDOT:
|
||||
return outputDOT(statusResp)
|
||||
default: // outputFormatHuman
|
||||
return outputHuman(statusResp)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options,
|
||||
serpent.Option{
|
||||
Flag: "output",
|
||||
FlagShorthand: "o",
|
||||
Description: "Output format: human, json, or dot.",
|
||||
Value: serpent.EnumOf(&output, "human", "json", "dot"),
|
||||
},
|
||||
serpent.Option{
|
||||
Flag: "recursive",
|
||||
FlagShorthand: "r",
|
||||
Description: "Show transitive dependencies and include DOT graph.",
|
||||
Value: serpent.BoolOf(&recursive),
|
||||
},
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func outputJSON(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(statusResp)
|
||||
}
|
||||
|
||||
func outputDOT(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
if statusResp.DOT == "" {
|
||||
return xerrors.New("DOT output requires --recursive flag")
|
||||
}
|
||||
fmt.Println(statusResp.DOT)
|
||||
return nil
|
||||
}
|
||||
|
||||
func outputHuman(statusResp *agentsdk.SyncStatusResponse) error {
|
||||
// Unit status
|
||||
fmt.Printf("Unit: %s\n", statusResp.Unit)
|
||||
fmt.Printf("Status: %s\n", statusResp.Status)
|
||||
fmt.Printf("Ready: %t\n", statusResp.IsReady)
|
||||
fmt.Println()
|
||||
|
||||
// Dependencies
|
||||
if len(statusResp.Dependencies) == 0 {
|
||||
fmt.Println("No dependencies")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("Dependencies:")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n", "Depends On", "Required", "Current", "Satisfied")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
|
||||
for _, dep := range statusResp.Dependencies {
|
||||
satisfied := "✓"
|
||||
if !dep.IsSatisfied {
|
||||
satisfied = "✗"
|
||||
}
|
||||
fmt.Printf("%-20s %-15s %-15s %-10s\n",
|
||||
dep.DependsOn,
|
||||
dep.RequiredStatus,
|
||||
dep.CurrentStatus,
|
||||
satisfied,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
// mockAgentSocketServer simulates the agent socket server for testing
|
||||
type mockAgentSocketServer struct {
|
||||
listener net.Listener
|
||||
handlers map[string]func(string) (string, error)
|
||||
}
|
||||
|
||||
func newMockAgentSocketServer(t *testing.T, socketPath string) *mockAgentSocketServer {
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &mockAgentSocketServer{
|
||||
listener: listener,
|
||||
handlers: make(map[string]func(string) (string, error)),
|
||||
}
|
||||
|
||||
// Set up default handlers
|
||||
server.handlers["sync.wait"] = func(unitName string) (string, error) {
|
||||
// Always return dependencies not satisfied to trigger polling
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
}
|
||||
|
||||
server.handlers["sync.start"] = func(unitName string) (string, error) {
|
||||
return "Unit " + unitName + " started successfully", nil
|
||||
}
|
||||
|
||||
go server.serve(t)
|
||||
return server
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("Accept error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go s.handleConnection(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) handleConnection(t *testing.T, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Simple JSON-RPC-like protocol simulation
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Logf("Read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
request := string(buf[:n])
|
||||
|
||||
// Parse method from request (simplified)
|
||||
var method string
|
||||
if strings.Contains(request, "sync.wait") {
|
||||
method = "sync.wait"
|
||||
} else if strings.Contains(request, "sync.start") {
|
||||
method = "sync.start"
|
||||
}
|
||||
|
||||
handler, exists := s.handlers[method]
|
||||
if !exists {
|
||||
response := `{"error": {"code": -32601, "message": "Method not found"}}`
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
// Extract unit name from request (simplified)
|
||||
unitName := "test-unit"
|
||||
if strings.Contains(request, "test-unit") {
|
||||
unitName = "test-unit"
|
||||
}
|
||||
|
||||
message, err := handler(unitName)
|
||||
if err != nil {
|
||||
response := fmt.Sprintf(`{"error": {"code": -32603, "message": %q}}`, err.Error())
|
||||
_, _ = conn.Write([]byte(response))
|
||||
return
|
||||
}
|
||||
|
||||
response := fmt.Sprintf(`{"result": {"success": true, "message": %q}}`, message)
|
||||
_, _ = conn.Write([]byte(response))
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) setHandler(method string, handler func(string) (string, error)) {
|
||||
s.handlers[method] = handler
|
||||
}
|
||||
|
||||
func (s *mockAgentSocketServer) close() {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
func TestSyncStartTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with a short timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", "100ms")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately 100ms
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range (100ms + some buffer for test execution)
|
||||
assert.True(t, duration >= 100*time.Millisecond, "Duration should be at least 100ms, got %v", duration)
|
||||
assert.True(t, duration < 2*time.Second, "Duration should be less than 2s, got %v", duration)
|
||||
}
|
||||
|
||||
func TestSyncStartNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncWaitNoTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Set up handler that will eventually succeed
|
||||
callCount := 0
|
||||
server.setHandler("sync.wait", func(unitName string) (string, error) {
|
||||
callCount++
|
||||
if callCount >= 3 {
|
||||
// After 3 calls, dependencies are satisfied
|
||||
return "Dependencies satisfied", nil
|
||||
}
|
||||
return "", unit.ErrDependenciesNotSatisfied
|
||||
})
|
||||
|
||||
// Test without timeout - should eventually succeed
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit")
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should succeed after a few polling cycles
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should take at least 2 seconds (2 polling cycles at 1s interval)
|
||||
assert.True(t, duration >= 2*time.Second, "Duration should be at least 2s, got %v", duration)
|
||||
assert.True(t, callCount >= 3, "Should have made at least 3 calls, got %d", callCount)
|
||||
}
|
||||
|
||||
func TestSyncStartTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "start", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncWaitTimeoutWithDifferentValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
timeout string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"50ms", "50ms", 50 * time.Millisecond},
|
||||
{"200ms", "200ms", 200 * time.Millisecond},
|
||||
{"500ms", "500ms", 500 * time.Millisecond},
|
||||
{"1s", "1s", 1 * time.Second},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a unique temporary socket file
|
||||
socketPath := fmt.Sprintf("/tmp/coder-test-%d.sock", time.Now().UnixNano())
|
||||
// Remove existing socket if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
defer func() { _ = os.Remove(socketPath) }()
|
||||
|
||||
// Start mock server
|
||||
server := newMockAgentSocketServer(t, socketPath)
|
||||
defer server.close()
|
||||
|
||||
// Test with specified timeout
|
||||
inv, _ := clitest.New(t, "exp", "sync", "wait", "test-unit", "--timeout", tc.timeout)
|
||||
|
||||
// Override the socket path for this test
|
||||
inv.Args = append(inv.Args, "--agent-socket", socketPath)
|
||||
|
||||
start := time.Now()
|
||||
err := inv.Run()
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should timeout after approximately the specified duration
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeout waiting for dependencies of unit 'test-unit'")
|
||||
|
||||
// Should timeout within a reasonable range
|
||||
assert.True(t, duration >= tc.expected, "Duration should be at least %v, got %v", tc.expected, duration)
|
||||
assert.True(t, duration < tc.expected+2*time.Second, "Duration should be less than %v, got %v", tc.expected+2*time.Second, duration)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/agent/unit"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// PollInterval is the interval between dependency checks
|
||||
PollInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWait() *serpent.Command {
|
||||
var timeout time.Duration
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "wait <unit>",
|
||||
Short: "Wait for a unit's dependencies to be satisfied",
|
||||
Long: "Poll until all dependencies for a unit are met. Exits when dependencies are satisfied or timeout is reached.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 1 {
|
||||
return xerrors.New("exactly one unit name is required")
|
||||
}
|
||||
unitName := i.Args[0]
|
||||
|
||||
// Set up context with timeout if specified
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Waiting for dependencies of unit '%s' to be satisfied...\n", unitName)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Poll until dependencies are satisfied
|
||||
ticker := time.NewTicker(PollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return xerrors.Errorf("timeout waiting for dependencies of unit '%s'", unitName)
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// Check if dependencies are satisfied
|
||||
err := client.SyncReady(ctx, unitName)
|
||||
if err == nil {
|
||||
// Dependencies are satisfied
|
||||
fmt.Printf("Dependencies for unit '%s' are now satisfied\n", unitName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if it's a "not ready" error (expected while waiting)
|
||||
if xerrors.Is(err, unit.ErrDependenciesNotSatisfied) {
|
||||
// Still waiting, continue polling
|
||||
continue
|
||||
}
|
||||
|
||||
// Some other error occurred
|
||||
return xerrors.Errorf("error checking dependencies: %w", err)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = append(cmd.Options, serpent.Option{
|
||||
Flag: "timeout",
|
||||
Description: "Maximum time to wait for dependencies (e.g., 30s, 5m). No timeout by default.",
|
||||
Value: serpent.DurationOf(&timeout),
|
||||
})
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func (r *RootCmd) syncWant() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "want <unit> <depends-on>",
|
||||
Short: "Declare a dependency between units",
|
||||
Long: "Declare that a unit depends on another unit reaching complete status.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if len(i.Args) != 2 {
|
||||
return xerrors.New("exactly two arguments are required: unit and depends-on")
|
||||
}
|
||||
unit := i.Args[0]
|
||||
dependsOn := i.Args[1]
|
||||
|
||||
// Show initial message
|
||||
fmt.Printf("Declaring dependency: '%s' depends on '%s'...\n", unit, dependsOn)
|
||||
|
||||
// Connect to agent socket
|
||||
client, err := agentsdk.NewSocketClient(agentsdk.SocketConfig{
|
||||
Path: "/tmp/coder.sock",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("connect to agent socket: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Declare the dependency
|
||||
if err := client.SyncWant(ctx, unit, dependsOn); err != nil {
|
||||
return xerrors.Errorf("declare dependency failed: %w", err)
|
||||
}
|
||||
|
||||
// Display success message
|
||||
fmt.Printf("Dependency declared: '%s' now depends on '%s'\n", unit, dependsOn)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncWant(t *testing.T) {
|
||||
}
|
||||
+3
@@ -67,6 +67,9 @@ OPTIONS:
|
||||
--script-data-dir string, $CODER_AGENT_SCRIPT_DATA_DIR (default: /tmp)
|
||||
Specify the location for storing script data.
|
||||
|
||||
--socket-path string, $CODER_AGENT_SOCKET_PATH
|
||||
Specify the path for the agent socket.
|
||||
|
||||
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
|
||||
Specify the max timeout for a SSH connection, it is advisable to set
|
||||
it to a minimum of 60s, but no more than 72h.
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@ USAGE:
|
||||
Get started with a templated template.
|
||||
|
||||
OPTIONS:
|
||||
--id aws-devcontainer|aws-linux|aws-windows|azure-linux|digitalocean-linux|docker|docker-devcontainer|docker-envbuilder|gcp-devcontainer|gcp-linux|gcp-vm-container|gcp-windows|kubernetes|kubernetes-devcontainer|nomad-docker|scratch
|
||||
--id aws-devcontainer|aws-linux|aws-windows|azure-linux|digitalocean-linux|docker|docker-devcontainer|docker-envbuilder|gcp-devcontainer|gcp-linux|gcp-vm-container|gcp-windows|kubernetes|kubernetes-devcontainer|nomad-docker|scratch|tasks-docker
|
||||
Specify a given example template by ID.
|
||||
|
||||
———
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ USAGE:
|
||||
Aliases: ls
|
||||
|
||||
OPTIONS:
|
||||
-c, --column [id|username|email|created at|updated at|status] (default: username,email,created at,status)
|
||||
-c, --column [id|username|name|email|created at|updated at|status] (default: username,email,created at,status)
|
||||
Columns to display in table output.
|
||||
|
||||
--github-user-id int
|
||||
|
||||
+302
-375
@@ -2,8 +2,6 @@ package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -12,12 +10,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -25,7 +24,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/searchquery"
|
||||
"github.com/coder/coder/v2/coderd/taskname"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
|
||||
aiagentapi "github.com/coder/agentapi-sdk-go"
|
||||
@@ -96,31 +94,54 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
|
||||
// This endpoint creates a new task for the given user.
|
||||
func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = api.Auditor.Load()
|
||||
mems = httpmw.OrganizationMembersParam(r)
|
||||
ctx = r.Context()
|
||||
apiKey = httpmw.APIKey(r)
|
||||
auditor = api.Auditor.Load()
|
||||
mems = httpmw.OrganizationMembersParam(r)
|
||||
taskResourceInfo = audit.AdditionalFields{}
|
||||
)
|
||||
|
||||
if mems.User != nil {
|
||||
taskResourceInfo.WorkspaceOwner = mems.User.Username
|
||||
}
|
||||
|
||||
aReq, commitAudit := audit.InitRequest[database.TaskTable](rw, &audit.RequestParams{
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
AdditionalFields: taskResourceInfo,
|
||||
})
|
||||
|
||||
defer commitAudit()
|
||||
|
||||
var req codersdk.CreateTaskRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
hasAITask, err := api.Database.GetTemplateVersionHasAITask(ctx, req.TemplateVersionID)
|
||||
// Fetch the template version to verify access and whether or not it has an
|
||||
// AI task.
|
||||
templateVersion, err := api.Database.GetTemplateVersionByID(ctx, req.TemplateVersionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) || rbac.IsUnauthorizedError(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
if httpapi.Is404Error(err) {
|
||||
// Avoid using httpapi.ResourceNotFound() here because this is an
|
||||
// input error and 404 would be confusing.
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Template version not found or you do not have access to this resource",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching whether the template version has an AI task.",
|
||||
Message: "Internal error fetching template version.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !hasAITask {
|
||||
|
||||
aReq.UpdateOrganizationID(templateVersion.OrganizationID)
|
||||
|
||||
if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf(`Template does not have required parameter %q`, codersdk.AITaskPromptParameterName),
|
||||
})
|
||||
@@ -177,23 +198,12 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
// A task can still be created if the caller can read the organization
|
||||
// member. The organization is required, which can be sourced from the
|
||||
// template.
|
||||
// templateVersion.
|
||||
//
|
||||
// TODO: This code gets called twice for each workspace build request.
|
||||
// This is inefficient and costs at most 2 extra RTTs to the DB.
|
||||
// This can be optimized. It exists as it is now for code simplicity.
|
||||
// The most common case is to create a workspace for 'Me'. Which does
|
||||
// not enter this code branch.
|
||||
template, err := requestTemplate(ctx, createReq, api.Database)
|
||||
if err != nil {
|
||||
httperror.WriteResponseError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the caller can find the organization membership in the same org
|
||||
// as the template, then they can continue.
|
||||
orgIndex := slices.IndexFunc(mems.Memberships, func(mem httpmw.OrganizationMember) bool {
|
||||
return mem.OrganizationID == template.OrganizationID
|
||||
return mem.OrganizationID == templateVersion.OrganizationID
|
||||
})
|
||||
if orgIndex == -1 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
@@ -206,56 +216,112 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
Username: member.Username,
|
||||
AvatarURL: member.AvatarURL,
|
||||
}
|
||||
|
||||
// Update workspace owner information for audit in case it changed.
|
||||
taskResourceInfo.WorkspaceOwner = owner.Username
|
||||
}
|
||||
|
||||
aReq, commitAudit := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
AdditionalFields: audit.AdditionalFields{
|
||||
WorkspaceOwner: owner.Username,
|
||||
// Track insert from preCreateInTX.
|
||||
var dbTaskTable database.TaskTable
|
||||
|
||||
// Ensure an audit log is created for the workspace creation event.
|
||||
aReqWS, commitAuditWS := audit.InitRequest[database.WorkspaceTable](rw, &audit.RequestParams{
|
||||
Audit: *auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionCreate,
|
||||
AdditionalFields: taskResourceInfo,
|
||||
OrganizationID: templateVersion.OrganizationID,
|
||||
})
|
||||
defer commitAuditWS()
|
||||
|
||||
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
|
||||
// Before creating the workspace, ensure that this task can be created.
|
||||
preCreateInTX: func(ctx context.Context, tx database.Store) error {
|
||||
// Create task record in the database before creating the workspace so that
|
||||
// we can request that the workspace be linked to it after creation.
|
||||
dbTaskTable, err = tx.InsertTask(ctx, database.InsertTaskParams{
|
||||
OrganizationID: templateVersion.OrganizationID,
|
||||
OwnerID: owner.ID,
|
||||
Name: taskName,
|
||||
WorkspaceID: uuid.NullUUID{}, // Will be set after workspace creation.
|
||||
TemplateVersionID: templateVersion.ID,
|
||||
TemplateParameters: []byte("{}"),
|
||||
Prompt: req.Input,
|
||||
CreatedAt: dbtime.Time(api.Clock.Now()),
|
||||
})
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating task.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
// After the workspace is created, ensure that the task is linked to it.
|
||||
postCreateInTX: func(ctx context.Context, tx database.Store, workspace database.Workspace) error {
|
||||
// Update the task record with the workspace ID after creation.
|
||||
dbTaskTable, err = tx.UpdateTaskWorkspaceID(ctx, database.UpdateTaskWorkspaceIDParams{
|
||||
ID: dbTaskTable.ID,
|
||||
WorkspaceID: uuid.NullUUID{
|
||||
UUID: workspace.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating task.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
defer commitAudit()
|
||||
w, err := createWorkspace(ctx, aReq, apiKey.UserID, api, owner, createReq, r)
|
||||
if err != nil {
|
||||
httperror.WriteResponseError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
task := taskFromWorkspace(w, req.Input)
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, task)
|
||||
aReq.New = dbTaskTable
|
||||
|
||||
// Fetch the task to get the additional columns from the view.
|
||||
dbTask, err := api.Database.GetTaskByID(ctx, dbTaskTable.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching task.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, taskFromDBTaskAndWorkspace(dbTask, workspace))
|
||||
}
|
||||
|
||||
func taskFromWorkspace(ws codersdk.Workspace, initialPrompt string) codersdk.Task {
|
||||
// TODO(DanielleMaywood):
|
||||
// This just picks up the first agent it discovers.
|
||||
// This approach _might_ break when a task has multiple agents,
|
||||
// depending on which agent was found first.
|
||||
//
|
||||
// We explicitly do not have support for running tasks
|
||||
// inside of a sub agent at the moment, so we can be sure
|
||||
// that any sub agents are not the agent we're looking for.
|
||||
var taskAgentID uuid.NullUUID
|
||||
// taskFromDBTaskAndWorkspace creates a codersdk.Task response from the task
|
||||
// database record and workspace.
|
||||
func taskFromDBTaskAndWorkspace(dbTask database.Task, ws codersdk.Workspace) codersdk.Task {
|
||||
var taskAgentLifecycle *codersdk.WorkspaceAgentLifecycle
|
||||
var taskAgentHealth *codersdk.WorkspaceAgentHealth
|
||||
for _, resource := range ws.LatestBuild.Resources {
|
||||
for _, agent := range resource.Agents {
|
||||
if agent.ParentID.Valid {
|
||||
continue
|
||||
}
|
||||
|
||||
taskAgentID = uuid.NullUUID{Valid: true, UUID: agent.ID}
|
||||
taskAgentLifecycle = &agent.LifecycleState
|
||||
taskAgentHealth = &agent.Health
|
||||
break
|
||||
// If we have an agent ID from the task, find the agent details in the
|
||||
// workspace.
|
||||
if dbTask.WorkspaceAgentID.Valid {
|
||||
findTaskAgentLoop:
|
||||
for _, resource := range ws.LatestBuild.Resources {
|
||||
for _, agent := range resource.Agents {
|
||||
if agent.ID == dbTask.WorkspaceAgentID.UUID {
|
||||
taskAgentLifecycle = &agent.LifecycleState
|
||||
taskAgentHealth = &agent.Health
|
||||
break findTaskAgentLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ignore 'latest app status' if it is older than the latest build and the latest build is a 'start' transition.
|
||||
// This ensures that you don't show a stale app status from a previous build.
|
||||
// For stop transitions, there is still value in showing the latest app status.
|
||||
// Ignore 'latest app status' if it is older than the latest build and the
|
||||
// latest build is a 'start' transition. This ensures that you don't show a
|
||||
// stale app status from a previous build. For stop transitions, there is
|
||||
// still value in showing the latest app status.
|
||||
var currentState *codersdk.TaskStateEntry
|
||||
if ws.LatestAppStatus != nil {
|
||||
if ws.LatestBuild.Transition != codersdk.WorkspaceTransitionStart || ws.LatestAppStatus.CreatedAt.After(ws.LatestBuild.CreatedAt) {
|
||||
@@ -268,188 +334,135 @@ func taskFromWorkspace(ws codersdk.Workspace, initialPrompt string) codersdk.Tas
|
||||
}
|
||||
}
|
||||
|
||||
var appID uuid.NullUUID
|
||||
if ws.LatestBuild.AITaskSidebarAppID != nil {
|
||||
appID = uuid.NullUUID{
|
||||
Valid: true,
|
||||
UUID: *ws.LatestBuild.AITaskSidebarAppID,
|
||||
}
|
||||
}
|
||||
|
||||
return codersdk.Task{
|
||||
ID: ws.ID,
|
||||
OrganizationID: ws.OrganizationID,
|
||||
OwnerID: ws.OwnerID,
|
||||
ID: dbTask.ID,
|
||||
OrganizationID: dbTask.OrganizationID,
|
||||
OwnerID: dbTask.OwnerID,
|
||||
OwnerName: ws.OwnerName,
|
||||
Name: ws.Name,
|
||||
OwnerAvatarURL: ws.OwnerAvatarURL,
|
||||
Name: dbTask.Name,
|
||||
TemplateID: ws.TemplateID,
|
||||
TemplateVersionID: dbTask.TemplateVersionID,
|
||||
TemplateName: ws.TemplateName,
|
||||
TemplateDisplayName: ws.TemplateDisplayName,
|
||||
TemplateIcon: ws.TemplateIcon,
|
||||
WorkspaceID: uuid.NullUUID{Valid: true, UUID: ws.ID},
|
||||
WorkspaceBuildNumber: ws.LatestBuild.BuildNumber,
|
||||
WorkspaceAgentID: taskAgentID,
|
||||
WorkspaceID: dbTask.WorkspaceID,
|
||||
WorkspaceName: ws.Name,
|
||||
WorkspaceBuildNumber: dbTask.WorkspaceBuildNumber.Int32,
|
||||
WorkspaceStatus: ws.LatestBuild.Status,
|
||||
WorkspaceAgentID: dbTask.WorkspaceAgentID,
|
||||
WorkspaceAgentLifecycle: taskAgentLifecycle,
|
||||
WorkspaceAgentHealth: taskAgentHealth,
|
||||
WorkspaceAppID: appID,
|
||||
CreatedAt: ws.CreatedAt,
|
||||
UpdatedAt: ws.UpdatedAt,
|
||||
InitialPrompt: initialPrompt,
|
||||
Status: ws.LatestBuild.Status,
|
||||
WorkspaceAppID: dbTask.WorkspaceAppID,
|
||||
InitialPrompt: dbTask.Prompt,
|
||||
Status: codersdk.TaskStatus(dbTask.Status),
|
||||
CurrentState: currentState,
|
||||
CreatedAt: dbTask.CreatedAt,
|
||||
UpdatedAt: ws.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// tasksFromWorkspaces converts a slice of API workspaces into tasks, fetching
|
||||
// prompts and mapping status/state. This method enforces that only AI task
|
||||
// workspaces are given.
|
||||
func (api *API) tasksFromWorkspaces(ctx context.Context, apiWorkspaces []codersdk.Workspace) ([]codersdk.Task, error) {
|
||||
// Fetch prompts for each workspace build and map by build ID.
|
||||
buildIDs := make([]uuid.UUID, 0, len(apiWorkspaces))
|
||||
for _, ws := range apiWorkspaces {
|
||||
buildIDs = append(buildIDs, ws.LatestBuild.ID)
|
||||
}
|
||||
parameters, err := api.Database.GetWorkspaceBuildParametersByBuildIDs(ctx, buildIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
promptsByBuildID := make(map[uuid.UUID]string, len(parameters))
|
||||
for _, p := range parameters {
|
||||
if p.Name == codersdk.AITaskPromptParameterName {
|
||||
promptsByBuildID[p.WorkspaceBuildID] = p.Value
|
||||
}
|
||||
}
|
||||
|
||||
tasks := make([]codersdk.Task, 0, len(apiWorkspaces))
|
||||
for _, ws := range apiWorkspaces {
|
||||
tasks = append(tasks, taskFromWorkspace(ws, promptsByBuildID[ws.LatestBuild.ID]))
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// tasksListResponse wraps a list of experimental tasks.
|
||||
//
|
||||
// Experimental: Response shape is experimental and may change.
|
||||
type tasksListResponse struct {
|
||||
Tasks []codersdk.Task `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// @Summary List AI tasks
|
||||
// @Description: EXPERIMENTAL: this endpoint is experimental and not guaranteed to be stable.
|
||||
// @ID list-tasks
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Experimental
|
||||
// @Param q query string false "Search query for filtering tasks"
|
||||
// @Param after_id query string false "Return tasks after this ID for pagination"
|
||||
// @Param limit query int false "Maximum number of tasks to return" minimum(1) maximum(100) default(25)
|
||||
// @Param offset query int false "Offset for pagination" minimum(0) default(0)
|
||||
// @Success 200 {object} coderd.tasksListResponse
|
||||
// @Param q query string false "Search query for filtering tasks. Supports: owner:<username/uuid/me>, organization:<org-name/uuid>, status:<status>"
|
||||
// @Success 200 {object} codersdk.TasksListResponse
|
||||
// @Router /api/experimental/tasks [get]
|
||||
//
|
||||
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
|
||||
// tasksList is an experimental endpoint to list AI tasks by mapping
|
||||
// workspaces to a task-shaped response.
|
||||
// tasksList is an experimental endpoint to list tasks.
|
||||
func (api *API) tasksList(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Support standard pagination/filters for workspaces.
|
||||
page, ok := ParsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Parse query parameters for filtering tasks.
|
||||
queryStr := r.URL.Query().Get("q")
|
||||
filter, errs := searchquery.Workspaces(ctx, api.Database, queryStr, page, api.AgentInactiveDisconnectTimeout)
|
||||
filter, errs := searchquery.Tasks(ctx, api.Database, queryStr, apiKey.UserID)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid workspace search query.",
|
||||
Message: "Invalid task search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure that we only include AI task workspaces in the results.
|
||||
filter.HasAITask = sql.NullBool{Valid: true, Bool: true}
|
||||
|
||||
if filter.OwnerUsername == "me" {
|
||||
filter.OwnerID = apiKey.UserID
|
||||
filter.OwnerUsername = ""
|
||||
}
|
||||
|
||||
prepared, err := api.HTTPAuth.AuthorizeSQLFilter(r, policy.ActionRead, rbac.ResourceWorkspace.Type)
|
||||
// Fetch all tasks matching the filters from the database.
|
||||
dbTasks, err := api.Database.ListTasks(ctx, filter)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error preparing sql filter.",
|
||||
Message: "Internal error fetching tasks.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Order with requester's favorites first, include summary row.
|
||||
filter.RequesterID = apiKey.UserID
|
||||
filter.WithSummary = true
|
||||
|
||||
workspaceRows, err := api.Database.GetAuthorizedWorkspaces(ctx, filter, prepared)
|
||||
tasks, err := api.convertTasks(ctx, apiKey.UserID, dbTasks)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
Message: "Internal error converting tasks.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(workspaceRows) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspaces.",
|
||||
Detail: "Workspace summary row is missing.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(workspaceRows) == 1 {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, tasksListResponse{
|
||||
Tasks: []codersdk.Task{},
|
||||
Count: 0,
|
||||
})
|
||||
return
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.TasksListResponse{
|
||||
Tasks: tasks,
|
||||
Count: len(tasks),
|
||||
})
|
||||
}
|
||||
|
||||
// convertTasks converts database tasks to API tasks, enriching them with
|
||||
// workspace information.
|
||||
func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks []database.Task) ([]codersdk.Task, error) {
|
||||
if len(dbTasks) == 0 {
|
||||
return []codersdk.Task{}, nil
|
||||
}
|
||||
|
||||
// Skip summary row.
|
||||
workspaceRows = workspaceRows[:len(workspaceRows)-1]
|
||||
// Prepare to batch fetch workspaces.
|
||||
workspaceIDs := make([]uuid.UUID, 0, len(dbTasks))
|
||||
for _, task := range dbTasks {
|
||||
if !task.WorkspaceID.Valid {
|
||||
return nil, xerrors.New("task has no workspace ID")
|
||||
}
|
||||
workspaceIDs = append(workspaceIDs, task.WorkspaceID.UUID)
|
||||
}
|
||||
|
||||
// Fetch workspaces for tasks that have workspaces.
|
||||
workspaceRows, err := api.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{
|
||||
WorkspaceIds: workspaceIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("fetch workspaces: %w", err)
|
||||
}
|
||||
|
||||
workspaces := database.ConvertWorkspaceRows(workspaceRows)
|
||||
|
||||
// Gather associated data and convert to API workspaces.
|
||||
data, err := api.workspaceData(ctx, workspaces)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace resources.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
apiWorkspaces, err := convertWorkspaces(apiKey.UserID, workspaces, data)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error converting workspaces.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
return nil, xerrors.Errorf("fetch workspace data: %w", err)
|
||||
}
|
||||
|
||||
tasks, err := api.tasksFromWorkspaces(ctx, apiWorkspaces)
|
||||
apiWorkspaces, err := convertWorkspaces(requesterID, workspaces, data)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching task prompts and states.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
return nil, xerrors.Errorf("convert workspaces: %w", err)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, tasksListResponse{
|
||||
Tasks: tasks,
|
||||
Count: len(tasks),
|
||||
})
|
||||
workspacesByID := make(map[uuid.UUID]codersdk.Workspace)
|
||||
for _, ws := range apiWorkspaces {
|
||||
workspacesByID[ws.ID] = ws
|
||||
}
|
||||
|
||||
// Convert tasks to SDK format.
|
||||
result := make([]codersdk.Task, 0, len(dbTasks))
|
||||
for _, dbTask := range dbTasks {
|
||||
task := taskFromDBTaskAndWorkspace(dbTask, workspacesByID[dbTask.WorkspaceID.UUID])
|
||||
result = append(result, task)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// @Summary Get AI task by ID
|
||||
@@ -458,9 +471,9 @@ func (api *API) tasksList(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Experimental
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param id path string true "Task ID" format(uuid)
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
// @Success 200 {object} codersdk.Task
|
||||
// @Router /api/experimental/tasks/{user}/{id} [get]
|
||||
// @Router /api/experimental/tasks/{user}/{task} [get]
|
||||
//
|
||||
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
|
||||
// taskGet is an experimental endpoint to fetch a single AI task by ID
|
||||
@@ -469,25 +482,22 @@ func (api *API) tasksList(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
task := httpmw.TaskParam(r)
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
taskID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
|
||||
if !task.WorkspaceID.Valid {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching task.",
|
||||
Detail: "Task workspace ID is invalid.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For now, taskID = workspaceID, once we have a task data model in
|
||||
// the DB, we can change this lookup.
|
||||
workspaceID := taskID
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace.",
|
||||
Detail: err.Error(),
|
||||
@@ -507,34 +517,6 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if data.builds[0].HasAITask == nil || !*data.builds[0].HasAITask {
|
||||
// TODO(DanielleMaywood):
|
||||
// This is a temporary workaround. When a task has just been created, but
|
||||
// not yet provisioned, the workspace build will not have `HasAITask` set.
|
||||
//
|
||||
// When we reach this code flow, it is _either_ because the workspace is
|
||||
// not a task, or it is a task that has not yet been provisioned. This
|
||||
// endpoint should rarely be called with a non-task workspace so we
|
||||
// should be fine with this extra database call to check if it has the
|
||||
// special "AI Task" parameter.
|
||||
parameters, err := api.Database.GetWorkspaceBuildParameters(ctx, data.builds[0].ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace build parameters.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
_, hasAITask := slice.Find(parameters, func(t database.WorkspaceBuildParameter) bool {
|
||||
return t.Name == codersdk.AITaskPromptParameterName
|
||||
})
|
||||
|
||||
if !hasAITask {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
appStatus := codersdk.WorkspaceAppStatus{}
|
||||
if len(data.appStatuses) > 0 {
|
||||
@@ -557,16 +539,8 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
tasks, err := api.tasksFromWorkspaces(ctx, []codersdk.Workspace{ws})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching task prompt and state.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, tasks[0])
|
||||
taskResp := taskFromDBTaskAndWorkspace(task, ws)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, taskResp)
|
||||
}
|
||||
|
||||
// @Summary Delete AI task by ID
|
||||
@@ -575,83 +549,71 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Experimental
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param id path string true "Task ID" format(uuid)
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
// @Success 202 "Task deletion initiated"
|
||||
// @Router /api/experimental/tasks/{user}/{id} [delete]
|
||||
// @Router /api/experimental/tasks/{user}/{task} [delete]
|
||||
//
|
||||
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
|
||||
// taskDelete is an experimental endpoint to delete a task by ID (workspace ID).
|
||||
// taskDelete is an experimental endpoint to delete a task by ID.
|
||||
// It creates a delete workspace build and returns 202 Accepted if the build was
|
||||
// created.
|
||||
func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
task := httpmw.TaskParam(r)
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
taskID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
|
||||
})
|
||||
return
|
||||
now := api.Clock.Now()
|
||||
|
||||
if task.WorkspaceID.Valid {
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching task workspace before deleting task.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Construct a request to the workspace build creation handler to
|
||||
// initiate deletion.
|
||||
buildReq := codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionDelete,
|
||||
Reason: "Deleted via tasks API",
|
||||
}
|
||||
|
||||
_, err = api.postWorkspaceBuildsInternal(
|
||||
ctx,
|
||||
apiKey,
|
||||
workspace,
|
||||
buildReq,
|
||||
func(action policy.Action, object rbac.Objecter) bool {
|
||||
return api.Authorize(r, action, object)
|
||||
},
|
||||
audit.WorkspaceBuildBaggageFromRequest(r),
|
||||
)
|
||||
if err != nil {
|
||||
httperror.WriteWorkspaceBuildError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// For now, taskID = workspaceID, once we have a task data model in
|
||||
// the DB, we can change this lookup.
|
||||
workspaceID := taskID
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
_, err := api.Database.DeleteTask(ctx, database.DeleteTaskParams{
|
||||
ID: task.ID,
|
||||
DeletedAt: dbtime.Time(now),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace.",
|
||||
Message: "Failed to delete task",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := api.workspaceData(ctx, []database.Workspace{workspace})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace resources.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(data.builds) == 0 || len(data.templates) == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if data.builds[0].HasAITask == nil || !*data.builds[0].HasAITask {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
// Construct a request to the workspace build creation handler to
|
||||
// initiate deletion.
|
||||
buildReq := codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: codersdk.WorkspaceTransitionDelete,
|
||||
Reason: "Deleted via tasks API",
|
||||
}
|
||||
|
||||
_, err = api.postWorkspaceBuildsInternal(
|
||||
ctx,
|
||||
apiKey,
|
||||
workspace,
|
||||
buildReq,
|
||||
func(action policy.Action, object rbac.Objecter) bool {
|
||||
return api.Authorize(r, action, object)
|
||||
},
|
||||
audit.WorkspaceBuildBaggageFromRequest(r),
|
||||
)
|
||||
if err != nil {
|
||||
httperror.WriteWorkspaceBuildError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete build created successfully.
|
||||
// Task deleted and delete build created successfully.
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
|
||||
@@ -661,26 +623,18 @@ func (api *API) taskDelete(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Experimental
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param id path string true "Task ID" format(uuid)
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
// @Param request body codersdk.TaskSendRequest true "Task input request"
|
||||
// @Success 204 "Input sent successfully"
|
||||
// @Router /api/experimental/tasks/{user}/{id}/send [post]
|
||||
// @Router /api/experimental/tasks/{user}/{task}/send [post]
|
||||
//
|
||||
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
|
||||
// taskSend submits task input to the tasks sidebar app by dialing the agent
|
||||
// taskSend submits task input to the task app by dialing the agent
|
||||
// directly over the tailnet. We enforce ApplicationConnect RBAC on the
|
||||
// workspace and validate the sidebar app health.
|
||||
// workspace and validate the task app health.
|
||||
func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
taskID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
|
||||
})
|
||||
return
|
||||
}
|
||||
task := httpmw.TaskParam(r)
|
||||
|
||||
var req codersdk.TaskSendRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
@@ -693,7 +647,7 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = api.authAndDoWithTaskSidebarAppClient(r, taskID, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
|
||||
if err := api.authAndDoWithTaskAppClient(r, task, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
|
||||
agentAPIClient, err := aiagentapi.NewClient(appURL.String(), aiagentapi.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
|
||||
@@ -743,27 +697,19 @@ func (api *API) taskSend(rw http.ResponseWriter, r *http.Request) {
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Experimental
|
||||
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
|
||||
// @Param id path string true "Task ID" format(uuid)
|
||||
// @Param task path string true "Task ID" format(uuid)
|
||||
// @Success 200 {object} codersdk.TaskLogsResponse
|
||||
// @Router /api/experimental/tasks/{user}/{id}/logs [get]
|
||||
// @Router /api/experimental/tasks/{user}/{task}/logs [get]
|
||||
//
|
||||
// EXPERIMENTAL: This endpoint is experimental and not guaranteed to be stable.
|
||||
// taskLogs reads task output by dialing the agent directly over the tailnet.
|
||||
// We enforce ApplicationConnect RBAC on the workspace and validate the sidebar app health.
|
||||
// We enforce ApplicationConnect RBAC on the workspace and validate the task app health.
|
||||
func (api *API) taskLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
idStr := chi.URLParam(r, "id")
|
||||
taskID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Invalid UUID %q for task ID.", idStr),
|
||||
})
|
||||
return
|
||||
}
|
||||
task := httpmw.TaskParam(r)
|
||||
|
||||
var out codersdk.TaskLogsResponse
|
||||
if err := api.authAndDoWithTaskSidebarAppClient(r, taskID, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
|
||||
if err := api.authAndDoWithTaskAppClient(r, task, func(ctx context.Context, client *http.Client, appURL *url.URL) error {
|
||||
agentAPIClient, err := aiagentapi.NewClient(appURL.String(), aiagentapi.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
|
||||
@@ -811,24 +757,40 @@ func (api *API) taskLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, out)
|
||||
}
|
||||
|
||||
// authAndDoWithTaskSidebarAppClient centralizes the shared logic to:
|
||||
// authAndDoWithTaskAppClient centralizes the shared logic to:
|
||||
//
|
||||
// - Fetch the task workspace
|
||||
// - Authorize ApplicationConnect on the workspace
|
||||
// - Validate the AI task and sidebar app health
|
||||
// - Validate the AI task and task app health
|
||||
// - Dial the agent and construct an HTTP client to the apps loopback URL
|
||||
//
|
||||
// The provided callback receives the context, an HTTP client that dials via the
|
||||
// agent, and the base app URL (as a value URL) to perform any request.
|
||||
func (api *API) authAndDoWithTaskSidebarAppClient(
|
||||
func (api *API) authAndDoWithTaskAppClient(
|
||||
r *http.Request,
|
||||
taskID uuid.UUID,
|
||||
task database.Task,
|
||||
do func(ctx context.Context, client *http.Client, appURL *url.URL) error,
|
||||
) error {
|
||||
ctx := r.Context()
|
||||
|
||||
workspaceID := taskID
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, workspaceID)
|
||||
if task.Status != database.TaskStatusActive {
|
||||
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Task status must be active.",
|
||||
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
|
||||
})
|
||||
}
|
||||
if !task.WorkspaceID.Valid {
|
||||
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Task does not have a workspace.",
|
||||
})
|
||||
}
|
||||
if !task.WorkspaceAppID.Valid {
|
||||
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Task does not have a workspace app.",
|
||||
})
|
||||
}
|
||||
|
||||
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
return httperror.ErrResourceNotFound
|
||||
@@ -844,65 +806,30 @@ func (api *API) authAndDoWithTaskSidebarAppClient(
|
||||
return httperror.ErrResourceNotFound
|
||||
}
|
||||
|
||||
data, err := api.workspaceData(ctx, []database.Workspace{workspace})
|
||||
apps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, task.WorkspaceAgentID.UUID)
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace resources.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
if len(data.builds) == 0 || len(data.templates) == 0 {
|
||||
return httperror.ErrResourceNotFound
|
||||
}
|
||||
build := data.builds[0]
|
||||
if build.HasAITask == nil || !*build.HasAITask || build.AITaskSidebarAppID == nil || *build.AITaskSidebarAppID == uuid.Nil {
|
||||
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Task is not configured with a sidebar app.",
|
||||
})
|
||||
}
|
||||
|
||||
// Find the sidebar app details to get the URL and validate app health.
|
||||
sidebarAppID := *build.AITaskSidebarAppID
|
||||
agentID, sidebarApp, ok := func() (uuid.UUID, codersdk.WorkspaceApp, bool) {
|
||||
for _, res := range build.Resources {
|
||||
for _, agent := range res.Agents {
|
||||
for _, app := range agent.Apps {
|
||||
if app.ID == sidebarAppID {
|
||||
return agent.ID, app, true
|
||||
}
|
||||
}
|
||||
}
|
||||
var app *database.WorkspaceApp
|
||||
for _, a := range apps {
|
||||
if a.ID == task.WorkspaceAppID.UUID {
|
||||
app = &a
|
||||
break
|
||||
}
|
||||
return uuid.Nil, codersdk.WorkspaceApp{}, false
|
||||
}()
|
||||
if !ok {
|
||||
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Task sidebar app not found in latest build.",
|
||||
})
|
||||
}
|
||||
|
||||
// Return an informative error if the app isn't healthy rather than trying
|
||||
// and failing.
|
||||
switch sidebarApp.Health {
|
||||
case codersdk.WorkspaceAppHealthDisabled:
|
||||
// No health check, pass through.
|
||||
case codersdk.WorkspaceAppHealthInitializing:
|
||||
return httperror.NewResponseError(http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Task sidebar app is initializing. Try again shortly.",
|
||||
})
|
||||
case codersdk.WorkspaceAppHealthUnhealthy:
|
||||
return httperror.NewResponseError(http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Task sidebar app is unhealthy.",
|
||||
})
|
||||
}
|
||||
|
||||
// Build the direct app URL and dial the agent.
|
||||
if sidebarApp.URL == "" {
|
||||
appURL := app.Url.String
|
||||
if appURL == "" {
|
||||
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Task sidebar app URL is not configured.",
|
||||
Message: "Task app URL is not configured.",
|
||||
})
|
||||
}
|
||||
parsedURL, err := url.Parse(sidebarApp.URL)
|
||||
parsedURL, err := url.Parse(appURL)
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error parsing task app URL.",
|
||||
@@ -917,7 +844,7 @@ func (api *API) authAndDoWithTaskSidebarAppClient(
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer dialCancel()
|
||||
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, agentID)
|
||||
agentConn, release, err := api.agentProvider.AgentConn(dialCtx, task.WorkspaceAgentID.UUID)
|
||||
if err != nil {
|
||||
return httperror.NewResponseError(http.StatusBadGateway, codersdk.Response{
|
||||
Message: "Failed to reach task app endpoint.",
|
||||
|
||||
+445
-248
@@ -2,7 +2,7 @@ package coderd_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
@@ -22,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
@@ -54,10 +54,6 @@ func TestAITasksPrompts(t *testing.T) {
|
||||
t.Run("MultipleBuilds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("This test checks RBAC, which is not supported in the in-memory database")
|
||||
}
|
||||
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
first := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, first.OrganizationID)
|
||||
@@ -215,8 +211,8 @@ func TestTasks(t *testing.T) {
|
||||
Apps: []*proto.App{
|
||||
{
|
||||
Id: taskAppID.String(),
|
||||
Slug: "task-sidebar",
|
||||
DisplayName: "Task Sidebar",
|
||||
Slug: "task-app",
|
||||
DisplayName: "Task App",
|
||||
Url: opt.appURL,
|
||||
},
|
||||
},
|
||||
@@ -226,9 +222,7 @@ func TestTasks(t *testing.T) {
|
||||
},
|
||||
AiTasks: []*proto.AITask{
|
||||
{
|
||||
SidebarApp: &proto.AITaskSidebarApp{
|
||||
Id: taskAppID.String(),
|
||||
},
|
||||
AppId: taskAppID.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -251,27 +245,33 @@ func TestTasks(t *testing.T) {
|
||||
|
||||
template := createAITemplate(t, client, user)
|
||||
|
||||
// Create a workspace (task) with a specific prompt.
|
||||
// Create a task with a specific prompt using the new data model.
|
||||
wantPrompt := "build me a web app"
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: wantPrompt},
|
||||
}
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: wantPrompt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
|
||||
// Wait for the workspace to be built.
|
||||
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// List tasks via experimental API and verify the prompt and status mapping.
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
tasks, err := exp.Tasks(ctx, &codersdk.TasksFilter{Owner: codersdk.Me})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, ok := slice.Find(tasks, func(task codersdk.Task) bool { return task.ID == workspace.ID })
|
||||
got, ok := slice.Find(tasks, func(t codersdk.Task) bool { return t.ID == task.ID })
|
||||
require.True(t, ok, "task should be found in the list")
|
||||
assert.Equal(t, wantPrompt, got.InitialPrompt, "task prompt should match the AI Prompt parameter")
|
||||
assert.Equal(t, workspace.Name, got.Name, "task name should map from workspace name")
|
||||
assert.Equal(t, workspace.ID, got.WorkspaceID.UUID, "workspace id should match")
|
||||
// Status should be populated via app status or workspace status mapping.
|
||||
assert.Equal(t, task.WorkspaceID.UUID, got.WorkspaceID.UUID, "workspace id should match")
|
||||
assert.Equal(t, task.WorkspaceName, got.WorkspaceName, "workspace name should match")
|
||||
// Status should be populated via the tasks_with_status view.
|
||||
assert.NotEmpty(t, got.Status, "task status should not be empty")
|
||||
assert.NotEmpty(t, got.WorkspaceStatus, "workspace status should not be empty")
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
@@ -282,17 +282,22 @@ func TestTasks(t *testing.T) {
|
||||
ctx = testutil.Context(t, testutil.WaitLong)
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
template = createAITemplate(t, client, user)
|
||||
// Create a workspace (task) with a specific prompt.
|
||||
wantPrompt = "review my code"
|
||||
workspace = coderdtest.CreateWorkspace(t, client, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: wantPrompt},
|
||||
}
|
||||
})
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
)
|
||||
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
ws := coderdtest.MustWorkspace(t, client, workspace.ID)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: wantPrompt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid)
|
||||
|
||||
// Get the workspace and wait for it to be ready.
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
ws = coderdtest.MustWorkspace(t, client, task.WorkspaceID.UUID)
|
||||
// Assert invariant: the workspace has exactly one resource with one agent with one app.
|
||||
require.Len(t, ws.LatestBuild.Resources, 1)
|
||||
require.Len(t, ws.LatestBuild.Resources[0].Agents, 1)
|
||||
@@ -300,9 +305,9 @@ func TestTasks(t *testing.T) {
|
||||
taskAppID := ws.LatestBuild.Resources[0].Agents[0].Apps[0].ID
|
||||
|
||||
// Insert an app status for the workspace
|
||||
_, err := db.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
_, err = db.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: workspace.ID,
|
||||
WorkspaceID: task.WorkspaceID.UUID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
AgentID: agentID,
|
||||
AppID: taskAppID,
|
||||
@@ -312,31 +317,34 @@ func TestTasks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch the task by ID via experimental API and verify fields.
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
task, err := exp.TaskByID(ctx, workspace.ID)
|
||||
updated, err := exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, workspace.ID, task.ID, "task ID should match workspace ID")
|
||||
assert.Equal(t, workspace.Name, task.Name, "task name should map from workspace name")
|
||||
assert.Equal(t, wantPrompt, task.InitialPrompt, "task prompt should match the AI Prompt parameter")
|
||||
assert.Equal(t, workspace.ID, task.WorkspaceID.UUID, "workspace id should match")
|
||||
assert.NotEmpty(t, task.Status, "task status should not be empty")
|
||||
assert.Equal(t, task.ID, updated.ID, "task ID should match")
|
||||
assert.Equal(t, task.Name, updated.Name, "task name should match")
|
||||
assert.Equal(t, wantPrompt, updated.InitialPrompt, "task prompt should match the AI Prompt parameter")
|
||||
assert.Equal(t, task.WorkspaceID.UUID, updated.WorkspaceID.UUID, "workspace id should match")
|
||||
assert.Equal(t, task.WorkspaceName, updated.WorkspaceName, "workspace name should match")
|
||||
assert.Equal(t, ws.LatestBuild.BuildNumber, updated.WorkspaceBuildNumber, "workspace build number should match")
|
||||
assert.Equal(t, agentID, updated.WorkspaceAgentID.UUID, "workspace agent id should match")
|
||||
assert.Equal(t, taskAppID, updated.WorkspaceAppID.UUID, "workspace app id should match")
|
||||
assert.NotEmpty(t, updated.WorkspaceStatus, "task status should not be empty")
|
||||
|
||||
// Stop the workspace
|
||||
coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
|
||||
coderdtest.MustTransitionWorkspace(t, client, task.WorkspaceID.UUID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop)
|
||||
|
||||
// Verify that the previous status still remains
|
||||
updated, err := exp.TaskByID(ctx, workspace.ID)
|
||||
updated, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, updated.CurrentState, "current state should not be nil")
|
||||
assert.Equal(t, "all done", updated.CurrentState.Message)
|
||||
assert.Equal(t, codersdk.TaskStateComplete, updated.CurrentState.State)
|
||||
|
||||
// Start the workspace again
|
||||
coderdtest.MustTransitionWorkspace(t, client, workspace.ID, codersdk.WorkspaceTransitionStop, codersdk.WorkspaceTransitionStart)
|
||||
coderdtest.MustTransitionWorkspace(t, client, task.WorkspaceID.UUID, codersdk.WorkspaceTransitionStop, codersdk.WorkspaceTransitionStart)
|
||||
|
||||
// Verify that the status from the previous build is no longer present
|
||||
updated, err = exp.TaskByID(ctx, workspace.ID)
|
||||
updated, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, updated.CurrentState, "current state should be nil")
|
||||
})
|
||||
@@ -359,7 +367,8 @@ func TestTasks(t *testing.T) {
|
||||
Input: "delete me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ws, err := client.Workspace(ctx, task.ID)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
@@ -368,7 +377,7 @@ func TestTasks(t *testing.T) {
|
||||
|
||||
// Poll until the workspace is deleted.
|
||||
for {
|
||||
dws, derr := client.DeletedWorkspace(ctx, task.ID)
|
||||
dws, derr := client.DeletedWorkspace(ctx, task.WorkspaceID.UUID)
|
||||
if derr == nil && dws.LatestBuild.Status == codersdk.WorkspaceStatusDeleted {
|
||||
break
|
||||
}
|
||||
@@ -439,7 +448,8 @@ func TestTasks(t *testing.T) {
|
||||
Input: "delete me not",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ws, err := client.Workspace(ctx, task.ID)
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
@@ -466,36 +476,37 @@ func TestTasks(t *testing.T) {
|
||||
t.Run("IntegrationOK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
createStatusResponse := func(status string) string {
|
||||
return `
|
||||
{
|
||||
"$schema": "http://localhost:3284/schemas/StatusResponseBody.json",
|
||||
"status": "` + status + `"
|
||||
}
|
||||
`
|
||||
}
|
||||
statusResponse := createStatusResponse("stable")
|
||||
statusResponse := agentapisdk.StatusStable
|
||||
|
||||
// Start a fake AgentAPI that accepts GET /status and POST /message.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/status" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := agentapisdk.GetStatusResponse{
|
||||
Status: statusResponse,
|
||||
}
|
||||
respBytes, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprint(w, statusResponse)
|
||||
w.Write(respBytes)
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodPost && r.URL.Path == "/message" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
assert.Equal(t, `{"content":"Hello, Agent!","type":"user"}`, string(b), "expected message content")
|
||||
expectedReq := agentapisdk.PostMessageParams{
|
||||
Content: "Hello, Agent!",
|
||||
Type: agentapisdk.MessageTypeUser,
|
||||
}
|
||||
expectedBytes, _ := json.Marshal(expectedReq)
|
||||
assert.Equal(t, string(expectedBytes), string(b), "expected message content")
|
||||
|
||||
resp := agentapisdk.PostMessageResponse{Ok: true}
|
||||
respBytes, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.WriteString(w, `{"ok": true}`)
|
||||
w.Write(respBytes)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -503,103 +514,105 @@ func TestTasks(t *testing.T) {
|
||||
defer srv.Close()
|
||||
|
||||
// Create an AI-capable template whose sidebar app points to our fake AgentAPI.
|
||||
authToken := uuid.NewString()
|
||||
template := createAITemplate(t, client, owner, withSidebarURL(srv.URL), withAgentToken(authToken))
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
ctx = testutil.Context(t, testutil.WaitLong)
|
||||
owner = coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ = coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
agentAuthToken = uuid.NewString()
|
||||
template = createAITemplate(t, client, owner, withAgentToken(agentAuthToken), withSidebarURL(srv.URL))
|
||||
exp = codersdk.NewExperimentalClient(userClient)
|
||||
)
|
||||
|
||||
// Create a workspace (task) from the AI-capable template.
|
||||
ws := coderdtest.CreateWorkspace(t, userClient, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "send a message"},
|
||||
}
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "send me food",
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid)
|
||||
|
||||
// Get the workspace and wait for it to be ready.
|
||||
ws, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
|
||||
// Fetch the task by ID via experimental API and verify fields.
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, task.WorkspaceBuildNumber)
|
||||
require.True(t, task.WorkspaceAgentID.Valid)
|
||||
require.True(t, task.WorkspaceAppID.Valid)
|
||||
|
||||
// Insert an app status for the workspace
|
||||
_, err = db.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: task.WorkspaceID.UUID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
AgentID: task.WorkspaceAgentID.UUID,
|
||||
AppID: task.WorkspaceAppID.UUID,
|
||||
State: database.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start a fake agent so the workspace agent is connected before sending the message.
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(agentAuthToken))
|
||||
_ = agenttest.New(t, userClient.URL, agentAuthToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, ws.ID).WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.ID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Lookup the sidebar app ID.
|
||||
w, err := client.Workspace(ctx, ws.ID)
|
||||
// Fetch the task by ID via experimental API and verify fields.
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
var sidebarAppID uuid.UUID
|
||||
for _, res := range w.LatestBuild.Resources {
|
||||
for _, ag := range res.Agents {
|
||||
for _, app := range ag.Apps {
|
||||
if app.Slug == "task-sidebar" {
|
||||
sidebarAppID = app.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotEqual(t, uuid.Nil, sidebarAppID)
|
||||
|
||||
// Make the sidebar app unhealthy initially.
|
||||
err = api.Database.UpdateWorkspaceAppHealthByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: sidebarAppID,
|
||||
err = db.UpdateWorkspaceAppHealthByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: task.WorkspaceAppID.UUID,
|
||||
Health: database.WorkspaceAppHealthUnhealthy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(userClient)
|
||||
err = exp.TaskSend(ctx, "me", ws.ID, codersdk.TaskSendRequest{
|
||||
err = exp.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
Input: "Hello, Agent!",
|
||||
})
|
||||
require.Error(t, err, "wanted error due to unhealthy sidebar app")
|
||||
|
||||
// Make the sidebar app healthy.
|
||||
err = api.Database.UpdateWorkspaceAppHealthByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: sidebarAppID,
|
||||
err = db.UpdateWorkspaceAppHealthByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAppHealthByIDParams{
|
||||
ID: task.WorkspaceAppID.UUID,
|
||||
Health: database.WorkspaceAppHealthHealthy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
statusResponse = createStatusResponse("bad")
|
||||
statusResponse = agentapisdk.AgentStatus("bad")
|
||||
|
||||
err = exp.TaskSend(ctx, "me", ws.ID, codersdk.TaskSendRequest{
|
||||
err = exp.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
Input: "Hello, Agent!",
|
||||
})
|
||||
require.Error(t, err, "wanted error due to bad status")
|
||||
|
||||
statusResponse = createStatusResponse("stable")
|
||||
statusResponse = agentapisdk.StatusStable
|
||||
|
||||
// Send task input to the tasks sidebar app and expect 204.e
|
||||
err = exp.TaskSend(ctx, "me", ws.ID, codersdk.TaskSendRequest{
|
||||
Input: "Hello, Agent!",
|
||||
})
|
||||
require.NoError(t, err, "wanted no error due to healthy sidebar app and stable status")
|
||||
})
|
||||
|
||||
t.Run("MissingContent", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
template := createAITemplate(t, client, user)
|
||||
|
||||
// Create a workspace (task).
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "do work"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
err := exp.TaskSend(ctx, "me", ws.ID, codersdk.TaskSendRequest{
|
||||
Input: "",
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
t.Run("SendOK", func(t *testing.T) {
|
||||
err = exp.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
Input: "Hello, Agent!",
|
||||
})
|
||||
require.NoError(t, err, "wanted no error due to healthy sidebar app and stable status")
|
||||
})
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
t.Run("MissingContent", func(t *testing.T) {
|
||||
err = exp.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
|
||||
Input: "",
|
||||
})
|
||||
require.Error(t, err, "wanted error due to missing content")
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("TaskNotFound", func(t *testing.T) {
|
||||
@@ -619,106 +632,112 @@ func TestTasks(t *testing.T) {
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotATask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Create a template without AI tasks.
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
err := exp.TaskSend(ctx, "me", ws.ID, codersdk.TaskSendRequest{
|
||||
Input: "hello",
|
||||
})
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Logs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
messageResponse := `
|
||||
messageResponseData := agentapisdk.GetMessagesResponse{
|
||||
Messages: []agentapisdk.Message{
|
||||
{
|
||||
"$schema": "http://localhost:3284/schemas/MessagesResponseBody.json",
|
||||
"messages": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "Welcome, user!",
|
||||
"role": "agent",
|
||||
"time": "2025-09-25T10:42:48.751774125Z"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"content": "Hello, agent!",
|
||||
"role": "user",
|
||||
"time": "2025-09-25T10:46:42.880996296Z"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"content": "What would you like to work on today?",
|
||||
"role": "agent",
|
||||
"time": "2025-09-25T10:46:50.747761102Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
`
|
||||
Id: 0,
|
||||
Content: "Welcome, user!",
|
||||
Role: agentapisdk.RoleAgent,
|
||||
Time: time.Date(2025, 9, 25, 10, 42, 48, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Content: "Hello, agent!",
|
||||
Role: agentapisdk.RoleUser,
|
||||
Time: time.Date(2025, 9, 25, 10, 46, 42, 0, time.UTC),
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
Content: "What would you like to work on today?",
|
||||
Role: agentapisdk.RoleAgent,
|
||||
Time: time.Date(2025, 9, 25, 10, 46, 50, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
}
|
||||
messageResponseBytes, err := json.Marshal(messageResponseData)
|
||||
require.NoError(t, err)
|
||||
messageResponse := string(messageResponseBytes)
|
||||
|
||||
// Fake AgentAPI that returns a couple of messages.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/messages" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.WriteString(w, messageResponse)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
var shouldReturnError bool
|
||||
|
||||
// Template pointing sidebar app to our fake AgentAPI.
|
||||
authToken := uuid.NewString()
|
||||
template := createAITemplate(t, client, owner, withSidebarURL(srv.URL), withAgentToken(authToken))
|
||||
// Fake AgentAPI that returns a couple of messages or an error.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if shouldReturnError {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = io.WriteString(w, "boom")
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/messages" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
io.WriteString(w, messageResponse)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Create task workspace.
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "show logs"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
// Create an AI-capable template whose sidebar app points to our fake AgentAPI.
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
ctx = testutil.Context(t, testutil.WaitLong)
|
||||
owner = coderdtest.CreateFirstUser(t, client)
|
||||
agentAuthToken = uuid.NewString()
|
||||
template = createAITemplate(t, client, owner, withAgentToken(agentAuthToken), withSidebarURL(srv.URL))
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
)
|
||||
|
||||
// Start a fake agent.
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.ID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "show logs",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid)
|
||||
|
||||
// Omit sidebar app health as undefined is OK.
|
||||
// Get the workspace and wait for it to be ready.
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
// Fetch the task by ID via experimental API and verify fields.
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, task.WorkspaceBuildNumber)
|
||||
require.True(t, task.WorkspaceAgentID.Valid)
|
||||
require.True(t, task.WorkspaceAppID.Valid)
|
||||
|
||||
// Insert an app status for the workspace
|
||||
_, err = db.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
|
||||
ID: uuid.New(),
|
||||
WorkspaceID: task.WorkspaceID.UUID,
|
||||
CreatedAt: dbtime.Now(),
|
||||
AgentID: task.WorkspaceAgentID.UUID,
|
||||
AppID: task.WorkspaceAppID.UUID,
|
||||
State: database.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start a fake agent so the workspace agent is connected before fetching logs.
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(agentAuthToken))
|
||||
_ = agenttest.New(t, client.URL, agentAuthToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.ID).WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
// Fetch the task by ID via experimental API and verify fields.
|
||||
task, err = exp.TaskByID(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
// Fetch logs.
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
resp, err := exp.TaskLogs(ctx, "me", ws.ID)
|
||||
resp, err := exp.TaskLogs(ctx, "me", task.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Logs, 3)
|
||||
assert.Equal(t, 0, resp.Logs[0].ID)
|
||||
@@ -734,38 +753,11 @@ func TestTasks(t *testing.T) {
|
||||
assert.Equal(t, "What would you like to work on today?", resp.Logs[2].Content)
|
||||
})
|
||||
|
||||
//nolint:tparallel // Not intended to run in parallel.
|
||||
t.Run("UpstreamError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Fake AgentAPI that returns 500 for messages.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = io.WriteString(w, "boom")
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
authToken := uuid.NewString()
|
||||
template := createAITemplate(t, client, owner, withSidebarURL(srv.URL), withAgentToken(authToken))
|
||||
ws := coderdtest.CreateWorkspace(t, client, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: "show logs"},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
// Start fake agent.
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, ws.ID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
_, err := exp.TaskLogs(ctx, "me", ws.ID)
|
||||
shouldReturnError = true
|
||||
t.Cleanup(func() { shouldReturnError = false })
|
||||
_, err := exp.TaskLogs(ctx, "me", task.ID)
|
||||
|
||||
var sdkErr *codersdk.Error
|
||||
require.Error(t, err)
|
||||
@@ -796,7 +788,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}},
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -869,7 +861,7 @@ func TestTasksCreate(t *testing.T) {
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: "AI Prompt", Type: "string"}},
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
@@ -965,7 +957,212 @@ func TestTasksCreate(t *testing.T) {
|
||||
var sdkErr *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAsf(t, err, &sdkErr, "error should be of type *codersdk.Error")
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("TaskTableCreatedAndLinked", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
taskPrompt = "Create a REST API"
|
||||
)
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create a template with AI task support to test the new task data model.
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
task, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: taskPrompt,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, task.WorkspaceID.Valid)
|
||||
|
||||
ws, err := client.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
|
||||
|
||||
// Verify that the task was created in the tasks table with the correct
|
||||
// fields. This ensures the data model properly separates task records
|
||||
// from workspace records.
|
||||
dbCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
dbTask, err := db.GetTaskByID(dbCtx, task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.OrganizationID, dbTask.OrganizationID)
|
||||
assert.Equal(t, user.UserID, dbTask.OwnerID)
|
||||
assert.Equal(t, task.Name, dbTask.Name)
|
||||
assert.True(t, dbTask.WorkspaceID.Valid)
|
||||
assert.Equal(t, ws.ID, dbTask.WorkspaceID.UUID)
|
||||
assert.Equal(t, version.ID, dbTask.TemplateVersionID)
|
||||
assert.Equal(t, taskPrompt, dbTask.Prompt)
|
||||
assert.False(t, dbTask.DeletedAt.Valid)
|
||||
|
||||
// Verify the bidirectional relationship works by looking up the task
|
||||
// via workspace ID.
|
||||
dbTaskByWs, err := db.GetTaskByWorkspaceID(dbCtx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, dbTask.ID, dbTaskByWs.ID)
|
||||
})
|
||||
|
||||
t.Run("TaskWithCustomName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
taskPrompt = "Build a dashboard"
|
||||
taskName = "my-custom-task"
|
||||
)
|
||||
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
task, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: taskPrompt,
|
||||
Name: taskName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, taskName, task.Name)
|
||||
|
||||
// Verify the custom name is preserved in the database record.
|
||||
dbCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
dbTask, err := db.GetTaskByID(dbCtx, task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, taskName, dbTask.Name)
|
||||
})
|
||||
|
||||
t.Run("MultipleTasksForSameUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
task1, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "First task",
|
||||
Name: "task-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
task2, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: template.ActiveVersionID,
|
||||
Input: "Second task",
|
||||
Name: "task-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both tasks are stored independently and can be listed together.
|
||||
dbCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
tasks, err := db.ListTasks(dbCtx, database.ListTasksParams{
|
||||
OwnerID: user.UserID,
|
||||
OrganizationID: uuid.Nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(tasks), 2)
|
||||
|
||||
taskIDs := make(map[uuid.UUID]bool)
|
||||
for _, task := range tasks {
|
||||
taskIDs[task.ID] = true
|
||||
}
|
||||
assert.True(t, taskIDs[task1.ID], "task1 should be in the list")
|
||||
assert.True(t, taskIDs[task2.ID], "task2 should be in the list")
|
||||
})
|
||||
|
||||
t.Run("TaskLinkedToCorrectTemplateVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
version1 := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version1.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version1.ID)
|
||||
|
||||
version2 := coderdtest.UpdateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionApply: echo.ApplyComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{Type: &proto.Response_Plan{Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
}}},
|
||||
},
|
||||
}, template.ID)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version2.ID)
|
||||
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Create a task using version 2 to verify the template_version_id is
|
||||
// stored correctly.
|
||||
task, err := expClient.CreateTask(ctx, "me", codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: version2.ID,
|
||||
Input: "Use version 2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the task references the correct template version, not just the
|
||||
// active one.
|
||||
dbCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
dbTask, err := db.GetTaskByID(dbCtx, task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, version2.ID, dbTask.TemplateVersionID, "task should be linked to version 2")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Generated
+302
-62
@@ -115,9 +115,15 @@ const docTemplate = `{
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after ID",
|
||||
"description": "Cursor pagination after ID (cannot be used with offset)",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -145,39 +151,16 @@ const docTemplate = `{
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query for filtering tasks",
|
||||
"description": "Search query for filtering tasks. Supports: owner:\u003cusername/uuid/me\u003e, organization:\u003corg-name/uuid\u003e, status:\u003cstatus\u003e",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Return tasks after this ID for pagination",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"type": "integer",
|
||||
"default": 25,
|
||||
"description": "Maximum number of tasks to return",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"minimum": 0,
|
||||
"type": "integer",
|
||||
"default": 0,
|
||||
"description": "Offset for pagination",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/coderd.tasksListResponse"
|
||||
"$ref": "#/definitions/codersdk.TasksListResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +206,7 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}": {
|
||||
"/api/experimental/tasks/{user}/{task}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -247,7 +230,7 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -284,7 +267,7 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -296,7 +279,7 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}/logs": {
|
||||
"/api/experimental/tasks/{user}/{task}/logs": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -320,7 +303,7 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -335,7 +318,7 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}/send": {
|
||||
"/api/experimental/tasks/{user}/{task}/send": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
@@ -359,7 +342,7 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
@@ -954,6 +937,138 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/metrics": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug metrics",
|
||||
"operationId": "debug-metrics",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug pprof index",
|
||||
"operationId": "debug-pprof-index",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/cmdline": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug pprof cmdline",
|
||||
"operationId": "debug-pprof-cmdline",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/profile": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug pprof profile",
|
||||
"operationId": "debug-pprof-profile",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/symbol": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug pprof symbol",
|
||||
"operationId": "debug-pprof-symbol",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/trace": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": [
|
||||
"Debug"
|
||||
],
|
||||
"summary": "Debug pprof trace",
|
||||
"operationId": "debug-pprof-trace",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/tailnet": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -2944,6 +3059,45 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/oauth2/revoke": {
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/x-www-form-urlencoded"
|
||||
],
|
||||
"tags": [
|
||||
"Enterprise"
|
||||
],
|
||||
"summary": "Revoke OAuth2 tokens (RFC 7009).",
|
||||
"operationId": "oauth2-token-revocation",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Client ID for authentication",
|
||||
"name": "client_id",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "The token to revoke",
|
||||
"name": "token",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Hint about token type (access_token or refresh_token)",
|
||||
"name": "token_type_hint",
|
||||
"in": "formData"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Token successfully revoked"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/oauth2/tokens": {
|
||||
"post": {
|
||||
"produces": [
|
||||
@@ -11486,20 +11640,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"coderd.tasksListResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Task"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ACLAvailable": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11549,9 +11689,8 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"initiator_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -11590,6 +11729,9 @@ const docTemplate = `{
|
||||
"codersdk.AIBridgeListInterceptionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@@ -11738,6 +11880,12 @@ const docTemplate = `{
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"allow_list": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.APIAllowListTarget"
|
||||
}
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -11971,6 +12119,7 @@ const docTemplate = `{
|
||||
"workspace:delete",
|
||||
"workspace:delete_agent",
|
||||
"workspace:read",
|
||||
"workspace:share",
|
||||
"workspace:ssh",
|
||||
"workspace:start",
|
||||
"workspace:stop",
|
||||
@@ -11988,6 +12137,7 @@ const docTemplate = `{
|
||||
"workspace_dormant:delete",
|
||||
"workspace_dormant:delete_agent",
|
||||
"workspace_dormant:read",
|
||||
"workspace_dormant:share",
|
||||
"workspace_dormant:ssh",
|
||||
"workspace_dormant:start",
|
||||
"workspace_dormant:stop",
|
||||
@@ -12167,6 +12317,7 @@ const docTemplate = `{
|
||||
"APIKeyScopeWorkspaceDelete",
|
||||
"APIKeyScopeWorkspaceDeleteAgent",
|
||||
"APIKeyScopeWorkspaceRead",
|
||||
"APIKeyScopeWorkspaceShare",
|
||||
"APIKeyScopeWorkspaceSsh",
|
||||
"APIKeyScopeWorkspaceStart",
|
||||
"APIKeyScopeWorkspaceStop",
|
||||
@@ -12184,6 +12335,7 @@ const docTemplate = `{
|
||||
"APIKeyScopeWorkspaceDormantDelete",
|
||||
"APIKeyScopeWorkspaceDormantDeleteAgent",
|
||||
"APIKeyScopeWorkspaceDormantRead",
|
||||
"APIKeyScopeWorkspaceDormantShare",
|
||||
"APIKeyScopeWorkspaceDormantSsh",
|
||||
"APIKeyScopeWorkspaceDormantStart",
|
||||
"APIKeyScopeWorkspaceDormantStop",
|
||||
@@ -13832,6 +13984,9 @@ const docTemplate = `{
|
||||
"docs_url": {
|
||||
"$ref": "#/definitions/serpent.URL"
|
||||
},
|
||||
"enable_authz_recording": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"enable_terraform_debug_mode": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -14750,7 +14905,15 @@ const docTemplate = `{
|
||||
"enum": [
|
||||
"bug",
|
||||
"chat",
|
||||
"docs"
|
||||
"docs",
|
||||
"star"
|
||||
]
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"navbar",
|
||||
"dropdown"
|
||||
]
|
||||
},
|
||||
"name": {
|
||||
@@ -14923,6 +15086,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -15195,6 +15361,9 @@ const docTemplate = `{
|
||||
},
|
||||
"token": {
|
||||
"type": "string"
|
||||
},
|
||||
"token_revoke": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15294,7 +15463,10 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "string"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -16926,6 +17098,7 @@ const docTemplate = `{
|
||||
"read",
|
||||
"read_personal",
|
||||
"ssh",
|
||||
"share",
|
||||
"unassign",
|
||||
"update",
|
||||
"update_personal",
|
||||
@@ -16944,6 +17117,7 @@ const docTemplate = `{
|
||||
"ActionRead",
|
||||
"ActionReadPersonal",
|
||||
"ActionSSH",
|
||||
"ActionShare",
|
||||
"ActionUnassign",
|
||||
"ActionUpdate",
|
||||
"ActionUpdatePersonal",
|
||||
@@ -17556,6 +17730,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"owner_avatar_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -17566,19 +17743,15 @@ const docTemplate = `{
|
||||
"status": {
|
||||
"enum": [
|
||||
"pending",
|
||||
"starting",
|
||||
"running",
|
||||
"stopping",
|
||||
"stopped",
|
||||
"failed",
|
||||
"canceling",
|
||||
"canceled",
|
||||
"deleting",
|
||||
"deleted"
|
||||
"initializing",
|
||||
"active",
|
||||
"paused",
|
||||
"unknown",
|
||||
"error"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.WorkspaceStatus"
|
||||
"$ref": "#/definitions/codersdk.TaskStatus"
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -17595,6 +17768,10 @@ const docTemplate = `{
|
||||
"template_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -17631,6 +17808,28 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"workspace_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_status": {
|
||||
"enum": [
|
||||
"pending",
|
||||
"starting",
|
||||
"running",
|
||||
"stopping",
|
||||
"stopped",
|
||||
"failed",
|
||||
"canceling",
|
||||
"canceled",
|
||||
"deleting",
|
||||
"deleted"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.WorkspaceStatus"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -17715,6 +17914,39 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.TaskStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"pending",
|
||||
"initializing",
|
||||
"active",
|
||||
"paused",
|
||||
"unknown",
|
||||
"error"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"TaskStatusPending",
|
||||
"TaskStatusInitializing",
|
||||
"TaskStatusActive",
|
||||
"TaskStatusPaused",
|
||||
"TaskStatusUnknown",
|
||||
"TaskStatusError"
|
||||
]
|
||||
},
|
||||
"codersdk.TasksListResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Task"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.TelemetryConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -20242,6 +20474,7 @@ const docTemplate = `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ai_task_sidebar_app_id": {
|
||||
"description": "Deprecated: This field has been replaced with ` + "`" + `TaskAppID` + "`" + `",
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
@@ -20323,6 +20556,10 @@ const docTemplate = `{
|
||||
}
|
||||
]
|
||||
},
|
||||
"task_app_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -20745,6 +20982,9 @@ const docTemplate = `{
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"role": {
|
||||
"enum": [
|
||||
"admin",
|
||||
|
||||
Generated
+282
-62
@@ -91,9 +91,15 @@
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after ID",
|
||||
"description": "Cursor pagination after ID (cannot be used with offset)",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
@@ -119,39 +125,16 @@
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query for filtering tasks",
|
||||
"description": "Search query for filtering tasks. Supports: owner:\u003cusername/uuid/me\u003e, organization:\u003corg-name/uuid\u003e, status:\u003cstatus\u003e",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Return tasks after this ID for pagination",
|
||||
"name": "after_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"type": "integer",
|
||||
"default": 25,
|
||||
"description": "Maximum number of tasks to return",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"minimum": 0,
|
||||
"type": "integer",
|
||||
"default": 0,
|
||||
"description": "Offset for pagination",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/coderd.tasksListResponse"
|
||||
"$ref": "#/definitions/codersdk.TasksListResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,7 +178,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}": {
|
||||
"/api/experimental/tasks/{user}/{task}": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -217,7 +200,7 @@
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -252,7 +235,7 @@
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -264,7 +247,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}/logs": {
|
||||
"/api/experimental/tasks/{user}/{task}/logs": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
@@ -286,7 +269,7 @@
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
@@ -301,7 +284,7 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/experimental/tasks/{user}/{id}/send": {
|
||||
"/api/experimental/tasks/{user}/{task}/send": {
|
||||
"post": {
|
||||
"security": [
|
||||
{
|
||||
@@ -323,7 +306,7 @@
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Task ID",
|
||||
"name": "id",
|
||||
"name": "task",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
@@ -840,6 +823,126 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/metrics": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug metrics",
|
||||
"operationId": "debug-metrics",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug pprof index",
|
||||
"operationId": "debug-pprof-index",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/cmdline": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug pprof cmdline",
|
||||
"operationId": "debug-pprof-cmdline",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/profile": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug pprof profile",
|
||||
"operationId": "debug-pprof-profile",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/symbol": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug pprof symbol",
|
||||
"operationId": "debug-pprof-symbol",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/pprof/trace": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"tags": ["Debug"],
|
||||
"summary": "Debug pprof trace",
|
||||
"operationId": "debug-pprof-trace",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK"
|
||||
}
|
||||
},
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/debug/tailnet": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -2594,6 +2697,41 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/oauth2/revoke": {
|
||||
"post": {
|
||||
"consumes": ["application/x-www-form-urlencoded"],
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Revoke OAuth2 tokens (RFC 7009).",
|
||||
"operationId": "oauth2-token-revocation",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Client ID for authentication",
|
||||
"name": "client_id",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "The token to revoke",
|
||||
"name": "token",
|
||||
"in": "formData",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Hint about token type (access_token or refresh_token)",
|
||||
"name": "token_type_hint",
|
||||
"in": "formData"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Token successfully revoked"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/oauth2/tokens": {
|
||||
"post": {
|
||||
"produces": ["application/json"],
|
||||
@@ -10198,20 +10336,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"coderd.tasksListResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Task"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ACLAvailable": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -10261,9 +10385,8 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"initiator_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -10302,6 +10425,9 @@
|
||||
"codersdk.AIBridgeListInterceptionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"results": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
@@ -10450,6 +10576,12 @@
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"allow_list": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.APIAllowListTarget"
|
||||
}
|
||||
},
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -10675,6 +10807,7 @@
|
||||
"workspace:delete",
|
||||
"workspace:delete_agent",
|
||||
"workspace:read",
|
||||
"workspace:share",
|
||||
"workspace:ssh",
|
||||
"workspace:start",
|
||||
"workspace:stop",
|
||||
@@ -10692,6 +10825,7 @@
|
||||
"workspace_dormant:delete",
|
||||
"workspace_dormant:delete_agent",
|
||||
"workspace_dormant:read",
|
||||
"workspace_dormant:share",
|
||||
"workspace_dormant:ssh",
|
||||
"workspace_dormant:start",
|
||||
"workspace_dormant:stop",
|
||||
@@ -10871,6 +11005,7 @@
|
||||
"APIKeyScopeWorkspaceDelete",
|
||||
"APIKeyScopeWorkspaceDeleteAgent",
|
||||
"APIKeyScopeWorkspaceRead",
|
||||
"APIKeyScopeWorkspaceShare",
|
||||
"APIKeyScopeWorkspaceSsh",
|
||||
"APIKeyScopeWorkspaceStart",
|
||||
"APIKeyScopeWorkspaceStop",
|
||||
@@ -10888,6 +11023,7 @@
|
||||
"APIKeyScopeWorkspaceDormantDelete",
|
||||
"APIKeyScopeWorkspaceDormantDeleteAgent",
|
||||
"APIKeyScopeWorkspaceDormantRead",
|
||||
"APIKeyScopeWorkspaceDormantShare",
|
||||
"APIKeyScopeWorkspaceDormantSsh",
|
||||
"APIKeyScopeWorkspaceDormantStart",
|
||||
"APIKeyScopeWorkspaceDormantStop",
|
||||
@@ -12462,6 +12598,9 @@
|
||||
"docs_url": {
|
||||
"$ref": "#/definitions/serpent.URL"
|
||||
},
|
||||
"enable_authz_recording": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"enable_terraform_debug_mode": {
|
||||
"type": "boolean"
|
||||
},
|
||||
@@ -13354,7 +13493,11 @@
|
||||
"properties": {
|
||||
"icon": {
|
||||
"type": "string",
|
||||
"enum": ["bug", "chat", "docs"]
|
||||
"enum": ["bug", "chat", "docs", "star"]
|
||||
},
|
||||
"location": {
|
||||
"type": "string",
|
||||
"enum": ["navbar", "dropdown"]
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
@@ -13497,6 +13640,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
@@ -13769,6 +13915,9 @@
|
||||
},
|
||||
"token": {
|
||||
"type": "string"
|
||||
},
|
||||
"token_revoke": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13868,7 +14017,10 @@
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "string"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -15448,6 +15600,7 @@
|
||||
"read",
|
||||
"read_personal",
|
||||
"ssh",
|
||||
"share",
|
||||
"unassign",
|
||||
"update",
|
||||
"update_personal",
|
||||
@@ -15466,6 +15619,7 @@
|
||||
"ActionRead",
|
||||
"ActionReadPersonal",
|
||||
"ActionSSH",
|
||||
"ActionShare",
|
||||
"ActionUnassign",
|
||||
"ActionUpdate",
|
||||
"ActionUpdatePersonal",
|
||||
@@ -16064,6 +16218,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"owner_avatar_url": {
|
||||
"type": "string"
|
||||
},
|
||||
"owner_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -16074,19 +16231,15 @@
|
||||
"status": {
|
||||
"enum": [
|
||||
"pending",
|
||||
"starting",
|
||||
"running",
|
||||
"stopping",
|
||||
"stopped",
|
||||
"failed",
|
||||
"canceling",
|
||||
"canceled",
|
||||
"deleting",
|
||||
"deleted"
|
||||
"initializing",
|
||||
"active",
|
||||
"paused",
|
||||
"unknown",
|
||||
"error"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.WorkspaceStatus"
|
||||
"$ref": "#/definitions/codersdk.TaskStatus"
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -16103,6 +16256,10 @@
|
||||
"template_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -16139,6 +16296,28 @@
|
||||
"$ref": "#/definitions/uuid.NullUUID"
|
||||
}
|
||||
]
|
||||
},
|
||||
"workspace_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_status": {
|
||||
"enum": [
|
||||
"pending",
|
||||
"starting",
|
||||
"running",
|
||||
"stopping",
|
||||
"stopped",
|
||||
"failed",
|
||||
"canceling",
|
||||
"canceled",
|
||||
"deleting",
|
||||
"deleted"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.WorkspaceStatus"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -16212,6 +16391,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.TaskStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"pending",
|
||||
"initializing",
|
||||
"active",
|
||||
"paused",
|
||||
"unknown",
|
||||
"error"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"TaskStatusPending",
|
||||
"TaskStatusInitializing",
|
||||
"TaskStatusActive",
|
||||
"TaskStatusPaused",
|
||||
"TaskStatusUnknown",
|
||||
"TaskStatusError"
|
||||
]
|
||||
},
|
||||
"codersdk.TasksListResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.Task"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.TelemetryConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -18596,6 +18808,7 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ai_task_sidebar_app_id": {
|
||||
"description": "Deprecated: This field has been replaced with `TaskAppID`",
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
@@ -18673,6 +18886,10 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"task_app_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"template_version_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19073,6 +19290,9 @@
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"role": {
|
||||
"enum": ["admin", "use"],
|
||||
"allOf": [
|
||||
|
||||
+28
-15
@@ -2,6 +2,7 @@ package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
@@ -44,12 +45,17 @@ type CreateParams struct {
|
||||
// database representation. It is the responsibility of the caller to insert it
|
||||
// into the database.
|
||||
func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error) {
|
||||
keyID, keySecret, err := generateKey()
|
||||
// Length of an API Key ID.
|
||||
keyID, err := cryptorand.String(10)
|
||||
if err != nil {
|
||||
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key: %w", err)
|
||||
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key ID: %w", err)
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
// Length of an API Key secret.
|
||||
keySecret, hashedSecret, err := GenerateSecret(22)
|
||||
if err != nil {
|
||||
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key secret: %w", err)
|
||||
}
|
||||
|
||||
// Default expires at to now+lifetime, or use the configured value if not
|
||||
// set.
|
||||
@@ -120,7 +126,7 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: dbtime.Now(),
|
||||
UpdatedAt: dbtime.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
HashedSecret: hashedSecret,
|
||||
LoginType: params.LoginType,
|
||||
Scopes: scopes,
|
||||
AllowList: params.AllowList,
|
||||
@@ -128,17 +134,24 @@ func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error)
|
||||
}, token, nil
|
||||
}
|
||||
|
||||
// generateKey a new ID and secret for an API key.
|
||||
func generateKey() (id string, secret string, err error) {
|
||||
// Length of an API Key ID.
|
||||
id, err = cryptorand.String(10)
|
||||
func GenerateSecret(length int) (secret string, hashed []byte, err error) {
|
||||
secret, err = cryptorand.String(length)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return "", nil, err
|
||||
}
|
||||
// Length of an API Key secret.
|
||||
secret, err = cryptorand.String(22)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return id, secret, nil
|
||||
hash := HashSecret(secret)
|
||||
return secret, hash, nil
|
||||
}
|
||||
|
||||
// ValidateHash compares a secret against an expected hashed secret.
|
||||
func ValidateHash(hashedSecret []byte, secret string) bool {
|
||||
hash := HashSecret(secret)
|
||||
return subtle.ConstantTimeCompare(hashedSecret, hash) == 1
|
||||
}
|
||||
|
||||
// HashSecret is the single function used to hash API key secrets.
|
||||
// Use this to ensure a consistent hashing algorithm.
|
||||
func HashSecret(secret string) []byte {
|
||||
hash := sha256.Sum256([]byte(secret))
|
||||
return hash[:]
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package apikey_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -126,8 +125,8 @@ func TestGenerate(t *testing.T) {
|
||||
require.Equal(t, key.ID, keytokens[0])
|
||||
|
||||
// Assert that the hashed secret is correct.
|
||||
hashed := sha256.Sum256([]byte(keytokens[1]))
|
||||
assert.ElementsMatch(t, hashed, key.HashedSecret)
|
||||
equal := apikey.ValidateHash(key.HashedSecret, keytokens[1])
|
||||
require.True(t, equal, "valid secret")
|
||||
|
||||
assert.Equal(t, tc.params.UserID, key.UserID)
|
||||
assert.WithinDuration(t, dbtime.Now(), key.CreatedAt, time.Second*5)
|
||||
@@ -173,3 +172,17 @@ func TestGenerate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalid just ensures the false case is asserted by some tests.
|
||||
// Otherwise, a function that just `returns true` might pass all tests incorrectly.
|
||||
func TestInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Falsef(t, apikey.ValidateHash([]byte{}, "secret"), "empty hash")
|
||||
|
||||
secret, hash, err := apikey.GenerateSecret(10)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Falsef(t, apikey.ValidateHash(hash, secret+"_"), "different secret")
|
||||
require.Falsef(t, apikey.ValidateHash(hash[:len(hash)-1], secret), "different hash length")
|
||||
}
|
||||
|
||||
@@ -51,6 +51,8 @@ func TestTokenCRUD(t *testing.T) {
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*6))
|
||||
require.Less(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*24*8))
|
||||
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
|
||||
require.Len(t, keys[0].AllowList, 1)
|
||||
require.Equal(t, "*:*", keys[0].AllowList[0].String())
|
||||
|
||||
// no update
|
||||
|
||||
@@ -86,6 +88,8 @@ func TestTokenScoped(t *testing.T) {
|
||||
require.EqualValues(t, len(keys), 1)
|
||||
require.Contains(t, res.Key, keys[0].ID)
|
||||
require.Equal(t, keys[0].Scope, codersdk.APIKeyScopeApplicationConnect)
|
||||
require.Len(t, keys[0].AllowList, 1)
|
||||
require.Equal(t, "*:*", keys[0].AllowList[0].String())
|
||||
}
|
||||
|
||||
// Ensure backward-compat: when a token is created using the legacy singular
|
||||
@@ -132,6 +136,8 @@ func TestTokenLegacySingularScopeCompat(t *testing.T) {
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, tc.scope, keys[0].Scope)
|
||||
require.ElementsMatch(t, keys[0].Scopes, tc.scopes)
|
||||
require.Len(t, keys[0].AllowList, 1)
|
||||
require.Equal(t, "*:*", keys[0].AllowList[0].String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -776,10 +776,6 @@ func TestExecutorWorkspaceAutostopNoWaitChangedMyMind(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExecutorAutostartMultipleOK(t *testing.T) {
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip(`This test only really works when using a "real" database, similar to a HA setup`)
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
@@ -1259,10 +1255,6 @@ func TestNotifications(t *testing.T) {
|
||||
func TestExecutorPrebuilds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
// Prebuild workspaces should not be autostopped when the deadline is reached.
|
||||
// After being claimed, the workspace should stop at the deadline.
|
||||
t.Run("OnlyStopsAfterClaimed", func(t *testing.T) {
|
||||
|
||||
+58
-6
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
httppprof "net/http/pprof"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -32,6 +33,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
httpSwagger "github.com/swaggo/http-swagger/v2"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
@@ -491,7 +493,7 @@ func New(options *Options) *API {
|
||||
// We add this middleware early, to make sure that authorization checks made
|
||||
// by other middleware get recorded.
|
||||
if buildinfo.IsDev() {
|
||||
r.Use(httpmw.RecordAuthzChecks)
|
||||
r.Use(httpmw.RecordAuthzChecks(options.DeploymentValues.EnableAuthzRecording.Value()))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -983,6 +985,16 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postOAuth2ProviderAppToken())
|
||||
})
|
||||
|
||||
// RFC 7009 Token Revocation Endpoint
|
||||
r.Route("/revoke", func(r chi.Router) {
|
||||
r.Use(
|
||||
// RFC 7009 endpoint uses OAuth2 client authentication, not API key
|
||||
httpmw.AsAuthzSystem(httpmw.ExtractOAuth2ProviderAppWithOAuth2Errors(options.Database)),
|
||||
)
|
||||
// POST /revoke is the standard OAuth2 token revocation endpoint per RFC 7009
|
||||
r.Post("/", api.revokeOAuth2Token())
|
||||
})
|
||||
|
||||
// RFC 7591 Dynamic Client Registration - Public endpoint
|
||||
r.Post("/register", api.postOAuth2ClientRegistration())
|
||||
|
||||
@@ -1020,11 +1032,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractOrganizationMembersParam(options.Database, api.HTTPAuth.Authorize))
|
||||
r.Get("/{id}", api.taskGet)
|
||||
r.Delete("/{id}", api.taskDelete)
|
||||
r.Post("/{id}/send", api.taskSend)
|
||||
r.Get("/{id}/logs", api.taskLogs)
|
||||
r.Post("/", api.tasksCreate)
|
||||
|
||||
r.Route("/{task}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractTaskParam(options.Database))
|
||||
r.Get("/", api.taskGet)
|
||||
r.Delete("/", api.taskDelete)
|
||||
r.Post("/send", api.taskSend)
|
||||
r.Get("/logs", api.taskLogs)
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/mcp", func(r chi.Router) {
|
||||
@@ -1512,7 +1528,8 @@ func New(options *Options) *API {
|
||||
r.Route("/debug", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
// Ensure only owners can access debug endpoints.
|
||||
// Ensure only users with the debug_info:read (e.g. only owners)
|
||||
// can view debug endpoints.
|
||||
func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceDebugInfo) {
|
||||
@@ -1545,6 +1562,41 @@ func New(options *Options) *API {
|
||||
})
|
||||
}
|
||||
r.Method("GET", "/expvar", expvar.Handler()) // contains DERP metrics as well as cmdline and memstats
|
||||
|
||||
r.Route("/pprof", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
// Some of the pprof handlers strip the `/debug/pprof`
|
||||
// prefix, so we need to strip our additional prefix as
|
||||
// well.
|
||||
return http.StripPrefix("/api/v2", next)
|
||||
})
|
||||
|
||||
// Serve the index HTML page.
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Redirect to include a trailing slash, otherwise links on
|
||||
// the generated HTML page will be broken.
|
||||
if !strings.HasSuffix(r.URL.Path, "/") {
|
||||
http.Redirect(w, r, "/api/v2/debug/pprof/", http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
httppprof.Index(w, r)
|
||||
})
|
||||
|
||||
// Handle any out of the box pprof handlers that don't get
|
||||
// dealt with by the default index handler. See httppprof.init.
|
||||
r.Get("/cmdline", httppprof.Cmdline)
|
||||
r.Get("/profile", httppprof.Profile)
|
||||
r.Get("/symbol", httppprof.Symbol)
|
||||
r.Get("/trace", httppprof.Trace)
|
||||
|
||||
// Index will handle any standard and custom runtime/pprof
|
||||
// profiles.
|
||||
r.Get("/*", httppprof.Index)
|
||||
})
|
||||
|
||||
r.Get("/metrics", promhttp.InstrumentMetricHandler(
|
||||
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
|
||||
).ServeHTTP)
|
||||
})
|
||||
// Manage OAuth2 applications that can use Coder as an OAuth2 provider.
|
||||
r.Route("/oauth2-provider", func(r chi.Router) {
|
||||
|
||||
@@ -160,8 +160,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
t.Run(method+" "+route, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// This route is for compatibility purposes and is not documented.
|
||||
if route == "/workspaceagents/me/metadata" {
|
||||
// Wildcard routes break the swaggo parser, so we do not document
|
||||
// them.
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ type CheckConstraint string
|
||||
|
||||
// CheckConstraint enums.
|
||||
const (
|
||||
CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys
|
||||
CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users
|
||||
CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users
|
||||
CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs
|
||||
|
||||
@@ -51,6 +51,13 @@ func ListLazy[F any, T any](convert func(F) T) func(list []F) []T {
|
||||
}
|
||||
}
|
||||
|
||||
func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget {
|
||||
return codersdk.APIAllowListTarget{
|
||||
Type: codersdk.RBACResource(entry.Type),
|
||||
ID: entry.ID,
|
||||
}
|
||||
}
|
||||
|
||||
type ExternalAuthMeta struct {
|
||||
Authenticated bool
|
||||
ValidateError string
|
||||
@@ -189,6 +196,16 @@ func MinimalUser(user database.User) codersdk.MinimalUser {
|
||||
return codersdk.MinimalUser{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
}
|
||||
}
|
||||
|
||||
func MinimalUserFromVisibleUser(user database.VisibleUser) codersdk.MinimalUser {
|
||||
return codersdk.MinimalUser{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
}
|
||||
}
|
||||
@@ -197,7 +214,6 @@ func ReducedUser(user database.User) codersdk.ReducedUser {
|
||||
return codersdk.ReducedUser{
|
||||
MinimalUser: MinimalUser(user),
|
||||
Email: user.Email,
|
||||
Name: user.Name,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
LastSeenAt: user.LastSeenAt,
|
||||
@@ -374,6 +390,9 @@ func OAuth2ProviderApp(accessURL *url.URL, dbApp database.OAuth2ProviderApp) cod
|
||||
}).String(),
|
||||
// We do not currently support DeviceAuth.
|
||||
DeviceAuth: "",
|
||||
TokenRevoke: accessURL.ResolveReference(&url.URL{
|
||||
Path: "/oauth2/revoke",
|
||||
}).String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -927,7 +946,7 @@ func PreviewParameterValidation(v *previewtypes.ParameterValidation) codersdk.Pr
|
||||
}
|
||||
}
|
||||
|
||||
func AIBridgeInterception(interception database.AIBridgeInterception, tokenUsages []database.AIBridgeTokenUsage, userPrompts []database.AIBridgeUserPrompt, toolUsages []database.AIBridgeToolUsage) codersdk.AIBridgeInterception {
|
||||
func AIBridgeInterception(interception database.AIBridgeInterception, initiator database.VisibleUser, tokenUsages []database.AIBridgeTokenUsage, userPrompts []database.AIBridgeUserPrompt, toolUsages []database.AIBridgeToolUsage) codersdk.AIBridgeInterception {
|
||||
sdkTokenUsages := List(tokenUsages, AIBridgeTokenUsage)
|
||||
sort.Slice(sdkTokenUsages, func(i, j int) bool {
|
||||
// created_at ASC
|
||||
@@ -945,7 +964,7 @@ func AIBridgeInterception(interception database.AIBridgeInterception, tokenUsage
|
||||
})
|
||||
return codersdk.AIBridgeInterception{
|
||||
ID: interception.ID,
|
||||
InitiatorID: interception.InitiatorID,
|
||||
Initiator: MinimalUserFromVisibleUser(initiator),
|
||||
Provider: interception.Provider,
|
||||
Model: interception.Model,
|
||||
Metadata: jsonOrEmptyMap(interception.Metadata),
|
||||
|
||||
@@ -85,10 +85,6 @@ func TestNestedInTx(t *testing.T) {
|
||||
func testSQLDB(t testing.TB) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
connection, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -446,6 +446,34 @@ var (
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectSystemOAuth2 = rbac.Subject{
|
||||
Type: rbac.SubjectTypeSystemOAuth,
|
||||
FriendlyName: "System OAuth2",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "system-oauth2"},
|
||||
DisplayName: "System OAuth2",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
// OAuth2 resources - full CRUD permissions
|
||||
rbac.ResourceOauth2App.Type: rbac.ResourceOauth2App.AvailableActions(),
|
||||
rbac.ResourceOauth2AppSecret.Type: rbac.ResourceOauth2AppSecret.AvailableActions(),
|
||||
rbac.ResourceOauth2AppCodeToken.Type: rbac.ResourceOauth2AppCodeToken.AvailableActions(),
|
||||
|
||||
// API key permissions needed for OAuth2 token revocation
|
||||
rbac.ResourceApiKey.Type: {policy.ActionRead, policy.ActionDelete},
|
||||
|
||||
// Minimal read permissions that might be needed for OAuth2 operations
|
||||
rbac.ResourceUser.Type: {policy.ActionRead},
|
||||
rbac.ResourceOrganization.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectSystemReadProvisionerDaemons = rbac.Subject{
|
||||
Type: rbac.SubjectTypeSystemReadProvisionerDaemons,
|
||||
FriendlyName: "Provisioner Daemons Reader",
|
||||
@@ -643,6 +671,12 @@ func AsSystemRestricted(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectSystemRestricted)
|
||||
}
|
||||
|
||||
// AsSystemOAuth2 returns a context with an actor that has permissions
|
||||
// required for OAuth2 provider operations (token revocation, device codes, registration).
|
||||
func AsSystemOAuth2(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectSystemOAuth2)
|
||||
}
|
||||
|
||||
// AsSystemReadProvisionerDaemons returns a context with an actor that has permissions
|
||||
// to read provisioner daemons.
|
||||
func AsSystemReadProvisionerDaemons(ctx context.Context) context.Context {
|
||||
@@ -1436,6 +1470,14 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return q.db.CleanTailnetTunnels(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
// Shortcut if the user is an owner. The SQL filter is noticeable,
|
||||
// and this is an easy win for owners. Which is the common case.
|
||||
@@ -1470,6 +1512,13 @@ func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.Coun
|
||||
return q.db.CountInProgressPrebuilds(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.CountPendingNonActivePrebuilds(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceInboxNotification.WithOwner(userID.String())); err != nil {
|
||||
return 0, err
|
||||
@@ -1756,6 +1805,19 @@ func (q *querier) DeleteTailnetTunnel(ctx context.Context, arg database.DeleteTa
|
||||
return q.db.DeleteTailnetTunnel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
|
||||
task, err := q.db.GetTaskByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, task.RBACObject()); err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
return q.db.DeleteTask(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
@@ -1792,7 +1854,7 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return w.WorkspaceTable(), nil
|
||||
}
|
||||
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionUpdate, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
@@ -2420,7 +2482,7 @@ func (q *querier) GetOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) (d
|
||||
return q.db.GetOAuth2ProviderAppByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
|
||||
func (q *querier) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceOauth2App); err != nil {
|
||||
return database.OAuth2ProviderApp{}, err
|
||||
}
|
||||
@@ -3388,7 +3450,7 @@ func (q *querier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (databa
|
||||
if err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionCreate, workspace); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionShare, workspace); err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
return q.db.GetWorkspaceACLByID(ctx, id)
|
||||
@@ -3552,6 +3614,13 @@ func (q *querier) GetWorkspaceAgentsCreatedAfter(ctx context.Context, createdAt
|
||||
return q.db.GetWorkspaceAgentsCreatedAfter(ctx, createdAt)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAgentsForMetrics(ctx context.Context) ([]database.GetWorkspaceAgentsForMetricsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetWorkspaceAgentsForMetrics(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
workspace, err := q.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -3857,6 +3926,13 @@ func (q *querier) GetWorkspacesEligibleForTransition(ctx context.Context, now ti
|
||||
return q.db.GetWorkspacesEligibleForTransition(ctx, now)
|
||||
}
|
||||
|
||||
func (q *querier) GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]database.GetWorkspacesForWorkspaceMetricsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetWorkspacesForWorkspaceMetrics(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) InsertAIBridgeInterception(ctx context.Context, arg database.InsertAIBridgeInterceptionParams) (database.AIBridgeInterception, error) {
|
||||
return insert(q.log, q.auth, rbac.ResourceAibridgeInterception.WithOwner(arg.InitiatorID.String()), q.db.InsertAIBridgeInterception)(ctx, arg)
|
||||
}
|
||||
@@ -4439,7 +4515,7 @@ func (q *querier) InsertWorkspaceResourceMetadata(ctx context.Context, arg datab
|
||||
return q.db.InsertWorkspaceResourceMetadata(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
|
||||
func (q *querier) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
@@ -4806,6 +4882,14 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
|
||||
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
|
||||
return []uuid.UUID{}, err
|
||||
}
|
||||
return q.db.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
|
||||
preset, err := q.db.GetPresetByID(ctx, arg.PresetID)
|
||||
if err != nil {
|
||||
@@ -4953,6 +5037,30 @@ func (q *querier) UpdateTailnetPeerStatusByCoordinator(ctx context.Context, arg
|
||||
return q.db.UpdateTailnetPeerStatusByCoordinator(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
|
||||
// An actor is allowed to update the workspace ID of a task if they are the
|
||||
// owner of the task and workspace or have the appropriate permissions.
|
||||
task, err := q.db.GetTaskByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, task.RBACObject()); err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
ws, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID.UUID)
|
||||
if err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, ws.RBACObject()); err != nil {
|
||||
return database.TaskTable{}, err
|
||||
}
|
||||
|
||||
return q.db.UpdateTaskWorkspaceID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
|
||||
fetch := func(ctx context.Context, arg database.UpdateTemplateACLByIDParams) (database.Template, error) {
|
||||
return q.db.GetTemplateByID(ctx, arg.ID)
|
||||
@@ -5298,7 +5406,7 @@ func (q *querier) UpdateWorkspaceACLByID(ctx context.Context, arg database.Updat
|
||||
return w.WorkspaceTable(), nil
|
||||
}
|
||||
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionCreate, fetch, q.db.UpdateWorkspaceACLByID)(ctx, arg)
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.UpdateWorkspaceACLByID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg database.UpdateWorkspaceAgentConnectionByIDParams) error {
|
||||
@@ -5848,9 +5956,16 @@ func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg databas
|
||||
return q.CountConnectionLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
|
||||
func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
// TODO: Delete this function, all ListAIBridgeInterceptions should be authorized. For now just call ListAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.ListAIBridgeInterceptions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) {
|
||||
// TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier.
|
||||
// This cannot be deleted for now because it's included in the
|
||||
// database.Store interface, so dbauthz needs to implement it.
|
||||
return q.CountAIBridgeInterceptions(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -641,6 +641,16 @@ func (s *MethodTestSuite) TestProvisionerJob() {
|
||||
dbm.EXPECT().UpdateProvisionerJobWithCancelByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(v.RBACObject(tpl), []policy.Action{policy.ActionRead, policy.ActionUpdate}).Returns()
|
||||
}))
|
||||
s.Run("UpdatePrebuildProvisionerJobWithCancel", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
arg := database.UpdatePrebuildProvisionerJobWithCancelParams{
|
||||
PresetID: uuid.NullUUID{UUID: uuid.New(), Valid: true},
|
||||
Now: dbtime.Now(),
|
||||
}
|
||||
jobIDs := []uuid.UUID{uuid.New(), uuid.New()}
|
||||
|
||||
dbm.EXPECT().UpdatePrebuildProvisionerJobWithCancel(gomock.Any(), arg).Return(jobIDs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourcePrebuiltWorkspace, policy.ActionUpdate).Returns(jobIDs)
|
||||
}))
|
||||
s.Run("GetProvisionerJobsByIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
org2 := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
@@ -1689,6 +1699,15 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(arg).Asserts()
|
||||
}))
|
||||
s.Run("GetWorkspaceAgentsForMetrics", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
row := testutil.Fake(s.T(), faker, database.GetWorkspaceAgentsForMetricsRow{})
|
||||
dbm.EXPECT().GetWorkspaceAgentsForMetrics(gomock.Any()).Return([]database.GetWorkspaceAgentsForMetricsRow{row}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace, policy.ActionRead).Returns([]database.GetWorkspaceAgentsForMetricsRow{row})
|
||||
}))
|
||||
s.Run("GetWorkspacesForWorkspaceMetrics", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetWorkspacesForWorkspaceMetrics(gomock.Any()).Return([]database.GetWorkspacesForWorkspaceMetricsRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetAuthorizedWorkspaces", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetWorkspacesParams{}
|
||||
dbm.EXPECT().GetAuthorizedWorkspaces(gomock.Any(), arg, gomock.Any()).Return([]database.GetWorkspacesRow{}, nil).AnyTimes()
|
||||
@@ -1723,20 +1742,20 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbM.EXPECT().GetWorkspaceACLByID(gomock.Any(), ws.ID).Return(database.GetWorkspaceACLByIDRow{}, nil).AnyTimes()
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionCreate)
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionShare)
|
||||
}))
|
||||
s.Run("UpdateWorkspaceACLByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
arg := database.UpdateWorkspaceACLByIDParams{ID: w.ID}
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), w.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateWorkspaceACLByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionCreate)
|
||||
check.Args(arg).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), w.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteWorkspaceACLByID(gomock.Any(), w.ID).Return(nil).AnyTimes()
|
||||
check.Args(w.ID).Asserts(w, policy.ActionUpdate)
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -2353,6 +2372,16 @@ func (s *MethodTestSuite) TestTasks() {
|
||||
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
|
||||
check.Args(task.ID).Asserts(task, policy.ActionRead).Returns(task)
|
||||
}))
|
||||
s.Run("DeleteTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
task := testutil.Fake(s.T(), faker, database.Task{})
|
||||
arg := database.DeleteTaskParams{
|
||||
ID: task.ID,
|
||||
DeletedAt: dbtime.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteTask(gomock.Any(), arg).Return(database.TaskTable{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(task, policy.ActionDelete).Returns(database.TaskTable{})
|
||||
}))
|
||||
s.Run("InsertTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
tpl := testutil.Fake(s.T(), faker, database.Template{})
|
||||
tv := testutil.Fake(s.T(), faker, database.TemplateVersion{
|
||||
@@ -2386,6 +2415,20 @@ func (s *MethodTestSuite) TestTasks() {
|
||||
|
||||
check.Args(arg).Asserts(task, policy.ActionUpdate).Returns(database.TaskWorkspaceApp{})
|
||||
}))
|
||||
s.Run("UpdateTaskWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
task := testutil.Fake(s.T(), faker, database.Task{})
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
arg := database.UpdateTaskWorkspaceIDParams{
|
||||
ID: task.ID,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
}
|
||||
|
||||
dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes()
|
||||
dbm.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateTaskWorkspaceID(gomock.Any(), arg).Return(database.TaskTable{}, nil).AnyTimes()
|
||||
|
||||
check.Args(arg).Asserts(task, policy.ActionUpdate, ws, policy.ActionUpdate).Returns(database.TaskTable{})
|
||||
}))
|
||||
s.Run("GetTaskByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
task := testutil.Fake(s.T(), faker, database.Task{})
|
||||
task.WorkspaceID = uuid.NullUUID{UUID: uuid.New(), Valid: true}
|
||||
@@ -2937,7 +2980,6 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
dbm.EXPECT().GetParameterSchemasByJobID(gomock.Any(), jobID).Return([]database.ParameterSchema{}, nil).AnyTimes()
|
||||
check.Args(jobID).
|
||||
Asserts(tpl, policy.ActionRead).
|
||||
ErrorsWithInMemDB(sql.ErrNoRows).
|
||||
Returns([]database.ParameterSchema{})
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsByAgentIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -3180,7 +3222,7 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
}))
|
||||
s.Run("GetAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetAppSecurityKey(gomock.Any()).Return("", sql.ErrNoRows).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).ErrorsWithPG(sql.ErrNoRows)
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead).Errors(sql.ErrNoRows)
|
||||
}))
|
||||
s.Run("UpsertAppSecurityKey", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertAppSecurityKey(gomock.Any(), "foo").Return(nil).AnyTimes()
|
||||
@@ -3726,6 +3768,10 @@ func (s *MethodTestSuite) TestPrebuilds() {
|
||||
dbm.EXPECT().CountInProgressPrebuilds(gomock.Any()).Return([]database.CountInProgressPrebuildsRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("CountPendingNonActivePrebuilds", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().CountPendingNonActivePrebuilds(gomock.Any()).Return([]database.CountPendingNonActivePrebuildsRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceWorkspace.All(), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPresetsAtFailureLimit", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetPresetsAtFailureLimit(gomock.Any(), int64(0)).Return([]database.GetPresetsAtFailureLimitRow{}, nil).AnyTimes()
|
||||
check.Args(int64(0)).Asserts(rbac.ResourceTemplate.All(), policy.ActionViewInsights)
|
||||
@@ -3893,9 +3939,9 @@ func (s *MethodTestSuite) TestOAuth2ProviderApps() {
|
||||
}))
|
||||
s.Run("GetOAuth2ProviderAppByRegistrationToken", s.Subtest(func(db database.Store, check *expects) {
|
||||
app := dbgen.OAuth2ProviderApp(s.T(), db, database.OAuth2ProviderApp{
|
||||
RegistrationAccessToken: sql.NullString{String: "test-token", Valid: true},
|
||||
RegistrationAccessToken: []byte("test-token"),
|
||||
})
|
||||
check.Args(sql.NullString{String: "test-token", Valid: true}).Asserts(rbac.ResourceOauth2App, policy.ActionRead).Returns(app)
|
||||
check.Args([]byte("test-token")).Asserts(rbac.ResourceOauth2App, policy.ActionRead).Returns(app)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -4528,14 +4574,28 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
|
||||
s.Run("ListAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeInterceptionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.AIBridgeInterception{}, nil).AnyTimes()
|
||||
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeInterceptionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeInterceptionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.AIBridgeInterception{}, nil).AnyTimes()
|
||||
db.EXPECT().ListAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeInterceptionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeInterceptionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAuthorizedAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeInterceptionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
@@ -430,24 +430,6 @@ func (m *expects) Errors(err error) *expects {
|
||||
return m
|
||||
}
|
||||
|
||||
// ErrorsWithPG is optional. If it is never called, it will not be asserted.
|
||||
// It will only be asserted if the test is running with a Postgres database.
|
||||
func (m *expects) ErrorsWithPG(err error) *expects {
|
||||
if dbtestutil.WillUsePostgres() {
|
||||
return m.Errors(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ErrorsWithInMemDB is optional. If it is never called, it will not be asserted.
|
||||
// It will only be asserted if the test is running with an in-memory database.
|
||||
func (m *expects) ErrorsWithInMemDB(err error) *expects {
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
return m.Errors(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *expects) FailSystemObjectChecks() *expects {
|
||||
return m.WithSuccessAuthorizer(func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error {
|
||||
if obj.Type == rbac.ResourceSystem.Type {
|
||||
|
||||
@@ -55,14 +55,10 @@ type WorkspaceBuildBuilder struct {
|
||||
resources []*sdkproto.Resource
|
||||
params []database.WorkspaceBuildParameter
|
||||
agentToken string
|
||||
dispo workspaceBuildDisposition
|
||||
jobStatus database.ProvisionerJobStatus
|
||||
taskAppID uuid.UUID
|
||||
}
|
||||
|
||||
type workspaceBuildDisposition struct {
|
||||
starting bool
|
||||
}
|
||||
|
||||
// WorkspaceBuild generates a workspace build for the provided workspace.
|
||||
// Pass a database.Workspace{} with a nil ID to also generate a new workspace.
|
||||
// Omitting the template ID on a workspace will also generate a new template
|
||||
@@ -120,19 +116,23 @@ func (b WorkspaceBuildBuilder) WithAgent(mutations ...func([]*sdkproto.Agent) []
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.taskAppID = uuid.New()
|
||||
if seed == nil {
|
||||
seed = &sdkproto.App{}
|
||||
}
|
||||
|
||||
var err error
|
||||
//nolint: revive // returns modified struct
|
||||
b.taskAppID, err = uuid.Parse(takeFirst(seed.Id, uuid.NewString()))
|
||||
require.NoError(b.t, err)
|
||||
|
||||
return b.Params(database.WorkspaceBuildParameter{
|
||||
Name: codersdk.AITaskPromptParameterName,
|
||||
Value: "list me",
|
||||
}).WithAgent(func(a []*sdkproto.Agent) []*sdkproto.Agent {
|
||||
a[0].Apps = []*sdkproto.App{
|
||||
{
|
||||
Id: takeFirst(seed.Id, b.taskAppID.String()),
|
||||
Slug: takeFirst(seed.Slug, "vcode"),
|
||||
Id: b.taskAppID.String(),
|
||||
Slug: takeFirst(seed.Slug, "task-app"),
|
||||
Url: takeFirst(seed.Url, ""),
|
||||
},
|
||||
}
|
||||
@@ -141,8 +141,17 @@ func (b WorkspaceBuildBuilder) WithTask(seed *sdkproto.App) WorkspaceBuildBuilde
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Starting() WorkspaceBuildBuilder {
|
||||
//nolint: revive // returns modified struct
|
||||
b.dispo.starting = true
|
||||
b.jobStatus = database.ProvisionerJobStatusRunning
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Pending() WorkspaceBuildBuilder {
|
||||
b.jobStatus = database.ProvisionerJobStatusPending
|
||||
return b
|
||||
}
|
||||
|
||||
func (b WorkspaceBuildBuilder) Canceled() WorkspaceBuildBuilder {
|
||||
b.jobStatus = database.ProvisionerJobStatusCanceled
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -195,11 +204,11 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
if b.ws.ID == uuid.Nil {
|
||||
// nolint: revive
|
||||
b.ws = dbgen.Workspace(b.t, b.db, b.ws)
|
||||
resp.Workspace = b.ws
|
||||
b.logger.Debug(context.Background(), "created workspace",
|
||||
slog.F("name", resp.Workspace.Name),
|
||||
slog.F("workspace_id", resp.Workspace.ID))
|
||||
slog.F("name", b.ws.Name),
|
||||
slog.F("workspace_id", b.ws.ID))
|
||||
}
|
||||
resp.Workspace = b.ws
|
||||
b.seed.WorkspaceID = b.ws.ID
|
||||
b.seed.InitiatorID = takeFirst(b.seed.InitiatorID, b.ws.OwnerID)
|
||||
|
||||
@@ -227,7 +236,11 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
require.NoError(b.t, err, "insert job")
|
||||
b.logger.Debug(context.Background(), "inserted provisioner job", slog.F("job_id", job.ID))
|
||||
|
||||
if b.dispo.starting {
|
||||
switch b.jobStatus {
|
||||
case database.ProvisionerJobStatusPending:
|
||||
// Provisioner jobs are created in 'pending' status
|
||||
b.logger.Debug(context.Background(), "pending the provisioner job")
|
||||
case database.ProvisionerJobStatusRunning:
|
||||
// might need to do this multiple times if we got a template version
|
||||
// import job as well
|
||||
b.logger.Debug(context.Background(), "looping to acquire provisioner job")
|
||||
@@ -251,7 +264,23 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
case database.ProvisionerJobStatusCanceled:
|
||||
// Set provisioner job status to 'canceled'
|
||||
b.logger.Debug(context.Background(), "canceling the provisioner job")
|
||||
err = b.db.UpdateProvisionerJobWithCancelByID(ownerCtx, database.UpdateProvisionerJobWithCancelByIDParams{
|
||||
ID: jobID,
|
||||
CanceledAt: sql.NullTime{
|
||||
Time: dbtime.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
CompletedAt: sql.NullTime{
|
||||
Time: dbtime.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(b.t, err, "cancel job")
|
||||
default:
|
||||
// By default, consider jobs in 'succeeded' status
|
||||
b.logger.Debug(context.Background(), "completing the provisioner job")
|
||||
err = b.db.UpdateProvisionerJobWithCompleteByID(ownerCtx, database.UpdateProvisionerJobWithCompleteByIDParams{
|
||||
ID: job.ID,
|
||||
@@ -273,6 +302,30 @@ func (b WorkspaceBuildBuilder) Do() WorkspaceResponse {
|
||||
slog.F("workspace_id", resp.Workspace.ID),
|
||||
slog.F("build_number", resp.Build.BuildNumber))
|
||||
|
||||
// If this is a task workspace, link it to the workspace build.
|
||||
task, err := b.db.GetTaskByWorkspaceID(ownerCtx, resp.Workspace.ID)
|
||||
if err != nil {
|
||||
if b.taskAppID != uuid.Nil {
|
||||
require.Fail(b.t, "task app configured but failed to get task by workspace id", err)
|
||||
}
|
||||
} else {
|
||||
if b.taskAppID == uuid.Nil {
|
||||
require.Fail(b.t, "task app not configured but workspace is a task workspace")
|
||||
}
|
||||
|
||||
app := mustWorkspaceAppByWorkspaceAndBuildAndAppID(ownerCtx, b.t, b.db, resp.Workspace.ID, resp.Build.BuildNumber, b.taskAppID)
|
||||
_, err = b.db.UpsertTaskWorkspaceApp(ownerCtx, database.UpsertTaskWorkspaceAppParams{
|
||||
TaskID: task.ID,
|
||||
WorkspaceBuildNumber: resp.Build.BuildNumber,
|
||||
WorkspaceAgentID: uuid.NullUUID{UUID: app.AgentID, Valid: true},
|
||||
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
|
||||
})
|
||||
require.NoError(b.t, err, "upsert task workspace app")
|
||||
b.logger.Debug(context.Background(), "linked task to workspace build",
|
||||
slog.F("task_id", task.ID),
|
||||
slog.F("build_number", resp.Build.BuildNumber))
|
||||
}
|
||||
|
||||
for i := range b.params {
|
||||
b.params[i].WorkspaceBuildID = resp.Build.ID
|
||||
}
|
||||
@@ -543,6 +596,12 @@ func (t TemplateVersionBuilder) Do() TemplateVersionResponse {
|
||||
t.params[i] = dbgen.TemplateVersionParameter(t.t, t.db, param)
|
||||
}
|
||||
|
||||
// Update response with template and version
|
||||
if resp.Template.ID == uuid.Nil && version.TemplateID.Valid {
|
||||
template, err := t.db.GetTemplateByID(ownerCtx, version.TemplateID.UUID)
|
||||
require.NoError(t.t, err)
|
||||
resp.Template = template
|
||||
}
|
||||
resp.TemplateVersion = version
|
||||
return resp
|
||||
}
|
||||
@@ -623,3 +682,30 @@ func takeFirst[Value comparable](values ...Value) Value {
|
||||
return v != empty
|
||||
})
|
||||
}
|
||||
|
||||
// mustWorkspaceAppByWorkspaceAndBuildAndAppID finds a workspace app by
|
||||
// workspace ID, build number, and app ID. It returns the workspace app
|
||||
// if found, otherwise fails the test.
|
||||
func mustWorkspaceAppByWorkspaceAndBuildAndAppID(ctx context.Context, t testing.TB, db database.Store, workspaceID uuid.UUID, buildNumber int32, appID uuid.UUID) database.WorkspaceApp {
|
||||
t.Helper()
|
||||
|
||||
agents, err := db.GetWorkspaceAgentsByWorkspaceAndBuildNumber(ctx, database.GetWorkspaceAgentsByWorkspaceAndBuildNumberParams{
|
||||
WorkspaceID: workspaceID,
|
||||
BuildNumber: buildNumber,
|
||||
})
|
||||
require.NoError(t, err, "get workspace agents")
|
||||
require.NotEmpty(t, agents, "no agents found for workspace")
|
||||
|
||||
for _, agent := range agents {
|
||||
apps, err := db.GetWorkspaceAppsByAgentID(ctx, agent.ID)
|
||||
require.NoError(t, err, "get workspace apps")
|
||||
for _, app := range apps {
|
||||
if app.ID == appID {
|
||||
return app
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.FailNow(t, "could not find workspace app", "workspaceID=%s buildNumber=%d appID=%s", workspaceID, buildNumber, appID)
|
||||
return database.WorkspaceApp{} // Unreachable.
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package dbgen
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -20,6 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -161,8 +161,8 @@ func Template(t testing.TB, db database.Store, seed database.Template) database.
|
||||
|
||||
func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func(*database.InsertAPIKeyParams)) (key database.APIKey, token string) {
|
||||
id, _ := cryptorand.String(10)
|
||||
secret, _ := cryptorand.String(22)
|
||||
hashed := sha256.Sum256([]byte(secret))
|
||||
secret, hashed, err := apikey.GenerateSecret(22)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := seed.IPAddress
|
||||
if !ip.Valid {
|
||||
@@ -179,7 +179,7 @@ func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func
|
||||
ID: takeFirst(seed.ID, id),
|
||||
// 0 defaults to 86400 at the db layer
|
||||
LifetimeSeconds: takeFirst(seed.LifetimeSeconds, 0),
|
||||
HashedSecret: takeFirstSlice(seed.HashedSecret, hashed[:]),
|
||||
HashedSecret: takeFirstSlice(seed.HashedSecret, hashed),
|
||||
IPAddress: ip,
|
||||
UserID: takeFirst(seed.UserID, uuid.New()),
|
||||
LastUsed: takeFirst(seed.LastUsed, dbtime.Now()),
|
||||
@@ -194,7 +194,7 @@ func APIKey(t testing.TB, db database.Store, seed database.APIKey, munge ...func
|
||||
for _, fn := range munge {
|
||||
fn(¶ms)
|
||||
}
|
||||
key, err := db.InsertAPIKey(genCtx, params)
|
||||
key, err = db.InsertAPIKey(genCtx, params)
|
||||
require.NoError(t, err, "insert api key")
|
||||
return key, fmt.Sprintf("%s-%s", key.ID, secret)
|
||||
}
|
||||
@@ -980,16 +980,15 @@ func WorkspaceResourceMetadatums(t testing.TB, db database.Store, seed database.
|
||||
}
|
||||
|
||||
func WorkspaceProxy(t testing.TB, db database.Store, orig database.WorkspaceProxy) (database.WorkspaceProxy, string) {
|
||||
secret, err := cryptorand.HexString(64)
|
||||
secret, hashedSecret, err := apikey.GenerateSecret(64)
|
||||
require.NoError(t, err, "generate secret")
|
||||
hashedSecret := sha256.Sum256([]byte(secret))
|
||||
|
||||
proxy, err := db.InsertWorkspaceProxy(genCtx, database.InsertWorkspaceProxyParams{
|
||||
ID: takeFirst(orig.ID, uuid.New()),
|
||||
Name: takeFirst(orig.Name, testutil.GetRandomName(t)),
|
||||
DisplayName: takeFirst(orig.DisplayName, testutil.GetRandomName(t)),
|
||||
Icon: takeFirst(orig.Icon, testutil.GetRandomName(t)),
|
||||
TokenHashedSecret: hashedSecret[:],
|
||||
TokenHashedSecret: hashedSecret,
|
||||
CreatedAt: takeFirst(orig.CreatedAt, dbtime.Now()),
|
||||
UpdatedAt: takeFirst(orig.UpdatedAt, dbtime.Now()),
|
||||
DerpEnabled: takeFirst(orig.DerpEnabled, false),
|
||||
@@ -1259,7 +1258,7 @@ func OAuth2ProviderApp(t testing.TB, db database.Store, seed database.OAuth2Prov
|
||||
Jwks: seed.Jwks, // pqtype.NullRawMessage{} is not comparable, use existing value
|
||||
SoftwareID: takeFirst(seed.SoftwareID, sql.NullString{}),
|
||||
SoftwareVersion: takeFirst(seed.SoftwareVersion, sql.NullString{}),
|
||||
RegistrationAccessToken: takeFirst(seed.RegistrationAccessToken, sql.NullString{}),
|
||||
RegistrationAccessToken: seed.RegistrationAccessToken,
|
||||
RegistrationClientUri: takeFirst(seed.RegistrationClientUri, sql.NullString{}),
|
||||
})
|
||||
require.NoError(t, err, "insert oauth2 app")
|
||||
|
||||
@@ -5,7 +5,6 @@ package dbmetrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
@@ -187,6 +186,13 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CountAIBridgeInterceptions").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuditLogs(ctx, arg)
|
||||
@@ -208,6 +214,13 @@ func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]data
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountPendingNonActivePrebuilds(ctx)
|
||||
m.queryLatencies.WithLabelValues("CountPendingNonActivePrebuilds").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountUnreadInboxNotificationsByUserID(ctx, userID)
|
||||
@@ -467,6 +480,13 @@ func (m queryMetricsStore) DeleteTailnetTunnel(ctx context.Context, arg database
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteTask(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteTask").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
@@ -1090,7 +1110,7 @@ func (m queryMetricsStore) GetOAuth2ProviderAppByID(ctx context.Context, id uuid
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
|
||||
func (m queryMetricsStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetOAuth2ProviderAppByRegistrationToken(ctx, registrationAccessToken)
|
||||
m.queryLatencies.WithLabelValues("GetOAuth2ProviderAppByRegistrationToken").Observe(time.Since(start).Seconds())
|
||||
@@ -1972,6 +1992,13 @@ func (m queryMetricsStore) GetWorkspaceAgentsCreatedAfter(ctx context.Context, c
|
||||
return agents, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceAgentsForMetrics(ctx context.Context) ([]database.GetWorkspaceAgentsForMetricsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspaceAgentsForMetrics(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspaceAgentsForMetrics").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
start := time.Now()
|
||||
agents, err := m.s.GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx, workspaceID)
|
||||
@@ -2224,6 +2251,13 @@ func (m queryMetricsStore) GetWorkspacesEligibleForTransition(ctx context.Contex
|
||||
return workspaces, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]database.GetWorkspacesForWorkspaceMetricsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetWorkspacesForWorkspaceMetrics(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetWorkspacesForWorkspaceMetrics").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) InsertAIBridgeInterception(ctx context.Context, arg database.InsertAIBridgeInterceptionParams) (database.AIBridgeInterception, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.InsertAIBridgeInterception(ctx, arg)
|
||||
@@ -2693,7 +2727,7 @@ func (m queryMetricsStore) InsertWorkspaceResourceMetadata(ctx context.Context,
|
||||
return metadata, err
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
|
||||
func (m queryMetricsStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeInterceptions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeInterceptions").Observe(time.Since(start).Seconds())
|
||||
@@ -2973,6 +3007,13 @@ func (m queryMetricsStore) UpdateOrganizationDeletedByID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdatePrebuildProvisionerJobWithCancel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdatePrebuildProvisionerJobWithCancel").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdatePresetPrebuildStatus(ctx, arg)
|
||||
@@ -3043,6 +3084,13 @@ func (m queryMetricsStore) UpdateTailnetPeerStatusByCoordinator(ctx context.Cont
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateTaskWorkspaceID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateTaskWorkspaceID").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
|
||||
start := time.Now()
|
||||
err := m.s.UpdateTemplateACLByID(ctx, arg)
|
||||
@@ -3701,9 +3749,16 @@ func (m queryMetricsStore) CountAuthorizedConnectionLogs(ctx context.Context, ar
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeInterceptions").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeInterceptions").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ package dbmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sql "database/sql"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
@@ -248,6 +247,21 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAIBridgeInterceptions", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions indicates an expected call of CountAIBridgeInterceptions.
|
||||
func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -263,6 +277,21 @@ func (mr *MockStoreMockRecorder) CountAuditLogs(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuditLogs", reflect.TypeOf((*MockStore)(nil).CountAuditLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeInterceptions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeInterceptions indicates an expected call of CountAuthorizedAIBridgeInterceptions.
|
||||
func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -323,6 +352,21 @@ func (mr *MockStoreMockRecorder) CountInProgressPrebuilds(ctx any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountInProgressPrebuilds", reflect.TypeOf((*MockStore)(nil).CountInProgressPrebuilds), ctx)
|
||||
}
|
||||
|
||||
// CountPendingNonActivePrebuilds mocks base method.
|
||||
func (m *MockStore) CountPendingNonActivePrebuilds(ctx context.Context) ([]database.CountPendingNonActivePrebuildsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountPendingNonActivePrebuilds", ctx)
|
||||
ret0, _ := ret[0].([]database.CountPendingNonActivePrebuildsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountPendingNonActivePrebuilds indicates an expected call of CountPendingNonActivePrebuilds.
|
||||
func (mr *MockStoreMockRecorder) CountPendingNonActivePrebuilds(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountPendingNonActivePrebuilds", reflect.TypeOf((*MockStore)(nil).CountPendingNonActivePrebuilds), ctx)
|
||||
}
|
||||
|
||||
// CountUnreadInboxNotificationsByUserID mocks base method.
|
||||
func (m *MockStore) CountUnreadInboxNotificationsByUserID(ctx context.Context, userID uuid.UUID) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -850,6 +894,21 @@ func (mr *MockStoreMockRecorder) DeleteTailnetTunnel(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetTunnel", reflect.TypeOf((*MockStore)(nil).DeleteTailnetTunnel), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteTask mocks base method.
|
||||
func (m *MockStore) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) (database.TaskTable, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTask", ctx, arg)
|
||||
ret0, _ := ret[0].(database.TaskTable)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteTask indicates an expected call of DeleteTask.
|
||||
func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2280,7 +2339,7 @@ func (mr *MockStoreMockRecorder) GetOAuth2ProviderAppByID(ctx, id any) *gomock.C
|
||||
}
|
||||
|
||||
// GetOAuth2ProviderAppByRegistrationToken mocks base method.
|
||||
func (m *MockStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken sql.NullString) (database.OAuth2ProviderApp, error) {
|
||||
func (m *MockStore) GetOAuth2ProviderAppByRegistrationToken(ctx context.Context, registrationAccessToken []byte) (database.OAuth2ProviderApp, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOAuth2ProviderAppByRegistrationToken", ctx, registrationAccessToken)
|
||||
ret0, _ := ret[0].(database.OAuth2ProviderApp)
|
||||
@@ -4199,6 +4258,21 @@ func (mr *MockStoreMockRecorder) GetWorkspaceAgentsCreatedAfter(ctx, createdAt a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsCreatedAfter", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsCreatedAfter), ctx, createdAt)
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsForMetrics mocks base method.
|
||||
func (m *MockStore) GetWorkspaceAgentsForMetrics(ctx context.Context) ([]database.GetWorkspaceAgentsForMetricsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWorkspaceAgentsForMetrics", ctx)
|
||||
ret0, _ := ret[0].([]database.GetWorkspaceAgentsForMetricsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsForMetrics indicates an expected call of GetWorkspaceAgentsForMetrics.
|
||||
func (mr *MockStoreMockRecorder) GetWorkspaceAgentsForMetrics(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentsForMetrics", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentsForMetrics), ctx)
|
||||
}
|
||||
|
||||
// GetWorkspaceAgentsInLatestBuildByWorkspaceID mocks base method.
|
||||
func (m *MockStore) GetWorkspaceAgentsInLatestBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgent, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4739,6 +4813,21 @@ func (mr *MockStoreMockRecorder) GetWorkspacesEligibleForTransition(ctx, now any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspacesEligibleForTransition", reflect.TypeOf((*MockStore)(nil).GetWorkspacesEligibleForTransition), ctx, now)
|
||||
}
|
||||
|
||||
// GetWorkspacesForWorkspaceMetrics mocks base method.
|
||||
func (m *MockStore) GetWorkspacesForWorkspaceMetrics(ctx context.Context) ([]database.GetWorkspacesForWorkspaceMetricsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetWorkspacesForWorkspaceMetrics", ctx)
|
||||
ret0, _ := ret[0].([]database.GetWorkspacesForWorkspaceMetricsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetWorkspacesForWorkspaceMetrics indicates an expected call of GetWorkspacesForWorkspaceMetrics.
|
||||
func (mr *MockStoreMockRecorder) GetWorkspacesForWorkspaceMetrics(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspacesForWorkspaceMetrics", reflect.TypeOf((*MockStore)(nil).GetWorkspacesForWorkspaceMetrics), ctx)
|
||||
}
|
||||
|
||||
// InTx mocks base method.
|
||||
func (m *MockStore) InTx(arg0 func(database.Store) error, arg1 *database.TxOptions) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -5744,10 +5833,10 @@ func (mr *MockStoreMockRecorder) InsertWorkspaceResourceMetadata(ctx, arg any) *
|
||||
}
|
||||
|
||||
// ListAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.AIBridgeInterception, error) {
|
||||
func (m *MockStore) ListAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeInterceptions", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.AIBridgeInterception)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -5804,10 +5893,10 @@ func (mr *MockStoreMockRecorder) ListAIBridgeUserPromptsByInterceptionIDs(ctx, i
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.AIBridgeInterception, error) {
|
||||
func (m *MockStore) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeInterceptionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeInterceptions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.AIBridgeInterception)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeInterceptionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -6377,6 +6466,21 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationDeletedByID", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationDeletedByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePrebuildProvisionerJobWithCancel", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel indicates an expected call of UpdatePrebuildProvisionerJobWithCancel.
|
||||
func (mr *MockStoreMockRecorder) UpdatePrebuildProvisionerJobWithCancel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePrebuildProvisionerJobWithCancel", reflect.TypeOf((*MockStore)(nil).UpdatePrebuildProvisionerJobWithCancel), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdatePresetPrebuildStatus mocks base method.
|
||||
func (m *MockStore) UpdatePresetPrebuildStatus(ctx context.Context, arg database.UpdatePresetPrebuildStatusParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6518,6 +6622,21 @@ func (mr *MockStoreMockRecorder) UpdateTailnetPeerStatusByCoordinator(ctx, arg a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTailnetPeerStatusByCoordinator", reflect.TypeOf((*MockStore)(nil).UpdateTailnetPeerStatusByCoordinator), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateTaskWorkspaceID mocks base method.
|
||||
func (m *MockStore) UpdateTaskWorkspaceID(ctx context.Context, arg database.UpdateTaskWorkspaceIDParams) (database.TaskTable, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateTaskWorkspaceID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.TaskTable)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateTaskWorkspaceID indicates an expected call of UpdateTaskWorkspaceID.
|
||||
func (mr *MockStoreMockRecorder) UpdateTaskWorkspaceID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWorkspaceID", reflect.TypeOf((*MockStore)(nil).UpdateTaskWorkspaceID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateTemplateACLByID mocks base method.
|
||||
func (m *MockStore) UpdateTemplateACLByID(ctx context.Context, arg database.UpdateTemplateACLByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -52,10 +52,6 @@ func (w *wrapUpsertDB) UpsertTemplateUsageStats(ctx context.Context) error {
|
||||
func TestRollup_TwoInstancesUseLocking(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("Skipping test; only works with PostgreSQL.")
|
||||
}
|
||||
|
||||
db, ps := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
|
||||
@@ -23,13 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub.
|
||||
// TODO(hugodutka): since we removed the in-memory database, this is always true,
|
||||
// and we need to remove this function. https://github.com/coder/internal/issues/758
|
||||
func WillUsePostgres() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type options struct {
|
||||
fixedTimezone string
|
||||
dumpOnFailure bool
|
||||
@@ -75,10 +68,6 @@ func withReturnSQLDB(f func(*sql.DB)) Option {
|
||||
func NewDBWithSQLDB(t testing.TB, opts ...Option) (database.Store, pubsub.Pubsub, *sql.DB) {
|
||||
t.Helper()
|
||||
|
||||
if !WillUsePostgres() {
|
||||
t.Fatal("cannot use NewDBWithSQLDB without PostgreSQL, consider adding `if !dbtestutil.WillUsePostgres() { t.Skip() }` to this test")
|
||||
}
|
||||
|
||||
var sqlDB *sql.DB
|
||||
opts = append(opts, withReturnSQLDB(func(db *sql.DB) {
|
||||
sqlDB = db
|
||||
|
||||
@@ -20,9 +20,6 @@ func TestMain(m *testing.M) {
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
connect, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
@@ -37,9 +34,6 @@ func TestOpen(t *testing.T) {
|
||||
|
||||
func TestOpen_InvalidDBFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
_, err := dbtestutil.Open(t, dbtestutil.WithDBFrom("__invalid__"))
|
||||
require.Error(t, err)
|
||||
@@ -49,9 +43,6 @@ func TestOpen_InvalidDBFrom(t *testing.T) {
|
||||
|
||||
func TestOpen_ValidDBFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
// first check if we can create a new template db
|
||||
dsn, err := dbtestutil.Open(t, dbtestutil.WithDBFrom(""))
|
||||
@@ -115,9 +106,6 @@ func TestOpen_ValidDBFrom(t *testing.T) {
|
||||
func TestOpen_Panic(t *testing.T) {
|
||||
t.Skip("unskip this to manually test that we don't leak a database into postgres")
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
_, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
@@ -127,9 +115,6 @@ func TestOpen_Panic(t *testing.T) {
|
||||
func TestOpen_Timeout(t *testing.T) {
|
||||
t.Skip("unskip this and set a short timeout to manually test that we don't leak a database into postgres")
|
||||
t.Parallel()
|
||||
if !dbtestutil.WillUsePostgres() {
|
||||
t.Skip("this test requires postgres")
|
||||
}
|
||||
|
||||
_, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user