Compare commits

..

8 Commits

Author SHA1 Message Date
Dean Sheather ee8e8cb805 fix: initialize pseudo console with default size for SSH sessions [2.26] (#20491)
> Resolved an invalid parameter error (-2147024809) during PTY creation
on Windows 11 22H2 (but not only) when connecting via JetBrains Toolbox
which spawns the native SSH client with `-tt` forcing PTY allocation
even though there is no "terminal" on the client side to query its size.
>
> CreatePseudoConsole doesn't accept a 0x0 (zero width and zero height)
console size and unfortunately, there is NO explicit documentation in
the official Microsoft documentation that states the minimum valid
values or explicitly prohibits 0x0.
>
> Looking at real-world implementations in the search results, all
examples use reasonable non-zero values.
>
> I tested this with a local Windows VM registered to dev.coder.com i.e.
externally managed workspace.

Relates to #20468

Co-authored-by: Faur Ioan-Aurel <fioan89@gmail.com>
2025-10-28 22:47:03 +11:00
Danielle Maywood 4793806569 chore: upgrade coder/clistat to v1.1.1 (#20322) (#20324)
coder/clistat has received a handful of bug fixes so we're back-porting
these bug fixes to 2.26

---

Cherry-picked from 9bef5de30d
2025-10-16 15:29:05 +01:00
Dean Sheather 03440f6ae2 fix: avoid connection logging crashes in agent [2.26] (#20306)
# For release 2.26

- Ignore errors when reporting a connection from the server, just log
them instead
- Translate connection log IP `localhost` to `127.0.0.1` on both the
server and the agent
- Temporary fix: convert invalid IPs to `127.0.0.1` since the database
forbids NULL

Relates to #20194
2025-10-16 01:28:10 +11:00
Cian Johnston 7afe6c813b fix(coderd): ensure agent WebSocket conn is cleaned up (#19711) (#20094)
Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
2025-10-01 15:37:56 -05:00
Kacper Sawicki 536920459d [2.26 backport] perf(enterprise): remove expensive GetWorkspaces query from entitlements (#19747) (#19756)
This PR addresses the significant database load issue where the
GetWorkspaces query was causing performance problems in the license
entitlements code.

cherry picked from commit 3074547
2025-09-11 09:48:32 +02:00
Kacper Sawicki c0f1b9d73e [2.26 backport] fix: pin pg_dump version when generating schema (#19696) (#19765)
This is required by #19756 

The latest release of all `pg_dump` major versions, going back to 13,
started inserting `\restrict` `\unrestrict` keywords into dumps. This
currently breaks sqlc in `gen/dump` and our check migration script. Full
details of the postgres change are available here:
https://git.postgresql.org/gitweb/?p=postgresql.git;a=commitdiff;h=575f54d4c

To fix, we'll always use the `pg_dump` in our postgres 13.21 docker
image for schema dumps, instead of what's on the runner/local machine.

Coder doesn't restore from postgres dumps, so we're not vulnerable to
attacks that would be patched by the latest postgres version.
Regardless, we'll unpin ASAP.

Once sqlc is updated to handle these keywords, we need to start
stripping them when comparing the schema in the migration check script,
and then we can unpin the pg_dump version. This is being tracked at
https://github.com/coder/internal/issues/965

Co-authored-by: Ethan <39577870+ethanndickson@users.noreply.github.com>
2025-09-11 09:32:58 +02:00
Stephen Kirby a056cb6577 chore: add last commit from cherry-pick list for release (#19679)
Co-authored-by: Spike Curtis <spike@coder.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-02 19:19:04 -05:00
Stephen Kirby 0a73f842b3 fix: merge cherry-picked items for v2.26.0 (#19678)
Co-authored-by: Cian Johnston <cian@coder.com>
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
Co-authored-by: Hugo Dutka <hugo@coder.com>
Co-authored-by: Kacper Sawicki <kacper@coder.com>
Co-authored-by: Atif Ali <atif@coder.com>
Co-authored-by: Danielle Maywood <danielle@themaywoods.com>
Co-authored-by: Susana Ferreira <susana@coder.com>
Co-authored-by: Brett Kolodny <brettkolodny@gmail.com>
Co-authored-by: Dean Sheather <dean@deansheather.com>
Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
2025-09-02 16:19:17 -05:00
1343 changed files with 29620 additions and 84308 deletions
+2 -10
View File
@@ -91,9 +91,6 @@
## Systematic Debugging Approach
YOU MUST ALWAYS find the root cause of any issue you are debugging
YOU MUST NEVER fix a symptom or add a workaround instead of finding a root cause, even if it is faster.
### Multi-Issue Problem Solving
When facing multiple failing tests or complex integration issues:
@@ -101,21 +98,16 @@ When facing multiple failing tests or complex integration issues:
1. **Identify Root Causes**:
- Run failing tests individually to isolate issues
- Use LSP tools to trace through call chains
- Read Error Messages Carefully: Check both compilation and runtime errors
- Reproduce Consistently: Ensure you can reliably reproduce the issue before investigating
- Check Recent Changes: What changed that could have caused this? Git diff, recent commits, etc.
- When You Don't Know: Say "I don't understand X" rather than pretending to know
- Check both compilation and runtime errors
2. **Fix in Logical Order**:
- Address compilation issues first (imports, syntax)
- Fix authorization and RBAC issues next
- Resolve business logic and validation issues
- Handle edge cases and race conditions last
- IF your first fix doesn't work, STOP and re-analyze rather than adding more fixes
3. **Verification Strategy**:
- Always Test each fix individually before moving to next issue
- Verify Before Continuing: Did your test work? If not, form new hypothesis - don't add more fixes
- Test each fix individually before moving to next issue
- Use `make lint` and `make gen` after database changes
- Verify RFC compliance with actual specifications
- Run comprehensive test suites before considering complete
+3 -7
View File
@@ -40,15 +40,11 @@
- Use proper error types
- Pattern: `xerrors.Errorf("failed to X: %w", err)`
## Naming Conventions
### Naming Conventions
- Names MUST tell what code does, not how it's implemented or its history
- Follow Go and TypeScript naming conventions
- When changing code, never document the old behavior or the behavior change
- NEVER use implementation details in names (e.g., "ZodValidator", "MCPWrapper", "JSONParser")
- NEVER use temporal/historical context in names (e.g., "LegacyHandler", "UnifiedTool", "ImprovedInterface", "EnhancedParser")
- NEVER use pattern names unless they add clarity (e.g., prefer "Tool" over "ToolFactory")
- Use clear, descriptive names
- Abbreviate only when obvious
- Follow Go and TypeScript naming conventions
### Comments
+3 -11
View File
@@ -1,21 +1,13 @@
#!/bin/sh
install_devcontainer_cli() {
set -e
echo "🔧 Installing DevContainer CLI..."
cd "$(dirname "$0")/../tools/devcontainer-cli"
npm ci --omit=dev
ln -sf "$(pwd)/node_modules/.bin/devcontainer" "$(npm config get prefix)/bin/devcontainer"
npm install -g @devcontainers/cli@0.80.0 --integrity=sha512-w2EaxgjyeVGyzfA/KUEZBhyXqu/5PyWNXcnrXsZOBrt3aN2zyGiHrXoG54TF6K0b5DSCF01Rt5fnIyrCeFzFKw==
}
install_ssh_config() {
echo "🔑 Installing SSH configuration..."
if [ -d /mnt/home/coder/.ssh ]; then
rsync -a /mnt/home/coder/.ssh/ ~/.ssh/
chmod 0700 ~/.ssh
else
echo "⚠️ SSH directory not found."
fi
rsync -a /mnt/home/coder/.ssh/ ~/.ssh/
chmod 0700 ~/.ssh
}
install_git_config() {
-26
View File
@@ -1,26 +0,0 @@
{
"name": "devcontainer-cli",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "devcontainer-cli",
"version": "1.0.0",
"dependencies": {
"@devcontainers/cli": "^0.80.0"
}
},
"node_modules/@devcontainers/cli": {
"version": "0.80.0",
"resolved": "https://registry.npmjs.org/@devcontainers/cli/-/cli-0.80.0.tgz",
"integrity": "sha512-w2EaxgjyeVGyzfA/KUEZBhyXqu/5PyWNXcnrXsZOBrt3aN2zyGiHrXoG54TF6K0b5DSCF01Rt5fnIyrCeFzFKw==",
"bin": {
"devcontainer": "devcontainer.js"
},
"engines": {
"node": "^16.13.0 || >=18.0.0"
}
}
}
}
@@ -1,8 +0,0 @@
{
"name": "devcontainer-cli",
"private": true,
"version": "1.0.0",
"dependencies": {
"@devcontainers/cli": "^0.80.0"
}
}
-1
View File
@@ -26,6 +26,5 @@ ignorePatterns:
- pattern: "claude.ai"
- pattern: "splunk.com"
- pattern: "stackoverflow.com/questions"
- pattern: "developer.hashicorp.com/terraform/language"
aliveStatusCodes:
- 200
+1 -1
View File
@@ -4,7 +4,7 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.24.10"
default: "1.24.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
+1 -1
View File
@@ -16,7 +16,7 @@ runs:
- name: Setup Node
uses: actions/setup-node@0a44ba7841725637a19e28fa30b79a866c81b0a6 # v4.0.4
with:
node-version: 22.19.0
node-version: 20.19.4
# See https://github.com/actions/setup-node#caching-global-packages-data
cache: "pnpm"
cache-dependency-path: ${{ inputs.directory }}/pnpm-lock.yaml
-4
View File
@@ -80,9 +80,6 @@ updates:
mui:
patterns:
- "@mui*"
radix:
patterns:
- "@radix-ui/*"
react:
patterns:
- "react"
@@ -107,7 +104,6 @@ updates:
- dependency-name: "*"
update-types:
- version-update:semver-major
- dependency-name: "@playwright/test"
open-pull-requests-limit: 15
- package-ecosystem: "terraform"
@@ -0,0 +1,34 @@
app = "sao-paulo-coder"
primary_region = "gru"
[experimental]
entrypoint = ["/bin/sh", "-c", "CODER_DERP_SERVER_RELAY_URL=\"http://[${FLY_PRIVATE_IP}]:3000\" /opt/coder wsproxy server"]
auto_rollback = true
[build]
image = "ghcr.io/coder/coder-preview:main"
[env]
CODER_ACCESS_URL = "https://sao-paulo.fly.dev.coder.com"
CODER_HTTP_ADDRESS = "0.0.0.0:3000"
CODER_PRIMARY_ACCESS_URL = "https://dev.coder.com"
CODER_WILDCARD_ACCESS_URL = "*--apps.sao-paulo.fly.dev.coder.com"
CODER_VERBOSE = "true"
[http_service]
internal_port = 3000
force_https = true
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
memory_mb = 512
-4
View File
@@ -1,5 +1 @@
<!--
If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.
-->
+202 -96
View File
@@ -4,7 +4,6 @@ on:
push:
branches:
- main
- release/*
pull_request:
workflow_dispatch:
@@ -35,12 +34,12 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -124,7 +123,7 @@ jobs:
# runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
# steps:
# - name: Checkout
# uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
# uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
# with:
# fetch-depth: 1
# # See: https://github.com/stefanzweifel/git-auto-commit-action?tab=readme-ov-file#commits-made-by-this-action-do-not-trigger-new-workflow-runs
@@ -157,12 +156,12 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -181,7 +180,7 @@ jobs:
echo "LINT_CACHE_DIR=$dir" >> "$GITHUB_ENV"
- name: golangci-lint cache
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4
with:
path: |
${{ env.LINT_CACHE_DIR }}
@@ -191,7 +190,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@80c8a4945eec0f6d464eaf9e65ed98ef085283d1 # v1.38.1
uses: crate-ci/typos@52bd719c2c91f9d676e2aa359fc8e0db8925e6d8 # v1.35.3
with:
config: .github/workflows/typos.toml
@@ -204,7 +203,7 @@ jobs:
# Needed for helm chart linting
- name: Install helm
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
with:
version: v3.9.2
@@ -235,12 +234,12 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -291,12 +290,12 @@ jobs:
timeout-minutes: 7
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -341,7 +340,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
@@ -367,7 +366,7 @@ jobs:
uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b # v0.1.0
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -376,6 +375,13 @@ jobs:
id: go-paths
uses: ./.github/actions/setup-go-paths
- name: Download Go Build Cache
id: download-go-build-cache
uses: ./.github/actions/test-cache/download
with:
key-prefix: test-go-build-${{ runner.os }}-${{ runner.arch }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Setup Go
uses: ./.github/actions/setup-go
with:
@@ -383,7 +389,8 @@ jobs:
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
# Cache is already downloaded above
use-cache: false
- name: Setup Terraform
uses: ./.github/actions/setup-tf
@@ -492,11 +499,17 @@ jobs:
make test
- name: Upload failed test db dumps
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-db-dump-${{matrix.os}}
path: "**/*.test.sql"
- name: Upload Go Build Cache
uses: ./.github/actions/test-cache/upload
with:
cache-key: ${{ steps.download-go-build-cache.outputs.cache-key }}
cache-path: ${{ steps.go-paths.outputs.cached-dirs }}
- name: Upload Test Cache
uses: ./.github/actions/test-cache/upload
with:
@@ -518,6 +531,9 @@ jobs:
with:
api-key: ${{ secrets.DATADOG_API_KEY }}
# NOTE: this could instead be defined as a matrix strategy, but we want to
# only block merging if tests on postgres 13 fail. Using a matrix strategy
# here makes the check in the above `required` job rather complicated.
test-go-pg-17:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
needs:
@@ -530,12 +546,12 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -579,12 +595,12 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -639,12 +655,12 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -666,12 +682,12 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -699,12 +715,12 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -748,7 +764,7 @@ jobs:
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: failed-test-videos${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/*.webm
@@ -756,7 +772,7 @@ jobs:
- name: Upload pprof dumps
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: debug-pprof-dumps${{ matrix.variant.premium && '-premium' || '' }}
path: ./site/test-results/**/debug-pprof-*.txt
@@ -771,12 +787,12 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
# 👇 Ensures Chromatic can read your full git history
fetch-depth: 0
@@ -792,7 +808,7 @@ jobs:
# the check to pass. This is desired in PRs, but not in mainline.
- name: Publish to Chromatic (non-mainline)
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@58d9ffb36c90c97a02d061544ecc849cc4a242a9 # v13.1.3
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -824,7 +840,7 @@ jobs:
# infinitely "in progress" in mainline unless we re-review each build.
- name: Publish to Chromatic (mainline)
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
uses: chromaui/action@bc2d84ad2b60813a67d995c5582d696104a19383 # v13.3.2
uses: chromaui/action@58d9ffb36c90c97a02d061544ecc849cc4a242a9 # v13.1.3
env:
NODE_OPTIONS: "--max_old_space_size=4096"
STORYBOOK: true
@@ -852,12 +868,12 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
# 0 is required here for version.sh to work.
fetch-depth: 0
@@ -906,12 +922,10 @@ jobs:
required:
runs-on: ubuntu-latest
needs:
- changes
- fmt
- lint
- gen
- test-go-pg
- test-go-pg-17
- test-go-race-pg
- test-js
- test-e2e
@@ -923,19 +937,17 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Ensure required checks
run: | # zizmor: ignore[template-injection] We're just reading needs.x.result here, no risk of injection
echo "Checking required checks"
echo "- changes: ${{ needs.changes.result }}"
echo "- fmt: ${{ needs.fmt.result }}"
echo "- lint: ${{ needs.lint.result }}"
echo "- gen: ${{ needs.gen.result }}"
echo "- test-go-pg: ${{ needs.test-go-pg.result }}"
echo "- test-go-pg-17: ${{ needs.test-go-pg-17.result }}"
echo "- test-go-race-pg: ${{ needs.test-go-race-pg.result }}"
echo "- test-js: ${{ needs.test-js.result }}"
echo "- test-e2e: ${{ needs.test-e2e.result }}"
@@ -956,12 +968,12 @@ jobs:
needs: changes
# We always build the dylibs on Go changes to verify we're not merging unbuildable code,
# but they need only be signed and uploaded on coder/coder main.
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
runs-on: ${{ github.repository_owner == 'coder' && 'depot-macos-latest' || 'macos-latest' }}
steps:
# Harden Runner doesn't work on macOS
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -984,7 +996,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Install rcodesign
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
run: |
set -euo pipefail
wget -O /tmp/rcodesign.tar.gz https://github.com/indygreg/apple-platform-rs/releases/download/apple-codesign%2F0.22.0/apple-codesign-0.22.0-macos-universal.tar.gz
@@ -995,7 +1007,7 @@ jobs:
rm /tmp/rcodesign.tar.gz
- name: Setup Apple Developer certificate and API key
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
run: |
set -euo pipefail
touch /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
@@ -1016,13 +1028,13 @@ jobs:
make gen/mark-fresh
make build/coder-dylib
env:
CODER_SIGN_DARWIN: ${{ (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && '1' || '0' }}
CODER_SIGN_DARWIN: ${{ github.ref == 'refs/heads/main' && '1' || '0' }}
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -1031,7 +1043,7 @@ jobs:
retention-days: 7
- name: Delete Apple Developer certificate and API key
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
run: rm -f /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
check-build:
@@ -1043,12 +1055,12 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -1081,7 +1093,7 @@ jobs:
needs:
- changes
- build-dylib
if: (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
if: github.ref == 'refs/heads/main' && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-22.04' }}
permissions:
# Necessary to push docker images to ghcr.io.
@@ -1098,18 +1110,18 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -1144,7 +1156,7 @@ jobs:
# Necessary for signing Windows binaries.
- name: Setup Java
uses: actions/setup-java@dded0888837ed1f317902acf8a20df0ad188d165 # v5.0.0
uses: actions/setup-java@c5195efecf7bdfc987ee8bae7a71cb8b11521c00 # v4.7.1
with:
distribution: "zulu"
java-version: "11.0"
@@ -1177,17 +1189,17 @@ jobs:
# Setup GCloud for signing Windows binaries.
- name: Authenticate to Google Cloud
id: gcloud_auth
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
with:
workload_identity_provider: ${{ vars.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
token_format: "access_token"
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -1234,45 +1246,40 @@ jobs:
id: build-docker
env:
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
CODER_IMAGE_TAG_PREFIX: main
DOCKER_CLI_EXPERIMENTAL: "enabled"
run: |
set -euxo pipefail
# build Docker images for each architecture
version="$(./scripts/version.sh)"
tag="${version//+/-}"
tag="main-${version//+/-}"
echo "tag=$tag" >> "$GITHUB_OUTPUT"
# build images for each architecture
# note: omitting the -j argument to avoid race conditions when pushing
make build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
# only push if we are on main branch or release branch
if [[ "${GITHUB_REF}" == "refs/heads/main" || "${GITHUB_REF}" == refs/heads/release/* ]]; then
# only push if we are on main branch
if [ "${GITHUB_REF}" == "refs/heads/main" ]; then
# build and push multi-arch manifest, this depends on the other images
# being pushed so will automatically push them
# note: omitting the -j argument to avoid race conditions when pushing
make push/build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
# Define specific tags
tags=("$tag")
if [ "${GITHUB_REF}" == "refs/heads/main" ]; then
tags+=("main" "latest")
elif [[ "${GITHUB_REF}" == refs/heads/release/* ]]; then
tags+=("release-${GITHUB_REF#refs/heads/release/}")
fi
tags=("$tag" "main" "latest")
# Create and push a multi-arch manifest for each tag
# we are adding `latest` tag and keeping `main` for backward
# compatibality
for t in "${tags[@]}"; do
echo "Pushing multi-arch manifest for tag: $t"
# shellcheck disable=SC2046
./scripts/build_docker_multiarch.sh \
--push \
--target "ghcr.io/coder/coder-preview:$t" \
--version "$version" \
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
# shellcheck disable=SC2046
./scripts/build_docker_multiarch.sh \
--push \
--target "ghcr.io/coder/coder-preview:$t" \
--version "$version" \
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
done
fi
@@ -1316,7 +1323,7 @@ jobs:
id: attest_main
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: "ghcr.io/coder/coder-preview:main"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1353,7 +1360,7 @@ jobs:
id: attest_latest
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: "ghcr.io/coder/coder-preview:latest"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1390,7 +1397,7 @@ jobs:
id: attest_version
if: github.ref == 'refs/heads/main'
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
predicate-type: "https://slsa.dev/provenance/v1"
@@ -1454,7 +1461,7 @@ jobs:
- name: Upload build artifacts
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: coder
path: |
@@ -1463,28 +1470,112 @@ jobs:
./build/*.deb
retention-days: 7
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
deploy:
name: "deploy"
runs-on: ubuntu-latest
timeout-minutes: 30
needs:
- changes
- build
if: |
(github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/'))
github.ref == 'refs/heads/main' && !github.event.pull_request.head.repo.fork
&& needs.changes.outputs.docs-only == 'false'
&& !github.event.pull_request.head.repo.fork
uses: ./.github/workflows/deploy.yaml
with:
image: ${{ needs.build.outputs.IMAGE }}
permissions:
contents: read
id-token: write
packages: write # to retag image as dogfood
secrets:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
FLY_PARIS_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
FLY_JNB_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
with:
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
- name: Set up Flux CLI
uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.5.1"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@8e574c49425fa7efed1e74650a449bfa6a23308a # v2.3.4
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
- name: Reconcile Flux
run: |
set -euxo pipefail
flux --namespace flux-system reconcile source git flux-system
flux --namespace flux-system reconcile source git coder-main
flux --namespace flux-system reconcile kustomization flux-system
flux --namespace flux-system reconcile kustomization coder
flux --namespace flux-system reconcile source chart coder-coder
flux --namespace flux-system reconcile source chart coder-coder-provisioner
flux --namespace coder reconcile helmrelease coder
flux --namespace coder reconcile helmrelease coder-provisioner
# Just updating Flux is usually not enough. The Helm release may get
# redeployed, but unless something causes the Deployment to update the
# pods won't be recreated. It's important that the pods get recreated,
# since we use `imagePullPolicy: Always` to ensure we're running the
# latest image.
- name: Rollout Deployment
run: |
set -euxo pipefail
kubectl --namespace coder rollout restart deployment/coder
kubectl --namespace coder rollout status deployment/coder
kubectl --namespace coder rollout restart deployment/coder-provisioner
kubectl --namespace coder rollout status deployment/coder-provisioner
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged
deploy-wsproxies:
runs-on: ubuntu-latest
needs: build
if: github.ref == 'refs/heads/main' && !github.event.pull_request.head.repo.fork
steps:
- name: Harden Runner
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
- name: Setup flyctl
uses: superfly/flyctl-actions/setup-flyctl@fc53c09e1bc3be6f54706524e3b82c4f462f77be # v1.5
- name: Deploy workspace proxies
run: |
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
IMAGE: ${{ needs.build.outputs.IMAGE }}
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
# sqlc-vet runs a postgres docker container, runs Coder migrations, and then
# runs sqlc-vet to ensure all queries are valid. This catches any mistakes
@@ -1495,12 +1586,12 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -1523,7 +1614,6 @@ jobs:
steps:
- name: Send Slack notification
run: |
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
curl -X POST -H 'Content-type: application/json' \
--data '{
"blocks": [
@@ -1535,6 +1625,23 @@ jobs:
"emoji": true
}
},
{
"type": "section",
"fields": [
{
"type": "mrkdwn",
"text": "*Workflow:*\n'"${GITHUB_WORKFLOW}"'"
},
{
"type": "mrkdwn",
"text": "*Committer:*\n'"${GITHUB_ACTOR}"'"
},
{
"type": "mrkdwn",
"text": "*Commit:*\n'"${GITHUB_SHA}"'"
}
]
},
{
"type": "section",
"text": {
@@ -1546,7 +1653,7 @@ jobs:
"type": "section",
"text": {
"type": "mrkdwn",
"text": '"$ESCAPED_PROMPT"'
"text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame."
}
}
]
@@ -1554,4 +1661,3 @@ jobs:
env:
SLACK_WEBHOOK: ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }}
RUN_URL: "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
BLINK_CI_FAILURE_PROMPT: ${{ vars.BLINK_CI_FAILURE_PROMPT }}
+1 -1
View File
@@ -53,7 +53,7 @@ jobs:
if: ${{ github.event_name == 'pull_request_target' && !github.event.pull_request.draft }}
steps:
- name: release-labels
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
# This script ensures PR title and labels are in sync:
#
-172
View File
@@ -1,172 +0,0 @@
name: deploy
on:
# Via workflow_call, called from ci.yaml
workflow_call:
inputs:
image:
description: "Image and tag to potentially deploy. Current branch will be validated against should-deploy check."
required: true
type: string
secrets:
FLY_API_TOKEN:
required: true
FLY_PARIS_CODER_PROXY_SESSION_TOKEN:
required: true
FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN:
required: true
FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN:
required: true
FLY_JNB_CODER_PROXY_SESSION_TOKEN:
required: true
permissions:
contents: read
concurrency:
group: ${{ github.workflow }} # no per-branch concurrency
cancel-in-progress: false
jobs:
# Determines if the given branch should be deployed to dogfood.
should-deploy:
name: should-deploy
runs-on: ubuntu-latest
outputs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- name: Check if deploy is enabled
id: check
run: |
set -euo pipefail
verdict="$(./scripts/should_deploy.sh)"
echo "verdict=$verdict" >> "$GITHUB_OUTPUT"
deploy:
name: "deploy"
runs-on: ubuntu-latest
timeout-minutes: 30
needs: should-deploy
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
permissions:
contents: read
id-token: write
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
with:
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Set up Google Cloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
- name: Set up Flux CLI
uses: fluxcd/flux2/action@4a15fa6a023259353ef750acf1c98fe88407d4d0 # v2.7.2
with:
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
version: "2.7.0"
- name: Get Cluster Credentials
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
with:
cluster_name: dogfood-v2
location: us-central1-a
project_id: coder-dogfood-v2
# Retag image as dogfood while maintaining the multi-arch manifest
- name: Tag image as dogfood
run: docker buildx imagetools create --tag "ghcr.io/coder/coder-preview:dogfood" "$IMAGE"
env:
IMAGE: ${{ inputs.image }}
- name: Reconcile Flux
run: |
set -euxo pipefail
flux --namespace flux-system reconcile source git flux-system
flux --namespace flux-system reconcile source git coder-main
flux --namespace flux-system reconcile kustomization flux-system
flux --namespace flux-system reconcile kustomization coder
flux --namespace flux-system reconcile source chart coder-coder
flux --namespace flux-system reconcile source chart coder-coder-provisioner
flux --namespace coder reconcile helmrelease coder
flux --namespace coder reconcile helmrelease coder-provisioner
flux --namespace coder reconcile helmrelease coder-provisioner-tagged
flux --namespace coder reconcile helmrelease coder-provisioner-tagged-prebuilds
# Just updating Flux is usually not enough. The Helm release may get
# redeployed, but unless something causes the Deployment to update the
# pods won't be recreated. It's important that the pods get recreated,
# since we use `imagePullPolicy: Always` to ensure we're running the
# latest image.
- name: Rollout Deployment
run: |
set -euxo pipefail
kubectl --namespace coder rollout restart deployment/coder
kubectl --namespace coder rollout status deployment/coder
kubectl --namespace coder rollout restart deployment/coder-provisioner
kubectl --namespace coder rollout status deployment/coder-provisioner
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged-prebuilds
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged-prebuilds
deploy-wsproxies:
runs-on: ubuntu-latest
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
fetch-depth: 0
persist-credentials: false
- name: Setup flyctl
uses: superfly/flyctl-actions/setup-flyctl@fc53c09e1bc3be6f54706524e3b82c4f462f77be # v1.5
- name: Deploy workspace proxies
run: |
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
IMAGE: ${{ inputs.image }}
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
+4 -4
View File
@@ -38,17 +38,17 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: Docker login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -62,7 +62,7 @@ jobs:
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
with:
project: wl5hnrrkns
context: base-build-context
+2 -2
View File
@@ -23,14 +23,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: Setup Node
uses: ./.github/actions/setup-node
- uses: tj-actions/changed-files@dbf178ceecb9304128c8e0648591d71208c6e2c9 # v45.0.7
- uses: tj-actions/changed-files@f963b3f3562b00b6d2dd25efc390eb04e51ef6c6 # v45.0.7
id: changed-files
with:
files: |
+9 -9
View File
@@ -26,21 +26,21 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: Setup Nix
uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
uses: nixbuild/nix-quick-install-action@63ca48f939ee3b8d835f4126562537df0fee5b91 # v32
with:
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
# on version 2.29 and above.
nix_version: "2.28.5"
nix_version: "2.28.4"
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
@@ -82,13 +82,13 @@ jobs:
- name: Login to DockerHub
if: github.ref == 'refs/heads/main'
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and push Non-Nix image
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
with:
project: b4q6ltmpzh
token: ${{ secrets.DEPOT_TOKEN }}
@@ -125,12 +125,12 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
@@ -138,7 +138,7 @@ jobs:
uses: ./.github/actions/setup-tf
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
with:
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
+21 -6
View File
@@ -27,7 +27,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
@@ -53,7 +53,7 @@ jobs:
uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b # v0.1.0
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
@@ -143,7 +143,7 @@ jobs:
DB=ci gotestsum \
--format standard-quiet --packages "./..." \
-- -timeout=20m -v -p "$NUM_PARALLEL_PACKAGES" -parallel="$NUM_PARALLEL_TESTS" "$TESTCOUNT"
-- -timeout=20m -v -p $NUM_PARALLEL_PACKAGES -parallel=$NUM_PARALLEL_TESTS $TESTCOUNT
- name: Upload Embedded Postgres Cache
uses: ./.github/actions/embedded-pg-cache/upload
@@ -170,7 +170,6 @@ jobs:
steps:
- name: Send Slack notification
run: |
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
curl -X POST -H 'Content-type: application/json' \
--data '{
"blocks": [
@@ -182,6 +181,23 @@ jobs:
"emoji": true
}
},
{
"type": "section",
"fields": [
{
"type": "mrkdwn",
"text": "*Workflow:*\n'"${GITHUB_WORKFLOW}"'"
},
{
"type": "mrkdwn",
"text": "*Committer:*\n'"${GITHUB_ACTOR}"'"
},
{
"type": "mrkdwn",
"text": "*Commit:*\n'"${GITHUB_SHA}"'"
}
]
},
{
"type": "section",
"text": {
@@ -193,7 +209,7 @@ jobs:
"type": "section",
"text": {
"type": "mrkdwn",
"text": '"$ESCAPED_PROMPT"'
"text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame."
}
}
]
@@ -201,4 +217,3 @@ jobs:
env:
SLACK_WEBHOOK: ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }}
RUN_URL: "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
BLINK_CI_FAILURE_PROMPT: ${{ vars.BLINK_CI_FAILURE_PROMPT }}
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
+14 -15
View File
@@ -39,12 +39,12 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
@@ -76,12 +76,12 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -184,12 +184,12 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Find Comment
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: fc
with:
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
@@ -199,7 +199,7 @@ jobs:
- name: Comment on PR
id: comment_id
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
with:
comment-id: ${{ steps.fc.outputs.comment-id }}
issue-number: ${{ needs.get_info.outputs.PR_NUMBER }}
@@ -228,12 +228,12 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -248,7 +248,7 @@ jobs:
uses: ./.github/actions/setup-sqlc
- name: GHCR Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
@@ -337,7 +337,7 @@ jobs:
kubectl create namespace "pr${PR_NUMBER}"
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
@@ -370,7 +370,6 @@ jobs:
helm repo add bitnami https://charts.bitnami.com/bitnami
helm install coder-db bitnami/postgresql \
--namespace "pr${PR_NUMBER}" \
--set image.repository=bitnamilegacy/postgresql \
--set auth.username=coder \
--set auth.password=coder \
--set auth.database=coder \
@@ -491,7 +490,7 @@ jobs:
PASSWORD: ${{ steps.setup_deployment.outputs.password }}
- name: Find Comment
uses: peter-evans/find-comment@b30e6a3c0ed37e7c023ccd3f1db5c6c0b0c23aad # v4.0.0
uses: peter-evans/find-comment@3eae4d37986fb5a8592848f6a574fdf654e61f9e # v3.1.0
id: fc
with:
issue-number: ${{ env.PR_NUMBER }}
@@ -500,7 +499,7 @@ jobs:
direction: last
- name: Comment on PR
uses: peter-evans/create-or-update-comment@e8674b075228eee787fea43ef493e45ece1004c9 # v5.0.0
uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0
env:
STATUS: ${{ needs.get_info.outputs.NEW == 'true' && 'Created' || 'Updated' }}
with:
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
+24 -24
View File
@@ -37,7 +37,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Allow only maintainers/admins
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@v7.0.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -65,7 +65,7 @@ jobs:
steps:
# Harden Runner doesn't work on macOS.
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -131,7 +131,7 @@ jobs:
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
- name: Upload build artifacts
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: dylibs
path: |
@@ -164,12 +164,12 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -239,7 +239,7 @@ jobs:
cat "$CODER_RELEASE_NOTES_FILE"
- name: Docker Login
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -253,7 +253,7 @@ jobs:
# Necessary for signing Windows binaries.
- name: Setup Java
uses: actions/setup-java@dded0888837ed1f317902acf8a20df0ad188d165 # v5.0.0
uses: actions/setup-java@c5195efecf7bdfc987ee8bae7a71cb8b11521c00 # v4.7.1
with:
distribution: "zulu"
java-version: "11.0"
@@ -317,17 +317,17 @@ jobs:
# Setup GCloud for signing Windows binaries.
- name: Authenticate to Google Cloud
id: gcloud_auth
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
with:
workload_identity_provider: ${{ vars.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
token_format: "access_token"
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
- name: Download dylibs
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
with:
name: dylibs
path: ./build
@@ -397,7 +397,7 @@ jobs:
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
if: steps.image-base-tag.outputs.tag != ''
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
with:
project: wl5hnrrkns
context: base-build-context
@@ -454,7 +454,7 @@ jobs:
id: attest_base
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: ${{ steps.image-base-tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -570,7 +570,7 @@ jobs:
id: attest_main
if: ${{ !inputs.dry_run }}
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -614,7 +614,7 @@ jobs:
id: attest_latest
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
continue-on-error: true
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
with:
subject-name: ${{ steps.latest_tag.outputs.tag }}
predicate-type: "https://slsa.dev/provenance/v1"
@@ -734,13 +734,13 @@ jobs:
CREATED_LATEST_TAG: ${{ steps.build_docker.outputs.created_latest_tag }}
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
with:
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
- name: Setup GCloud SDK
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # 3.0.1
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # 2.2.0
- name: Publish Helm Chart
if: ${{ !inputs.dry_run }}
@@ -761,7 +761,7 @@ jobs:
- name: Upload artifacts to actions (if dry-run)
if: ${{ inputs.dry_run }}
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: release-artifacts
path: |
@@ -777,7 +777,7 @@ jobs:
- name: Upload latest sbom artifact to actions (if dry-run)
if: inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: latest-sbom-artifact
path: ./coder_latest_sbom.spdx.json
@@ -785,7 +785,7 @@ jobs:
- name: Send repository-dispatch event
if: ${{ !inputs.dry_run }}
uses: peter-evans/repository-dispatch@5fc4efd1a4797ddb68ffd0714a238564e4cc0e6f # v4.0.0
uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0
with:
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
repository: coder/packages
@@ -802,7 +802,7 @@ jobs:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
@@ -878,7 +878,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
@@ -888,7 +888,7 @@ jobs:
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -971,12 +971,12 @@ jobs:
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 1
persist-credentials: false
+5 -5
View File
@@ -20,17 +20,17 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: "Checkout code"
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
uses: ossf/scorecard-action@05b42c624433fc40578a4040d5cf5e36ddca8cde # v2.4.2
with:
results_file: results.sarif
results_format: sarif
@@ -39,7 +39,7 @@ jobs:
# Upload the results as artifacts.
- name: "Upload artifact"
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: SARIF file
path: results.sarif
@@ -47,6 +47,6 @@ jobs:
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
with:
sarif_file: results.sarif
+9 -9
View File
@@ -27,12 +27,12 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
@@ -40,7 +40,7 @@ jobs:
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/init@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
with:
languages: go, javascript
@@ -50,7 +50,7 @@ jobs:
rm Makefile
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/analyze@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
- name: Send Slack notification on failure
if: ${{ failure() }}
@@ -69,12 +69,12 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
fetch-depth: 0
persist-credentials: false
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
uses: aquasecurity/trivy-action@dc5a429b52fcf669ce959baa2c2dd26090d2a6c4
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
@@ -154,13 +154,13 @@ jobs:
severity: "CRITICAL,HIGH"
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@4e94bd11f71e507f7f87df81788dff88d1dacbfb # v3.29.5
uses: github/codeql-action/upload-sarif@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
with:
sarif_file: trivy-results.sarif
category: "Trivy"
- name: Upload Trivy scan results as an artifact
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: trivy
path: trivy-results.sarif
+8 -8
View File
@@ -18,12 +18,12 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: stale
uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
@@ -44,7 +44,7 @@ jobs:
# Start with the oldest issues, always.
ascending: true
- name: "Close old issues labeled likely-no"
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -96,12 +96,12 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: Run delete-old-branches-action
@@ -120,12 +120,12 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Delete PR Cleanup workflow runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
@@ -134,7 +134,7 @@ jobs:
delete_workflow_pattern: pr-cleanup.yaml
- name: Delete PR Deploy workflow skipped runs
uses: Mattraks/delete-workflow-runs@ab482449ba468316e9a8801e092d0405715c5e6d # v2.1.0
uses: Mattraks/delete-workflow-runs@39f0bbed25d76b34de5594dceab824811479e5de # v2.0.6
with:
token: ${{ github.token }}
repository: ${{ github.repository }}
-217
View File
@@ -1,217 +0,0 @@
name: AI Triage Automation
on:
issues:
types:
- labeled
workflow_dispatch:
inputs:
issue_url:
description: "GitHub Issue URL to process"
required: true
type: string
template_name:
description: "Coder template to use for workspace"
required: true
default: "coder"
type: string
template_preset:
description: "Template preset to use"
required: true
default: "none"
type: string
prefix:
description: "Prefix for workspace name"
required: false
default: "traiage"
type: string
jobs:
traiage:
name: Triage GitHub Issue with Claude Code
runs-on: ubuntu-latest
if: github.event.label.name == 'traiage' || github.event_name == 'workflow_dispatch'
timeout-minutes: 30
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
permissions:
contents: read
issues: write
actions: write
steps:
# This is only required for testing locally using nektos/act, so leaving commented out.
# An alternative is to use a larger or custom image.
# - name: Install Github CLI
# id: install-gh
# run: |
# (type -p wget >/dev/null || (sudo apt update && sudo apt install wget -y)) \
# && sudo mkdir -p -m 755 /etc/apt/keyrings \
# && out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
# && cat $out | sudo tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
# && sudo chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
# && sudo mkdir -p -m 755 /etc/apt/sources.list.d \
# && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | sudo tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
# && sudo apt update \
# && sudo apt install gh -y
- name: Determine Inputs
id: determine-inputs
if: always()
env:
GITHUB_ACTOR: ${{ github.actor }}
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'coder' }}
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'none'}}
INPUTS_PREFIX: ${{ inputs.prefix || 'traiage' }}
GH_TOKEN: ${{ github.token }}
run: |
echo "Using template name: ${INPUTS_TEMPLATE_NAME}"
echo "template_name=${INPUTS_TEMPLATE_NAME}" >> "${GITHUB_OUTPUT}"
echo "Using template preset: ${INPUTS_TEMPLATE_PRESET}"
echo "template_preset=${INPUTS_TEMPLATE_PRESET}" >> "${GITHUB_OUTPUT}"
echo "Using prefix: ${INPUTS_PREFIX}"
echo "prefix=${INPUTS_PREFIX}" >> "${GITHUB_OUTPUT}"
# For workflow_dispatch, use the actor who triggered it
# For issues events, use the issue author.
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
exit 1
fi
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
exit 0
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
exit 0
else
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
exit 1
fi
- name: Verify push access
env:
GITHUB_REPOSITORY: ${{ github.repository }}
GH_TOKEN: ${{ github.token }}
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
# Query the actors permission on this repo
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
if [[ "${can_push}" != "true" ]]; then
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
exit 1
fi
- name: Extract context key from issue
id: extract-context
env:
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
GH_TOKEN: ${{ github.token }}
run: |
issue_number="$(gh issue view "${ISSUE_URL}" --json number --jq '.number')"
context_key="gh-${issue_number}"
echo "context_key=${context_key}" >> "${GITHUB_OUTPUT}"
echo "CONTEXT_KEY=${context_key}" >> "${GITHUB_ENV}"
- name: Download and install Coder binary
shell: bash
env:
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
run: |
if [ "${{ runner.arch }}" == "ARM64" ]; then
ARCH="arm64"
else
ARCH="amd64"
fi
mkdir -p "${HOME}/.local/bin"
curl -fsSL --compressed "$CODER_URL/bin/coder-linux-${ARCH}" -o "${HOME}/.local/bin/coder"
chmod +x "${HOME}/.local/bin/coder"
export PATH="$HOME/.local/bin:$PATH"
coder version
coder whoami
echo "$HOME/.local/bin" >> "${GITHUB_PATH}"
- name: Get Coder username from GitHub actor
id: get-coder-username
env:
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
GH_TOKEN: ${{ github.token }}
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
run: |
user_json=$(
coder users list --github-user-id="${GITHUB_USER_ID}" --output=json
)
coder_username=$(jq -r 'first | .username' <<< "$user_json")
[[ -z "${coder_username}" || "${coder_username}" == "null" ]] && echo "No Coder user with GitHub user ID ${GITHUB_USER_ID} found" && exit 1
echo "coder_username=${coder_username}" >> "${GITHUB_OUTPUT}"
- name: Checkout repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
persist-credentials: false
fetch-depth: 0
# TODO(Cian): this is a good use-case for 'recipes'
- name: Create Coder task
id: create-task
env:
CODER_USERNAME: ${{ steps.get-coder-username.outputs.coder_username }}
CONTEXT_KEY: ${{ steps.extract-context.outputs.context_key }}
GH_TOKEN: ${{ github.token }}
GITHUB_REPOSITORY: ${{ github.repository }}
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
PREFIX: ${{ steps.determine-inputs.outputs.prefix }}
RUN_ID: ${{ github.run_id }}
TEMPLATE_NAME: ${{ steps.determine-inputs.outputs.template_name }}
TEMPLATE_PARAMETERS: ${{ secrets.TRAIAGE_TEMPLATE_PARAMETERS }}
TEMPLATE_PRESET: ${{ steps.determine-inputs.outputs.template_preset }}
run: |
# Fetch issue description using `gh` CLI
#shellcheck disable=SC2016 # The template string should not be subject to shell expansion
issue_description=$(gh issue view "${ISSUE_URL}" \
--json 'title,body,comments' \
--template '{{printf "%s\n\n%s\n\nComments:\n" .title .body}}{{range $k, $v := .comments}} - {{index $v.author "login"}}: {{printf "%s\n" $v.body}}{{end}}')
# Write a prompt to PROMPT_FILE
PROMPT=$(cat <<EOF
Fix ${ISSUE_URL}
Analyze the below GitHub issue description, understand the root cause, and make appropriate changes to resolve the issue.
---
${issue_description}
EOF
)
export PROMPT
export TASK_NAME="${PREFIX}-${CONTEXT_KEY}-${RUN_ID}"
echo "Creating task: $TASK_NAME"
./scripts/traiage.sh create
if [[ "${ISSUE_URL}" == "https://github.com/${GITHUB_REPOSITORY}"* ]]; then
gh issue comment "${ISSUE_URL}" --body "Task created: https://dev.coder.com/tasks/${CODER_USERNAME}/${TASK_NAME}" --create-if-none --edit-last
else
echo "Skipping comment on other repo."
fi
echo "TASK_NAME=${CODER_USERNAME}/${TASK_NAME}" >> "${GITHUB_OUTPUT}"
echo "TASK_NAME=${CODER_USERNAME}/${TASK_NAME}" >> "${GITHUB_ENV}"
-1
View File
@@ -1,6 +1,5 @@
[default]
extend-ignore-identifiers-re = ["gho_.*"]
extend-ignore-re = ["(#|//)\\s*spellchecker:ignore-next-line\\n.*"]
[default.extend-identifiers]
alog = "alog"
+3 -3
View File
@@ -21,17 +21,17 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
with:
persist-credentials: false
- name: Check Markdown links
uses: umbrelladocs/action-linkspector@652f85bc57bb1e7d4327260decc10aa68f7694c3 # v1.4.0
uses: umbrelladocs/action-linkspector@874d01cae9fd488e3077b08952093235bd626977 # v1.3.7
id: markdown-link-check
# checks all markdown files from /docs including all subfolders
with:
-4
View File
@@ -1,4 +0,0 @@
rules:
cache-poisoning:
ignore:
- "ci.yaml:184"
-5
View File
@@ -12,9 +12,6 @@ node_modules/
vendor/
yarn-error.log
# Test output files
test-output/
# VSCode settings.
**/.vscode/*
# Allow VSCode recommendations and default settings in project root.
@@ -89,5 +86,3 @@ result
__debug_bin*
**/.claude/settings.local.json
/.env
+1 -11
View File
@@ -169,16 +169,6 @@ linters-settings:
- name: var-declaration
- name: var-naming
- name: waitgroup-by-value
usetesting:
# Only os-setenv is enabled because we migrated to usetesting from another linter that
# only covered os-setenv.
os-setenv: true
os-create-temp: false
os-mkdir-temp: false
os-temp-dir: false
os-chdir: false
context-background: false
context-todo: false
# irrelevant as of Go v1.22: https://go.dev/blog/loopvar-preview
govet:
@@ -262,6 +252,7 @@ linters:
# - wastedassign
- staticcheck
- tenv
# In Go, it's possible for a package to test it's internal functionality
# without testing any exported functions. This is enabled to promote
# decomposing a package before testing it's internals. A function caller
@@ -274,5 +265,4 @@ linters:
- typecheck
- unconvert
- unused
- usetesting
- dupl
+1 -3
View File
@@ -54,13 +54,11 @@
}
},
"tailwindCSS.classFunctions": ["cva", "cn"],
"[css][html][markdown][yaml]": {
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
"typos.config": ".github/workflows/typos.toml",
"[markdown]": {
"editor.defaultFormatter": "DavidAnson.vscode-markdownlint"
},
"biome.lsp.bin": "site/node_modules/.bin/biome"
}
}
+19 -40
View File
@@ -1,41 +1,11 @@
# Coder Development Guidelines
You are an experienced, pragmatic software engineer. You don't over-engineer a solution when a simple one is possible.
Rule #1: If you want exception to ANY rule, YOU MUST STOP and get explicit permission first. BREAKING THE LETTER OR SPIRIT OF THE RULES IS FAILURE.
## Foundational rules
- Doing it right is better than doing it fast. You are not in a rush. NEVER skip steps or take shortcuts.
- Tedious, systematic work is often the correct solution. Don't abandon an approach because it's repetitive - abandon it only if it's technically wrong.
- Honesty is a core value.
## Our relationship
- Act as a critical peer reviewer. Your job is to disagree with me when I'm wrong, not to please me. Prioritize accuracy and reasoning over agreement.
- YOU MUST speak up immediately when you don't know something or we're in over our heads
- YOU MUST call out bad ideas, unreasonable expectations, and mistakes - I depend on this
- NEVER be agreeable just to be nice - I NEED your HONEST technical judgment
- NEVER write the phrase "You're absolutely right!" You are not a sycophant. We're working together because I value your opinion. Do not agree with me unless you can justify it with evidence or reasoning.
- YOU MUST ALWAYS STOP and ask for clarification rather than making assumptions.
- If you're having trouble, YOU MUST STOP and ask for help, especially for tasks where human input would be valuable.
- When you disagree with my approach, YOU MUST push back. Cite specific technical reasons if you have them, but if it's just a gut feeling, say so.
- If you're uncomfortable pushing back out loud, just say "Houston, we have a problem". I'll know what you mean
- We discuss architectutral decisions (framework changes, major refactoring, system design) together before implementation. Routine fixes and clear implementations don't need discussion.
## Proactiveness
When asked to do something, just do it - including obvious follow-up actions needed to complete the task properly.
Only pause to ask for confirmation when:
- Multiple valid approaches exist and the choice matters
- The action would delete or significantly restructure existing code
- You genuinely don't understand what's being asked
- Your partner asked a question (answer the question, don't jump to implementation)
@.claude/docs/WORKFLOWS.md
@.cursorrules
@README.md
@package.json
## Essential Commands
## 🚀 Essential Commands
| Task | Command | Notes |
|-------------------|--------------------------|----------------------------------|
@@ -51,13 +21,22 @@ Only pause to ask for confirmation when:
| **Format** | `make fmt` | Auto-format code |
| **Clean** | `make clean` | Clean build artifacts |
### Frontend Commands (site directory)
- `pnpm build` - Build frontend
- `pnpm dev` - Run development server
- `pnpm check` - Run code checks
- `pnpm format` - Format frontend code
- `pnpm lint` - Lint frontend code
- `pnpm test` - Run frontend tests
### Documentation Commands
- `pnpm run format-docs` - Format markdown tables in docs
- `pnpm run lint-docs` - Lint and fix markdown files
- `pnpm run storybook` - Run Storybook (from site directory)
## Critical Patterns
## 🔧 Critical Patterns
### Database Changes (ALWAYS FOLLOW)
@@ -99,7 +78,7 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
```
## Quick Reference
## 📋 Quick Reference
### Full workflows available in imported WORKFLOWS.md
@@ -109,14 +88,14 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
- [ ] Check if feature touches database - you'll need migrations
- [ ] Check if feature touches audit logs - update `enterprise/audit/table.go`
## Architecture
## 🏗️ Architecture
- **coderd**: Main API service
- **provisionerd**: Infrastructure provisioning
- **Agents**: Workspace services (SSH, port forwarding)
- **Database**: PostgreSQL with `dbauthz` authorization
## Testing
## 🧪 Testing
### Race Condition Prevention
@@ -133,21 +112,21 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
NEVER use `time.Sleep` to mitigate timing issues. If an issue
seems like it should use `time.Sleep`, read through https://github.com/coder/quartz and specifically the [README](https://github.com/coder/quartz/blob/main/README.md) to better understand how to handle timing issues.
## Code Style
## 🎯 Code Style
### Detailed guidelines in imported WORKFLOWS.md
- Follow [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- Commit format: `type(scope): message`
## Detailed Development Guides
## 📚 Detailed Development Guides
@.claude/docs/OAUTH2.md
@.claude/docs/TESTING.md
@.claude/docs/TROUBLESHOOTING.md
@.claude/docs/DATABASE.md
## Common Pitfalls
## 🚨 Common Pitfalls
1. **Audit table errors** → Update `enterprise/audit/table.go`
2. **OAuth2 errors** → Return RFC-compliant format
+12
View File
@@ -18,6 +18,18 @@ coderd/rbac/ @Emyrk
scripts/apitypings/ @Emyrk
scripts/gensite/ @aslilac
site/ @aslilac @Parkreiner
site/src/hooks/ @Parkreiner
# These rules intentionally do not specify any owners. More specific rules
# override less specific rules, so these files are "ignored" by the site/ rule.
site/e2e/google/protobuf/timestampGenerated.ts
site/e2e/provisionerGenerated.ts
site/src/api/countriesGenerated.ts
site/src/api/rbacresourcesGenerated.ts
site/src/api/typesGenerated.ts
site/src/testHelpers/entities.ts
site/CLAUDE.md
# The blood and guts of the autostop algorithm, which is quite complex and
# requires elite ball knowledge of most of the scheduling code to make changes
# without inadvertently affecting other parts of the codebase.
+2 -55
View File
@@ -561,7 +561,7 @@ endif
# Note: we don't run zizmor in the lint target because it takes a while. CI
# runs it explicitly.
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint
.PHONY: lint
lint/site-icons:
@@ -614,11 +614,6 @@ lint/actions/zizmor:
.
.PHONY: lint/actions/zizmor
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
lint/check-scopes: coderd/database/dump.sql
go run ./scripts/check-scopes
.PHONY: lint/check-scopes
# All files generated by the database should be added here, and this can be used
# as a target for jobs that need to run after the database is generated.
DB_GEN_FILES := \
@@ -635,23 +630,16 @@ TAILNETTEST_MOCKS := \
tailnet/tailnettest/workspaceupdatesprovidermock.go \
tailnet/tailnettest/subscriptionmock.go
AIBRIDGED_MOCKS := \
enterprise/aibridged/aibridgedmock/clientmock.go \
enterprise/aibridged/aibridgedmock/poolmock.go
GEN_FILES := \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
$(DB_GEN_FILES) \
$(SITE_GEN_FILES) \
coderd/rbac/object_gen.go \
codersdk/rbacresources_gen.go \
coderd/rbac/scopes_constants_gen.go \
codersdk/apikey_scopes_gen.go \
docs/admin/integrations/prometheus.md \
docs/reference/cli/index.md \
docs/admin/security/audit-logs.md \
@@ -665,8 +653,7 @@ GEN_FILES := \
agent/agentcontainers/acmock/acmock.go \
agent/agentcontainers/dcspec/dcspec_gen.go \
coderd/httpmw/loggermw/loggermock/loggermock.go \
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
$(AIBRIDGED_MOCKS)
codersdk/workspacesdk/agentconnmock/agentconnmock.go
# all gen targets should be added here and to gen/mark-fresh
gen: gen/db gen/golden-files $(GEN_FILES)
@@ -676,7 +663,6 @@ gen/db: $(DB_GEN_FILES)
.PHONY: gen/db
gen/golden-files: \
agent/unit/testdata/.gen-golden \
cli/testdata/.gen-golden \
coderd/.gen-golden \
coderd/notifications/.gen-golden \
@@ -697,13 +683,11 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
coderd/rbac/object_gen.go \
codersdk/rbacresources_gen.go \
coderd/rbac/scopes_constants_gen.go \
site/src/api/rbacresourcesGenerated.ts \
site/src/api/countriesGenerated.ts \
docs/admin/integrations/prometheus.md \
@@ -720,7 +704,6 @@ gen/mark-fresh:
agent/agentcontainers/dcspec/dcspec_gen.go \
coderd/httpmw/loggermw/loggermock/loggermock.go \
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
$(AIBRIDGED_MOCKS) \
"
for file in $$files; do
@@ -768,10 +751,6 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
go generate ./codersdk/workspacesdk/agentconnmock/
touch "$@"
$(AIBRIDGED_MOCKS): enterprise/aibridged/client.go enterprise/aibridged/pool.go
go generate ./enterprise/aibridged/aibridgedmock/
touch "$@"
agent/agentcontainers/dcspec/dcspec_gen.go: \
node_modules/.installed \
agent/agentcontainers/dcspec/devContainer.base.schema.json \
@@ -822,14 +801,6 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
--go-drpc_out=. \
--go-drpc_opt=paths=source_relative \
./enterprise/aibridged/proto/aibridged.proto
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
@@ -856,15 +827,6 @@ coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/mai
rmdir -v "$$tempdir"
touch "$@"
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go
# Generate typed low-level ScopeName constants from RBACPermissions
# Write to a temp file first to avoid truncating the package during build
# since the generator imports the rbac package.
tempfile=$(shell mktemp /tmp/scopes_constants_gen.XXXXXX)
go run ./scripts/typegen/main.go rbac scopenames > "$$tempfile"
mv -v "$$tempfile" coderd/rbac/scopes_constants_gen.go
touch "$@"
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
# Do no overwrite codersdk/rbacresources_gen.go directly, as it would make the file empty, breaking
# the `codersdk` package and any parallel build targets.
@@ -872,12 +834,6 @@ codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/m
mv /tmp/rbacresources_gen.go codersdk/rbacresources_gen.go
touch "$@"
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go
# Generate SDK constants for external API key scopes.
go run ./scripts/apikeyscopesgen > /tmp/apikey_scopes_gen.go
mv /tmp/apikey_scopes_gen.go codersdk/apikey_scopes_gen.go
touch "$@"
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
@@ -953,10 +909,6 @@ clean/golden-files:
-type f -name '*.golden' -delete
.PHONY: clean/golden-files
agent/unit/testdata/.gen-golden: $(wildcard agent/unit/testdata/*.golden) $(GO_SRC_FILES) $(wildcard agent/unit/*_test.go)
TZ=UTC go test ./agent/unit -run="TestGraph" -update
touch "$@"
cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) $(wildcard cli/*_test.go)
TZ=UTC go test ./cli -run="Test(CommandHelp|ServerYAML|ErrorExamples|.*Golden)" -update
touch "$@"
@@ -1182,8 +1134,3 @@ endif
dogfood/coder/nix.hash: flake.nix flake.lock
sha256sum flake.nix flake.lock >./dogfood/coder/nix.hash
# Count the number of test databases created per test package.
count-test-databases:
PGPASSWORD=postgres psql -h localhost -U postgres -d coder_testing -P pager=off -c 'SELECT test_package, count(*) as count from test_databases GROUP BY test_package ORDER BY count DESC'
.PHONY: count-test-databases
+16 -6
View File
@@ -74,6 +74,7 @@ type Options struct {
LogDir string
TempDir string
ScriptDataDir string
ExchangeToken func(ctx context.Context) (string, error)
Client Client
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
@@ -98,7 +99,6 @@ type Client interface {
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
)
tailnet.DERPMapRewriter
agentsdk.RefreshableSessionTokenProvider
}
type Agent interface {
@@ -131,6 +131,11 @@ func New(options Options) Agent {
}
options.ScriptDataDir = options.TempDir
}
if options.ExchangeToken == nil {
options.ExchangeToken = func(_ context.Context) (string, error) {
return "", nil
}
}
if options.ReportMetadataInterval == 0 {
options.ReportMetadataInterval = time.Second
}
@@ -167,6 +172,7 @@ func New(options Options) Agent {
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
@@ -197,6 +203,7 @@ func New(options Options) Agent {
// coordinator during shut down.
close(a.coordDisconnected)
a.announcementBanners.Store(new([]codersdk.BannerConfig))
a.sessionToken.Store(new(string))
a.init()
return a
}
@@ -205,6 +212,7 @@ type agent struct {
clock quartz.Clock
logger slog.Logger
client Client
exchangeToken func(ctx context.Context) (string, error)
tailnetListenPort uint16
filesystem afero.Fs
logDir string
@@ -246,6 +254,7 @@ type agent struct {
scriptRunner *agentscripts.Runner
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
announcementBannersRefreshInterval time.Duration
sessionToken atomic.Pointer[string]
sshServer *agentssh.Server
sshMaxTimeout time.Duration
blockFileTransfer bool
@@ -785,7 +794,7 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC
// log a warning.
// Related to https://github.com/coder/coder/issues/20194
logger.Warn(ctx, "failed to report connection to server", slog.Error(err))
// keep going, we still need to remove it from the slice
// no continue here, we still need to remove it from the slice
} else {
logger.Debug(ctx, "successfully reported connection")
}
@@ -918,10 +927,11 @@ func (a *agent) run() (retErr error) {
// This allows the agent to refresh its token if necessary.
// For instance identity this is required, since the instance
// may not have re-provisioned, but a new agent ID was created.
err := a.client.RefreshToken(a.hardCtx)
sessionToken, err := a.exchangeToken(a.hardCtx)
if err != nil {
return xerrors.Errorf("refresh token: %w", err)
return xerrors.Errorf("exchange token: %w", err)
}
a.sessionToken.Store(&sessionToken)
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
@@ -1087,7 +1097,7 @@ func (a *agent) handleManifest(manifestOK *checkpoint) func(ctx context.Context,
if err != nil {
return xerrors.Errorf("fetch metadata: %w", err)
}
a.logger.Info(ctx, "fetched manifest")
a.logger.Info(ctx, "fetched manifest", slog.F("manifest", mp))
manifest, err := agentsdk.ManifestFromProto(mp)
if err != nil {
a.logger.Critical(ctx, "failed to convert manifest", slog.F("manifest", mp), slog.Error(err))
@@ -1360,7 +1370,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error)
"CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName,
// Specific Coder subcommands require the agent token exposed!
"CODER_AGENT_TOKEN": a.client.GetSessionToken(),
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
+57 -92
View File
@@ -22,6 +22,7 @@ import (
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@@ -1807,12 +1808,11 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
//nolint:dogsled
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
idConnectionReport := uuid.New()
id := uuid.New()
// Test that the connection is reported. This must be tested in the
// first connection because we care about verifying all of these.
netConn0, err := conn.ReconnectingPTY(ctx, idConnectionReport, 80, 80, "bash --norc")
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
_ = netConn0.Close()
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
@@ -2028,8 +2028,7 @@ func runSubAgentMain() int {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
req = req.WithContext(ctx)
client := &http.Client{}
resp, err := client.Do(req)
resp, err := http.DefaultClient.Do(req)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "agent connection failed: %v\n", err)
return 11
@@ -2927,11 +2926,11 @@ func TestAgent_Speedtest(t *testing.T) {
func TestAgent_Reconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
// After the agent is disconnected from a coordinator, it's supposed
// to reconnect!
fCoordinator := tailnettest.NewFakeCoordinator()
coordinator := tailnet.NewCoordinator(logger)
defer coordinator.Close()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
@@ -2943,24 +2942,27 @@ func TestAgent_Reconnect(t *testing.T) {
DERPMap: derpMap,
},
statsCh,
fCoordinator,
coordinator,
)
defer client.Close()
initialized := atomic.Int32{}
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
initialized.Add(1)
return "", nil
},
Client: client,
Logger: logger.Named("agent"),
})
defer closer.Close()
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
close(call1.Resps) // hang up
// expect reconnect
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
// Check that the agent refreshes the token when it reconnects.
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
closer.Close()
require.Eventually(t, func() bool {
return coordinator.Node(agentID) != nil
}, testutil.WaitShort, testutil.IntervalFast)
client.LastWorkspaceAgent()
require.Eventually(t, func() bool {
return initialized.Load() == 2
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
@@ -2982,6 +2984,9 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
defer client.Close()
filesystem := afero.NewMemMapFs()
closer := agent.New(agent.Options{
ExchangeToken: func(ctx context.Context) (string, error) {
return "", nil
},
Client: client,
Logger: logger.Named("agent"),
Filesystem: filesystem,
@@ -3010,6 +3015,9 @@ func TestAgent_DebugServer(t *testing.T) {
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
DERPMap: derpMap,
}, 0, func(c *agenttest.Client, o *agent.Options) {
o.ExchangeToken = func(context.Context) (string, error) {
return "token", nil
}
o.LogDir = logDir
})
@@ -3462,7 +3470,11 @@ func TestAgent_Metrics_SSH(t *testing.T) {
registry := prometheus.NewRegistry()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{
// Make sure we always get a DERP connection for
// currently_reachable_peers.
DisableDirectConnections: true,
}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.PrometheusRegistry = registry
})
@@ -3477,31 +3489,16 @@ func TestAgent_Metrics_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
expected := []struct {
Name string
Type proto.Stats_Metric_Type
CheckFn func(float64) error
Labels []*proto.Stats_Metric_Label
}{
expected := []*proto.Stats_Metric{
{
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 1 {
return nil
}
return xerrors.Errorf("expected 1, got %f", v)
},
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
{
Name: "magic_type",
@@ -3514,44 +3511,24 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
CheckFn: func(v float64) error {
if v == 0 {
return nil
}
return xerrors.Errorf("expected 0, got %f", v)
},
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
Value: 0,
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
// We can't reliably ping a peer here, and networking is out of
// scope of this test, so we just test that the metric exists
// with the correct labels.
return nil
},
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
@@ -3560,11 +3537,9 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(float64) error {
return nil
},
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 0,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
@@ -3573,20 +3548,9 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
},
{
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
CheckFn: func(f float64) error {
if f >= 0 {
return nil
}
return xerrors.Errorf("expected >= 0, got %f", f)
},
Labels: []*proto.Stats_Metric_Label{
{
Name: "success",
Value: "true",
},
},
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
},
}
@@ -3608,10 +3572,11 @@ func TestAgent_Metrics_SSH(t *testing.T) {
for _, m := range mf.GetMetric() {
assert.Equal(t, expected[i].Name, mf.GetName())
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
// Value is max expected
if expected[i].Type == proto.Stats_Metric_GAUGE {
assert.NoError(t, expected[i].CheckFn(m.GetGauge().GetValue()), "check fn for %s failed", expected[i].Name)
assert.GreaterOrEqualf(t, expected[i].Value, m.GetGauge().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetGauge().GetValue())
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
assert.NoError(t, expected[i].CheckFn(m.GetCounter().GetValue()), "check fn for %s failed", expected[i].Name)
assert.GreaterOrEqualf(t, expected[i].Value, m.GetCounter().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetCounter().GetValue())
}
for j, lbl := range expected[i].Labels {
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
+2
View File
@@ -682,6 +682,8 @@ func (api *API) updaterLoop() {
} else {
prevErr = nil
}
default:
api.logger.Debug(api.ctx, "updater loop ticker skipped, update in progress")
}
return nil // Always nil to keep the ticker going.
+9 -1
View File
@@ -1,6 +1,7 @@
package agenttest
import (
"context"
"net/url"
"testing"
@@ -30,11 +31,18 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
}
if o.Client == nil {
agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken))
agentClient := agentsdk.New(coderURL)
agentClient.SetSessionToken(agentToken)
agentClient.SDK.SetLogger(log)
o.Client = agentClient
}
if o.ExchangeToken == nil {
o.ExchangeToken = func(_ context.Context) (string, error) {
return agentToken, nil
}
}
if o.LogDir == "" {
o.LogDir = t.TempDir()
}
+4 -30
View File
@@ -3,7 +3,6 @@ package agenttest
import (
"context"
"io"
"net/http"
"slices"
"sync"
"sync/atomic"
@@ -29,7 +28,6 @@ import (
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
const statsInterval = 500 * time.Millisecond
@@ -88,34 +86,10 @@ type Client struct {
fakeAgentAPI *FakeAgentAPI
LastWorkspaceAgent func()
mu sync.Mutex // Protects following.
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
refreshTokenCalls int
}
func (*Client) AsRequestOption() codersdk.RequestOption {
return func(_ *http.Request) {}
}
func (*Client) SetDialOption(*websocket.DialOptions) {}
func (*Client) GetSessionToken() string {
return "agenttest-token"
}
func (c *Client) RefreshToken(context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
c.refreshTokenCalls++
return nil
}
func (c *Client) GetNumRefreshTokenCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.refreshTokenCalls
mu sync.Mutex // Protects following.
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
}
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
-3
View File
@@ -60,9 +60,6 @@ func (a *agent) apiHandler() http.Handler {
r.Get("/api/v0/listening-ports", lp.handler)
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Post("/api/v0/list-directory", a.HandleLS)
r.Get("/api/v0/read-file", a.HandleReadFile)
r.Post("/api/v0/write-file", a.HandleWriteFile)
r.Post("/api/v0/edit-files", a.HandleEditFiles)
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)
+1 -2
View File
@@ -63,7 +63,6 @@ func NewAppHealthReporterWithClock(
// run a ticker for each app health check.
var mu sync.RWMutex
failures := make(map[uuid.UUID]int, 0)
client := &http.Client{}
for _, nextApp := range apps {
if !shouldStartTicker(nextApp) {
continue
@@ -92,7 +91,7 @@ func NewAppHealthReporterWithClock(
if err != nil {
return err
}
res, err := client.Do(req)
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
-273
View File
@@ -1,273 +0,0 @@
package agent
import (
"context"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strconv"
"syscall"
"github.com/icholy/replace"
"github.com/spf13/afero"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type HTTPResponseCode = int
func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
offset := parser.PositiveInt64(query, 0, "offset")
limit := parser.PositiveInt64(query, 0, "limit")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
status, err := a.streamFile(ctx, rw, path, offset, limit)
if err != nil {
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: err.Error(),
})
return
}
}
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}
f, err := a.filesystem.Open(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrNotExist):
status = http.StatusNotFound
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
}
return status, err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return http.StatusInternalServerError, err
}
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}
size := stat.Size()
if limit == 0 {
limit = size
}
bytesRemaining := max(size-offset, 0)
bytesToRead := min(bytesRemaining, limit)
// Relying on just the file name for the mime type for now.
mimeType := mime.TypeByExtension(filepath.Ext(path))
if mimeType == "" {
mimeType = "application/octet-stream"
}
rw.Header().Set("Content-Type", mimeType)
rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10))
rw.WriteHeader(http.StatusOK)
reader := io.NewSectionReader(f, offset, bytesToRead)
_, err = io.Copy(rw, reader)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
a.logger.Error(ctx, "workspace agent read file", slog.Error(err))
}
return 0, nil
}
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
status, err := a.writeFile(ctx, r, path)
if err != nil {
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf("Successfully wrote to %q", path),
})
}
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (HTTPResponseCode, error) {
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}
dir := filepath.Dir(path)
err := a.filesystem.MkdirAll(dir, 0o755)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.ENOTDIR):
status = http.StatusBadRequest
}
return status, err
}
f, err := a.filesystem.Create(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
case errors.Is(err, syscall.EISDIR):
status = http.StatusBadRequest
}
return status, err
}
defer f.Close()
_, err = io.Copy(f, r.Body)
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
}
return 0, nil
}
func (a *agent) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.FileEditRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if len(req.Files) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "must specify at least one file",
})
return
}
var combinedErr error
status := http.StatusOK
for _, edit := range req.Files {
s, err := a.editFile(r.Context(), edit.Path, edit.Edits)
// Keep the highest response status, so 500 will be preferred over 400, etc.
if s > status {
status = s
}
if err != nil {
combinedErr = errors.Join(combinedErr, err)
}
}
if combinedErr != nil {
httpapi.Write(ctx, rw, status, codersdk.Response{
Message: combinedErr.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: "Successfully edited file(s)",
})
}
func (a *agent) editFile(ctx context.Context, path string, edits []workspacesdk.FileEdit) (int, error) {
if path == "" {
return http.StatusBadRequest, xerrors.New("\"path\" is required")
}
if !filepath.IsAbs(path) {
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
}
if len(edits) == 0 {
return http.StatusBadRequest, xerrors.New("must specify at least one edit")
}
f, err := a.filesystem.Open(path)
if err != nil {
status := http.StatusInternalServerError
switch {
case errors.Is(err, os.ErrNotExist):
status = http.StatusNotFound
case errors.Is(err, os.ErrPermission):
status = http.StatusForbidden
}
return status, err
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return http.StatusInternalServerError, err
}
if stat.IsDir() {
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}
transforms := make([]transform.Transformer, len(edits))
for i, edit := range edits {
transforms[i] = replace.String(edit.Search, edit.Replace)
}
tmpfile, err := afero.TempFile(a.filesystem, "", filepath.Base(path))
if err != nil {
return http.StatusInternalServerError, err
}
defer tmpfile.Close()
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
if err != nil {
if rerr := a.filesystem.Remove(tmpfile.Name()); rerr != nil {
a.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
}
err = a.filesystem.Rename(tmpfile.Name(), path)
if err != nil {
return http.StatusInternalServerError, err
}
return 0, nil
}
-722
View File
@@ -1,722 +0,0 @@
package agent_test
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"syscall"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
type testFs struct {
afero.Fs
// intercept can return an error for testing when a call fails.
intercept func(call, file string) error
}
func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs {
return &testFs{
Fs: base,
intercept: intercept,
}
}
func (fs *testFs) Open(name string) (afero.File, error) {
if err := fs.intercept("open", name); err != nil {
return nil, err
}
return fs.Fs.Open(name)
}
func (fs *testFs) Create(name string) (afero.File, error) {
if err := fs.intercept("create", name); err != nil {
return nil, err
}
// Unlike os, afero lets you create files where directories already exist and
// lets you nest them underneath files, somehow.
stat, err := fs.Fs.Stat(name)
if err == nil && stat.IsDir() {
return nil, &os.PathError{
Op: "open",
Path: name,
Err: syscall.EISDIR,
}
}
stat, err = fs.Fs.Stat(filepath.Dir(name))
if err == nil && !stat.IsDir() {
return nil, &os.PathError{
Op: "open",
Path: name,
Err: syscall.ENOTDIR,
}
}
return fs.Fs.Create(name)
}
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
if err := fs.intercept("mkdirall", name); err != nil {
return err
}
// Unlike os, afero lets you create directories where files already exist and
// lets you nest them underneath files somehow.
stat, err := fs.Fs.Stat(filepath.Dir(name))
if err == nil && !stat.IsDir() {
return &os.PathError{
Op: "mkdir",
Path: name,
Err: syscall.ENOTDIR,
}
}
stat, err = fs.Fs.Stat(name)
if err == nil && !stat.IsDir() {
return &os.PathError{
Op: "mkdir",
Path: name,
Err: syscall.ENOTDIR,
}
}
return fs.Fs.MkdirAll(name, mode)
}
func (fs *testFs) Rename(oldName, newName string) error {
if err := fs.intercept("rename", newName); err != nil {
return err
}
return fs.Fs.Rename(oldName, newName)
}
func TestReadFile(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms")
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
if file == noPermsFilePath {
return os.ErrPermission
}
return nil
})
})
dirPath := filepath.Join(tmpdir, "a-directory")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
filePath := filepath.Join(tmpdir, "file")
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
require.NoError(t, err)
imagePath := filepath.Join(tmpdir, "file.png")
err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644)
require.NoError(t, err)
tests := []struct {
name string
path string
limit int64
offset int64
bytes []byte
mimeType string
errCode int
error string
}{
{
name: "NoPath",
path: "",
errCode: http.StatusBadRequest,
error: "\"path\" is required",
},
{
name: "RelativePathDotSlash",
path: "./relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "RelativePath",
path: "also-relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "NegativeLimit",
path: filePath,
limit: -10,
errCode: http.StatusBadRequest,
error: "value is negative",
},
{
name: "NegativeOffset",
path: filePath,
offset: -10,
errCode: http.StatusBadRequest,
error: "value is negative",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "does-not-exist"),
errCode: http.StatusNotFound,
error: "file does not exist",
},
{
name: "IsDir",
path: dirPath,
errCode: http.StatusBadRequest,
error: "not a file",
},
{
name: "NoPermissions",
path: noPermsFilePath,
errCode: http.StatusForbidden,
error: "permission denied",
},
{
name: "Defaults",
path: filePath,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Limit1",
path: filePath,
limit: 1,
bytes: []byte("c"),
mimeType: "application/octet-stream",
},
{
name: "Offset1",
path: filePath,
offset: 1,
bytes: []byte("ontent"),
mimeType: "application/octet-stream",
},
{
name: "Limit1Offset2",
path: filePath,
limit: 1,
offset: 2,
bytes: []byte("n"),
mimeType: "application/octet-stream",
},
{
name: "Limit7Offset0",
path: filePath,
limit: 7,
offset: 0,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Limit100",
path: filePath,
limit: 100,
bytes: []byte("content"),
mimeType: "application/octet-stream",
},
{
name: "Offset7",
path: filePath,
offset: 7,
bytes: []byte{},
mimeType: "application/octet-stream",
},
{
name: "Offset100",
path: filePath,
offset: 100,
bytes: []byte{},
mimeType: "application/octet-stream",
},
{
name: "MimeTypePng",
path: imagePath,
bytes: []byte("not really an image"),
mimeType: "image/png",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
reader, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit)
if tt.errCode != 0 {
require.Error(t, err)
cerr := coderdtest.SDKError(t, err)
require.Contains(t, cerr.Error(), tt.error)
require.Equal(t, tt.errCode, cerr.StatusCode())
} else {
require.NoError(t, err)
defer reader.Close()
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
require.Equal(t, tt.bytes, bytes)
require.Equal(t, tt.mimeType, mimeType)
}
})
}
}
func TestWriteFile(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
if file == noPermsFilePath || file == noPermsDirPath {
return os.ErrPermission
}
return nil
})
})
dirPath := filepath.Join(tmpdir, "directory")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
filePath := filepath.Join(tmpdir, "file")
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
require.NoError(t, err)
notDirErr := "not a directory"
if runtime.GOOS == "windows" {
notDirErr = "cannot find the path"
}
tests := []struct {
name string
path string
bytes []byte
errCode int
error string
}{
{
name: "NoPath",
path: "",
errCode: http.StatusBadRequest,
error: "\"path\" is required",
},
{
name: "RelativePathDotSlash",
path: "./relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "RelativePath",
path: "also-relative",
errCode: http.StatusBadRequest,
error: "file path must be absolute",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
bytes: []byte("now it does exist"),
},
{
name: "IsDir",
path: dirPath,
errCode: http.StatusBadRequest,
error: "is a directory",
},
{
name: "IsNotDir",
path: filepath.Join(filePath, "file2"),
errCode: http.StatusBadRequest,
error: notDirErr,
},
{
name: "NoPermissionsFile",
path: noPermsFilePath,
errCode: http.StatusForbidden,
error: "permission denied",
},
{
name: "NoPermissionsDir",
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
errCode: http.StatusForbidden,
error: "permission denied",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
reader := bytes.NewReader(tt.bytes)
err := conn.WriteFile(ctx, tt.path, reader)
if tt.errCode != 0 {
require.Error(t, err)
cerr := coderdtest.SDKError(t, err)
require.Contains(t, cerr.Error(), tt.error)
require.Equal(t, tt.errCode, cerr.StatusCode())
} else {
require.NoError(t, err)
b, err := afero.ReadFile(fs, tt.path)
require.NoError(t, err)
require.Equal(t, tt.bytes, b)
}
})
}
}
func TestEditFiles(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
failRenameFilePath := filepath.Join(tmpdir, "fail-rename")
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
if file == noPermsFilePath {
return &os.PathError{
Op: call,
Path: file,
Err: os.ErrPermission,
}
} else if file == failRenameFilePath && call == "rename" {
return xerrors.New("rename failed")
}
return nil
})
})
dirPath := filepath.Join(tmpdir, "directory")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
tests := []struct {
name string
contents map[string]string
edits []workspacesdk.FileEdits
expected map[string]string
errCode int
errors []string
}{
{
name: "NoFiles",
errCode: http.StatusBadRequest,
errors: []string{"must specify at least one file"},
},
{
name: "NoPath",
errCode: http.StatusBadRequest,
edits: []workspacesdk.FileEdits{
{
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errors: []string{"\"path\" is required"},
},
{
name: "RelativePathDotSlash",
edits: []workspacesdk.FileEdits{
{
Path: "./relative",
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"file path must be absolute"},
},
{
name: "RelativePath",
edits: []workspacesdk.FileEdits{
{
Path: "also-relative",
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"file path must be absolute"},
},
{
name: "NoEdits",
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "no-edits"),
},
},
errCode: http.StatusBadRequest,
errors: []string{"must specify at least one edit"},
},
{
name: "NonExistent",
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "does-not-exist"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusNotFound,
errors: []string{"file does not exist"},
},
{
name: "IsDir",
edits: []workspacesdk.FileEdits{
{
Path: dirPath,
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusBadRequest,
errors: []string{"not a file"},
},
{
name: "NoPermissions",
edits: []workspacesdk.FileEdits{
{
Path: noPermsFilePath,
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusForbidden,
errors: []string{"permission denied"},
},
{
name: "FailRename",
contents: map[string]string{failRenameFilePath: "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: failRenameFilePath,
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
errCode: http.StatusInternalServerError,
errors: []string{"rename failed"},
},
{
name: "Edit1",
contents: map[string]string{filepath.Join(tmpdir, "edit1"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit1"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
},
{
name: "EditEdit", // Edits affect previous edits.
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "edit-edit"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo",
Replace: "bar",
},
{
Search: "bar",
Replace: "qux",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
},
{
name: "Multiline",
contents: map[string]string{filepath.Join(tmpdir, "multiline"): "foo\nbar\nbaz\nqux"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "multiline"),
Edits: []workspacesdk.FileEdit{
{
Search: "bar\nbaz",
Replace: "frob",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "multiline"): "foo\nfrob\nqux"},
},
{
name: "Multifile",
contents: map[string]string{
filepath.Join(tmpdir, "file1"): "file 1",
filepath.Join(tmpdir, "file2"): "file 2",
filepath.Join(tmpdir, "file3"): "file 3",
},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "file1"),
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited1",
},
},
},
{
Path: filepath.Join(tmpdir, "file2"),
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited2",
},
},
},
{
Path: filepath.Join(tmpdir, "file3"),
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited3",
},
},
},
},
expected: map[string]string{
filepath.Join(tmpdir, "file1"): "edited1 1",
filepath.Join(tmpdir, "file2"): "edited2 2",
filepath.Join(tmpdir, "file3"): "edited3 3",
},
},
{
name: "MultiError",
contents: map[string]string{
filepath.Join(tmpdir, "file8"): "file 8",
},
edits: []workspacesdk.FileEdits{
{
Path: noPermsFilePath,
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited7",
},
},
},
{
Path: filepath.Join(tmpdir, "file8"),
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited8",
},
},
},
{
Path: filepath.Join(tmpdir, "file9"),
Edits: []workspacesdk.FileEdit{
{
Search: "file",
Replace: "edited9",
},
},
},
},
expected: map[string]string{
filepath.Join(tmpdir, "file8"): "edited8 8",
},
// Higher status codes will override lower ones, so in this case the 404
// takes priority over the 403.
errCode: http.StatusNotFound,
errors: []string{
fmt.Sprintf("%s: permission denied", noPermsFilePath),
"file9: file does not exist",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
for path, content := range tt.contents {
err := afero.WriteFile(fs, path, []byte(content), 0o644)
require.NoError(t, err)
}
err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: tt.edits})
if tt.errCode != 0 {
require.Error(t, err)
cerr := coderdtest.SDKError(t, err)
for _, error := range tt.errors {
require.Contains(t, cerr.Error(), error)
}
require.Equal(t, tt.errCode, cerr.StatusCode())
} else {
require.NoError(t, err)
}
for path, expect := range tt.expected {
b, err := afero.ReadFile(fs, path)
require.NoError(t, err)
require.Equal(t, expect, string(b))
}
})
}
}
@@ -1,350 +0,0 @@
package backedpipe
import (
"context"
"io"
"sync"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/singleflight"
"golang.org/x/xerrors"
)
var (
ErrPipeClosed = xerrors.New("pipe is closed")
ErrPipeAlreadyConnected = xerrors.New("pipe is already connected")
ErrReconnectionInProgress = xerrors.New("reconnection already in progress")
ErrReconnectFailed = xerrors.New("reconnect failed")
ErrInvalidSequenceNumber = xerrors.New("remote sequence number exceeds local sequence")
ErrReconnectWriterFailed = xerrors.New("reconnect writer failed")
)
// connectionState represents the current state of the BackedPipe connection.
type connectionState int
const (
// connected indicates the pipe is connected and operational.
connected connectionState = iota
// disconnected indicates the pipe is not connected but not closed.
disconnected
// reconnecting indicates a reconnection attempt is in progress.
reconnecting
// closed indicates the pipe is permanently closed.
closed
)
// ErrorEvent represents an error from a reader or writer with connection generation info.
type ErrorEvent struct {
Err error
Component string // "reader" or "writer"
Generation uint64 // connection generation when error occurred
}
const (
// Default buffer capacity used by the writer - 64MB
DefaultBufferSize = 64 * 1024 * 1024
)
// Reconnector is an interface for establishing connections when the BackedPipe needs to reconnect.
// Implementations should:
// 1. Establish a new connection to the remote side
// 2. Exchange sequence numbers with the remote side
// 3. Return the new connection and the remote's reader sequence number
//
// The readerSeqNum parameter is the local reader's current sequence number
// (total bytes successfully read from the remote). This must be sent to the
// remote so it can replay its data to us starting from this number.
//
// The returned remoteReaderSeqNum should be the remote side's reader sequence
// number (how many bytes of our outbound data it has successfully read). This
// informs our writer where to resume (i.e., which bytes to replay to the remote).
type Reconnector interface {
Reconnect(ctx context.Context, readerSeqNum uint64) (conn io.ReadWriteCloser, remoteReaderSeqNum uint64, err error)
}
// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections.
// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection
// and data replay capabilities.
type BackedPipe struct {
ctx context.Context
cancel context.CancelFunc
mu sync.RWMutex
reader *BackedReader
writer *BackedWriter
reconnector Reconnector
conn io.ReadWriteCloser
// State machine
state connectionState
connGen uint64 // Increments on each successful reconnection
// Unified error handling with generation filtering
errChan chan ErrorEvent
// singleflight group to dedupe concurrent ForceReconnect calls
sf singleflight.Group
// Track first error per generation to avoid duplicate reconnections
lastErrorGen uint64
}
// NewBackedPipe creates a new BackedPipe with default options and the specified reconnector.
// The pipe starts disconnected and must be connected using Connect().
func NewBackedPipe(ctx context.Context, reconnector Reconnector) *BackedPipe {
pipeCtx, cancel := context.WithCancel(ctx)
errChan := make(chan ErrorEvent, 1)
bp := &BackedPipe{
ctx: pipeCtx,
cancel: cancel,
reconnector: reconnector,
state: disconnected,
connGen: 0, // Start with generation 0
errChan: errChan,
}
// Create reader and writer with typed error channel for generation-aware error reporting
bp.reader = NewBackedReader(errChan)
bp.writer = NewBackedWriter(DefaultBufferSize, errChan)
// Start error handler goroutine
go bp.handleErrors()
return bp
}
// Connect establishes the initial connection using the reconnect function.
func (bp *BackedPipe) Connect() error {
bp.mu.Lock()
defer bp.mu.Unlock()
if bp.state == closed {
return ErrPipeClosed
}
if bp.state == connected {
return ErrPipeAlreadyConnected
}
// Use internal context for the actual reconnect operation to ensure
// Close() reliably cancels any in-flight attempt.
return bp.reconnectLocked()
}
// Read implements io.Reader by delegating to the BackedReader.
func (bp *BackedPipe) Read(p []byte) (int, error) {
return bp.reader.Read(p)
}
// Write implements io.Writer by delegating to the BackedWriter.
func (bp *BackedPipe) Write(p []byte) (int, error) {
bp.mu.RLock()
writer := bp.writer
state := bp.state
bp.mu.RUnlock()
if state == closed {
return 0, io.EOF
}
return writer.Write(p)
}
// Close closes the pipe and all underlying connections.
func (bp *BackedPipe) Close() error {
bp.mu.Lock()
defer bp.mu.Unlock()
if bp.state == closed {
return nil
}
bp.state = closed
bp.cancel() // Cancel main context
// Close all components in parallel to avoid deadlocks
//
// IMPORTANT: The connection must be closed first to unblock any
// readers or writers that might be holding the mutex on Read/Write
var g errgroup.Group
if bp.conn != nil {
conn := bp.conn
g.Go(func() error {
return conn.Close()
})
bp.conn = nil
}
if bp.reader != nil {
reader := bp.reader
g.Go(func() error {
return reader.Close()
})
}
if bp.writer != nil {
writer := bp.writer
g.Go(func() error {
return writer.Close()
})
}
// Wait for all close operations to complete and return any error
return g.Wait()
}
// Connected returns whether the pipe is currently connected.
func (bp *BackedPipe) Connected() bool {
bp.mu.RLock()
defer bp.mu.RUnlock()
return bp.state == connected && bp.reader.Connected() && bp.writer.Connected()
}
// reconnectLocked handles the reconnection logic. Must be called with write lock held.
func (bp *BackedPipe) reconnectLocked() error {
if bp.state == reconnecting {
return ErrReconnectionInProgress
}
bp.state = reconnecting
defer func() {
// Only reset to disconnected if we're still in reconnecting state
// (successful reconnection will set state to connected)
if bp.state == reconnecting {
bp.state = disconnected
}
}()
// Close existing connection if any
if bp.conn != nil {
_ = bp.conn.Close()
bp.conn = nil
}
// Increment the generation and update both reader and writer.
// We do it now to track even the connections that fail during
// Reconnect.
bp.connGen++
bp.reader.SetGeneration(bp.connGen)
bp.writer.SetGeneration(bp.connGen)
// Reconnect reader and writer
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go bp.reader.Reconnect(seqNum, newR)
// Get the precise reader sequence number from the reader while it holds its lock
readerSeqNum, ok := <-seqNum
if !ok {
// Reader was closed during reconnection
return ErrReconnectFailed
}
// Perform reconnect using the exact sequence number we just received
conn, remoteReaderSeqNum, err := bp.reconnector.Reconnect(bp.ctx, readerSeqNum)
if err != nil {
// Unblock reader reconnect
newR <- nil
return ErrReconnectFailed
}
// Provide the new connection to the reader (reader still holds its lock)
newR <- conn
// Replay our outbound data from the remote's reader sequence number
writerReconnectErr := bp.writer.Reconnect(remoteReaderSeqNum, conn)
if writerReconnectErr != nil {
return ErrReconnectWriterFailed
}
// Success - update state
bp.conn = conn
bp.state = connected
return nil
}
// handleErrors listens for connection errors from reader/writer and triggers reconnection.
// It filters errors from old connections and ensures only the first error per generation
// triggers reconnection.
func (bp *BackedPipe) handleErrors() {
for {
select {
case <-bp.ctx.Done():
return
case errorEvt := <-bp.errChan:
bp.handleConnectionError(errorEvt)
}
}
}
// handleConnectionError handles errors from either reader or writer components.
// It filters errors from old connections and ensures only one reconnection per generation.
func (bp *BackedPipe) handleConnectionError(errorEvt ErrorEvent) {
bp.mu.Lock()
defer bp.mu.Unlock()
// Skip if already closed
if bp.state == closed {
return
}
// Filter errors from old connections (lower generation)
if errorEvt.Generation < bp.connGen {
return
}
// Skip if not connected (already disconnected or reconnecting)
if bp.state != connected {
return
}
// Skip if we've already seen an error for this generation
if bp.lastErrorGen >= errorEvt.Generation {
return
}
// This is the first error for this generation
bp.lastErrorGen = errorEvt.Generation
// Mark as disconnected
bp.state = disconnected
// Try to reconnect using internal context
reconnectErr := bp.reconnectLocked()
if reconnectErr != nil {
// Reconnection failed - log or handle as needed
// For now, we'll just continue and wait for manual reconnection
_ = errorEvt.Err // Use the original error from the component
_ = errorEvt.Component // Component info available for potential logging by higher layers
}
}
// ForceReconnect forces a reconnection attempt immediately.
// This can be used to force a reconnection if a new connection is established.
// It prevents duplicate reconnections when called concurrently.
func (bp *BackedPipe) ForceReconnect() error {
// Deduplicate concurrent ForceReconnect calls so only one reconnection
// attempt runs at a time from this API. Use the pipe's internal context
// to ensure Close() cancels any in-flight attempt.
_, err, _ := bp.sf.Do("force-reconnect", func() (interface{}, error) {
bp.mu.Lock()
defer bp.mu.Unlock()
if bp.state == closed {
return nil, io.EOF
}
// Don't force reconnect if already reconnecting
if bp.state == reconnecting {
return nil, ErrReconnectionInProgress
}
return nil, bp.reconnectLocked()
})
return err
}
@@ -1,989 +0,0 @@
package backedpipe_test
import (
"bytes"
"context"
"io"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
"github.com/coder/coder/v2/testutil"
)
// mockConnection implements io.ReadWriteCloser for testing
type mockConnection struct {
mu sync.Mutex
readBuffer bytes.Buffer
writeBuffer bytes.Buffer
closed bool
readError error
writeError error
closeError error
readFunc func([]byte) (int, error)
writeFunc func([]byte) (int, error)
seqNum uint64
}
func newMockConnection() *mockConnection {
return &mockConnection{}
}
func (mc *mockConnection) Read(p []byte) (int, error) {
mc.mu.Lock()
defer mc.mu.Unlock()
if mc.readFunc != nil {
return mc.readFunc(p)
}
if mc.readError != nil {
return 0, mc.readError
}
return mc.readBuffer.Read(p)
}
func (mc *mockConnection) Write(p []byte) (int, error) {
mc.mu.Lock()
defer mc.mu.Unlock()
if mc.writeFunc != nil {
return mc.writeFunc(p)
}
if mc.writeError != nil {
return 0, mc.writeError
}
return mc.writeBuffer.Write(p)
}
func (mc *mockConnection) Close() error {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.closed = true
return mc.closeError
}
func (mc *mockConnection) WriteString(s string) {
mc.mu.Lock()
defer mc.mu.Unlock()
_, _ = mc.readBuffer.WriteString(s)
}
func (mc *mockConnection) ReadString() string {
mc.mu.Lock()
defer mc.mu.Unlock()
return mc.writeBuffer.String()
}
func (mc *mockConnection) SetReadError(err error) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.readError = err
}
func (mc *mockConnection) SetWriteError(err error) {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.writeError = err
}
func (mc *mockConnection) Reset() {
mc.mu.Lock()
defer mc.mu.Unlock()
mc.readBuffer.Reset()
mc.writeBuffer.Reset()
mc.readError = nil
mc.writeError = nil
mc.closed = false
}
// mockReconnector implements the Reconnector interface for testing
type mockReconnector struct {
mu sync.Mutex
connections []*mockConnection
connectionIndex int
callCount int
signalChan chan struct{}
}
// Reconnect implements the Reconnector interface
func (m *mockReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.callCount++
if m.connectionIndex >= len(m.connections) {
return nil, 0, xerrors.New("no more connections available")
}
conn := m.connections[m.connectionIndex]
m.connectionIndex++
// Signal when reconnection happens
if m.connectionIndex > 1 {
select {
case m.signalChan <- struct{}{}:
default:
}
}
// Determine remoteReaderSeqNum (how many bytes of our outbound data the remote has read)
var remoteReaderSeqNum uint64
switch {
case m.callCount == 1:
remoteReaderSeqNum = 0
case conn.seqNum != 0:
remoteReaderSeqNum = conn.seqNum
default:
// Default to 0 if unspecified
remoteReaderSeqNum = 0
}
return conn, remoteReaderSeqNum, nil
}
// GetCallCount returns the current call count in a thread-safe manner
func (m *mockReconnector) GetCallCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return m.callCount
}
// mockReconnectFunc creates a unified reconnector with all behaviors enabled
func mockReconnectFunc(connections ...*mockConnection) (*mockReconnector, chan struct{}) {
signalChan := make(chan struct{}, 1)
reconnector := &mockReconnector{
connections: connections,
signalChan: signalChan,
}
return reconnector, signalChan
}
// blockingReconnector is a reconnector that blocks on a channel for deterministic testing
type blockingReconnector struct {
conn1 *mockConnection
conn2 *mockConnection
callCount int
blockChan <-chan struct{}
blockedChan chan struct{}
mu sync.Mutex
signalOnce sync.Once // Ensure we only signal once for the first actual reconnect
}
func (b *blockingReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
b.mu.Lock()
b.callCount++
currentCall := b.callCount
b.mu.Unlock()
if currentCall == 1 {
// Initial connect
return b.conn1, 0, nil
}
// Signal that we're about to block, but only once for the first reconnect attempt
// This ensures we properly test singleflight deduplication
b.signalOnce.Do(func() {
select {
case b.blockedChan <- struct{}{}:
default:
// If channel is full, don't block
}
})
// For subsequent calls, block until channel is closed
select {
case <-b.blockChan:
// Channel closed, proceed with reconnection
case <-ctx.Done():
return nil, 0, ctx.Err()
}
return b.conn2, 0, nil
}
// GetCallCount returns the current call count in a thread-safe manner
func (b *blockingReconnector) GetCallCount() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.callCount
}
func mockBlockingReconnectFunc(conn1, conn2 *mockConnection, blockChan <-chan struct{}) (*blockingReconnector, chan struct{}) {
blockedChan := make(chan struct{}, 1)
reconnector := &blockingReconnector{
conn1: conn1,
conn2: conn2,
blockChan: blockChan,
blockedChan: blockedChan,
}
return reconnector, blockedChan
}
// eofTestReconnector is a custom reconnector for the EOF test case
type eofTestReconnector struct {
mu sync.Mutex
conn1 io.ReadWriteCloser
conn2 io.ReadWriteCloser
callCount int
}
func (e *eofTestReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
e.mu.Lock()
defer e.mu.Unlock()
e.callCount++
if e.callCount == 1 {
return e.conn1, 0, nil
}
if e.callCount == 2 {
// Second call is the reconnection after EOF
// Return 5 to indicate remote has read all 5 bytes of "hello"
return e.conn2, 5, nil
}
return nil, 0, xerrors.New("no more connections")
}
// GetCallCount returns the current call count in a thread-safe manner
func (e *eofTestReconnector) GetCallCount() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.callCount
}
func TestBackedPipe_NewBackedPipe(t *testing.T) {
t.Parallel()
ctx := context.Background()
reconnectFn, _ := mockReconnectFunc(newMockConnection())
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
require.NotNil(t, bp)
require.False(t, bp.Connected())
}
func TestBackedPipe_Connect(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnector, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
err := bp.Connect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 1, reconnector.GetCallCount())
}
func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
err := bp.Connect()
require.NoError(t, err)
// Second connect should fail
err = bp.Connect()
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrPipeAlreadyConnected)
}
func TestBackedPipe_ConnectAfterClose(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
err := bp.Close()
require.NoError(t, err)
err = bp.Connect()
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrPipeClosed)
}
func TestBackedPipe_BasicReadWrite(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
err := bp.Connect()
require.NoError(t, err)
// Write data
n, err := bp.Write([]byte("hello"))
require.NoError(t, err)
require.Equal(t, 5, n)
// Simulate data coming back
conn.WriteString("world")
// Read data
buf := make([]byte, 10)
n, err = bp.Read(buf)
require.NoError(t, err)
require.Equal(t, 5, n)
require.Equal(t, "world", string(buf[:n]))
}
func TestBackedPipe_WriteBeforeConnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
// Write before connecting should block
writeComplete := make(chan error, 1)
go func() {
_, err := bp.Write([]byte("hello"))
writeComplete <- err
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when disconnected")
case <-time.After(100 * time.Millisecond):
// Expected - write is blocked
}
// Connect should unblock the write
err := bp.Connect()
require.NoError(t, err)
// Write should now complete
err = testutil.RequireReceive(ctx, t, writeComplete)
require.NoError(t, err)
// Check that data was replayed to connection
require.Equal(t, "hello", conn.ReadString())
}
func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) {
t.Parallel()
ctx := context.Background()
testCtx := testutil.Context(t, testutil.WaitShort)
reconnectFn, _ := mockReconnectFunc(newMockConnection())
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
// Start a read that should block
readDone := make(chan struct{})
readStarted := make(chan struct{}, 1)
var readErr error
go func() {
defer close(readDone)
readStarted <- struct{}{} // Signal that we're about to start the read
buf := make([]byte, 10)
_, readErr = bp.Read(buf)
}()
// Wait for the goroutine to start
testutil.TryReceive(testCtx, t, readStarted)
// Ensure the read is actually blocked by verifying it hasn't completed
require.Eventually(t, func() bool {
select {
case <-readDone:
t.Fatal("Read should be blocked when disconnected")
return false
default:
// Good, still blocked
return true
}
}, testutil.WaitShort, testutil.IntervalMedium)
// Close should unblock the read
bp.Close()
testutil.TryReceive(testCtx, t, readDone)
require.Equal(t, io.EOF, readErr)
}
func TestBackedPipe_Reconnection(t *testing.T) {
t.Parallel()
ctx := context.Background()
testCtx := testutil.Context(t, testutil.WaitShort)
conn1 := newMockConnection()
conn2 := newMockConnection()
conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17
reconnectFn, signalChan := mockReconnectFunc(conn1, conn2)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
defer bp.Close()
// Initial connect
err := bp.Connect()
require.NoError(t, err)
// Write some data before failure
bp.Write([]byte("before disconnect***"))
// Simulate connection failure
conn1.SetReadError(xerrors.New("connection lost"))
conn1.SetWriteError(xerrors.New("connection lost"))
// Trigger a write to cause the pipe to notice the failure
_, _ = bp.Write([]byte("trigger failure "))
testutil.RequireReceive(testCtx, t, signalChan)
// Wait for reconnection to complete
require.Eventually(t, func() bool {
return bp.Connected()
}, testutil.WaitShort, testutil.IntervalFast, "pipe should reconnect")
replayedData := conn2.ReadString()
require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17")
// Verify that new writes work with the reconnected pipe
_, err = bp.Write([]byte("new data after reconnect"))
require.NoError(t, err)
// Read all data from the connection (replayed + new data)
allData := conn2.ReadString()
require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data")
}
func TestBackedPipe_Close(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
err := bp.Connect()
require.NoError(t, err)
err = bp.Close()
require.NoError(t, err)
require.True(t, conn.closed)
// Operations after close should fail
_, err = bp.Read(make([]byte, 10))
require.Equal(t, io.EOF, err)
_, err = bp.Write([]byte("test"))
require.Equal(t, io.EOF, err)
}
func TestBackedPipe_CloseIdempotent(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
err := bp.Close()
require.NoError(t, err)
// Second close should be no-op
err = bp.Close()
require.NoError(t, err)
}
func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) {
t.Parallel()
ctx := context.Background()
failingReconnector := &mockReconnector{
connections: nil, // No connections available
}
bp := backedpipe.NewBackedPipe(ctx, failingReconnector)
defer bp.Close()
err := bp.Connect()
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrReconnectFailed)
require.False(t, bp.Connected())
}
func TestBackedPipe_ForceReconnect(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn1 := newMockConnection()
conn2 := newMockConnection()
// Set conn2 sequence number to 9 to indicate remote has read all 9 bytes of "test data"
conn2.seqNum = 9
reconnector, _ := mockReconnectFunc(conn1, conn2)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Initial connect
err := bp.Connect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 1, reconnector.GetCallCount())
// Write some data to the first connection
_, err = bp.Write([]byte("test data"))
require.NoError(t, err)
require.Equal(t, "test data", conn1.ReadString())
// Force a reconnection
err = bp.ForceReconnect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 2, reconnector.GetCallCount())
// Since the mock returns the proper sequence number, no data should be replayed
// The new connection should be empty
require.Equal(t, "", conn2.ReadString())
// Verify that data can still be written and read after forced reconnection
_, err = bp.Write([]byte("new data"))
require.NoError(t, err)
require.Equal(t, "new data", conn2.ReadString())
// Verify that reads work with the new connection
conn2.WriteString("response data")
buf := make([]byte, 20)
n, err := bp.Read(buf)
require.NoError(t, err)
require.Equal(t, 13, n)
require.Equal(t, "response data", string(buf[:n]))
}
func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
// Close the pipe first
err := bp.Close()
require.NoError(t, err)
// Try to force reconnect when closed
err = bp.ForceReconnect()
require.Error(t, err)
require.Equal(t, io.EOF, err)
}
func TestBackedPipe_StateTransitionsAndGenerationTracking(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn1 := newMockConnection()
conn2 := newMockConnection()
conn3 := newMockConnection()
reconnector, signalChan := mockReconnectFunc(conn1, conn2, conn3)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Initial state should be disconnected
require.False(t, bp.Connected())
// Connect should transition to connected
err := bp.Connect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 1, reconnector.GetCallCount())
// Write some data
_, err = bp.Write([]byte("test data gen 1"))
require.NoError(t, err)
// Simulate connection failure by setting errors on connection
conn1.SetReadError(xerrors.New("connection lost"))
conn1.SetWriteError(xerrors.New("connection lost"))
// Trigger a write to cause the pipe to notice the failure
_, _ = bp.Write([]byte("trigger failure"))
// Wait for reconnection signal
testutil.RequireReceive(testutil.Context(t, testutil.WaitShort), t, signalChan)
// Wait for reconnection to complete
require.Eventually(t, func() bool {
return bp.Connected()
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect")
require.Equal(t, 2, reconnector.GetCallCount())
// Force another reconnection
err = bp.ForceReconnect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 3, reconnector.GetCallCount())
// Close should transition to closed state
err = bp.Close()
require.NoError(t, err)
require.False(t, bp.Connected())
// Operations on closed pipe should fail
err = bp.Connect()
require.Equal(t, backedpipe.ErrPipeClosed, err)
err = bp.ForceReconnect()
require.Equal(t, io.EOF, err)
}
func TestBackedPipe_GenerationFiltering(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn1 := newMockConnection()
conn2 := newMockConnection()
reconnector, _ := mockReconnectFunc(conn1, conn2)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Connect
err := bp.Connect()
require.NoError(t, err)
require.True(t, bp.Connected())
// Simulate multiple rapid errors from the same connection generation
// Only the first one should trigger reconnection
conn1.SetReadError(xerrors.New("error 1"))
conn1.SetWriteError(xerrors.New("error 2"))
// Trigger multiple errors quickly
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = bp.Write([]byte("trigger error 1"))
}()
go func() {
defer wg.Done()
_, _ = bp.Write([]byte("trigger error 2"))
}()
// Wait for both writes to complete
wg.Wait()
// Wait for reconnection to complete
require.Eventually(t, func() bool {
return bp.Connected()
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect once")
// Should have only reconnected once despite multiple errors
require.Equal(t, 2, reconnector.GetCallCount()) // Initial connect + 1 reconnect
}
func TestBackedPipe_DuplicateReconnectionPrevention(t *testing.T) {
t.Parallel()
ctx := context.Background()
testCtx := testutil.Context(t, testutil.WaitShort)
// Create a blocking reconnector for deterministic testing
conn1 := newMockConnection()
conn2 := newMockConnection()
blockChan := make(chan struct{})
reconnector, blockedChan := mockBlockingReconnectFunc(conn1, conn2, blockChan)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Initial connect
err := bp.Connect()
require.NoError(t, err)
require.Equal(t, 1, reconnector.GetCallCount(), "should have exactly 1 call after initial connect")
// We'll use channels to coordinate the test execution:
// 1. Start all goroutines but have them wait
// 2. Release the first one and wait for it to block
// 3. Release the others while the first is still blocked
const numConcurrent = 3
startSignals := make([]chan struct{}, numConcurrent)
startedSignals := make([]chan struct{}, numConcurrent)
for i := range startSignals {
startSignals[i] = make(chan struct{})
startedSignals[i] = make(chan struct{})
}
errors := make([]error, numConcurrent)
var wg sync.WaitGroup
// Start all goroutines
for i := 0; i < numConcurrent; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
// Wait for the signal to start
<-startSignals[idx]
// Signal that we're about to call ForceReconnect
close(startedSignals[idx])
errors[idx] = bp.ForceReconnect()
}(i)
}
// Start the first ForceReconnect and wait for it to block
close(startSignals[0])
<-startedSignals[0]
// Wait for the first reconnect to actually start and block
testutil.RequireReceive(testCtx, t, blockedChan)
// Now start all the other ForceReconnect calls
// They should all join the same singleflight operation
for i := 1; i < numConcurrent; i++ {
close(startSignals[i])
}
// Wait for all additional goroutines to have started their calls
for i := 1; i < numConcurrent; i++ {
<-startedSignals[i]
}
// At this point, one reconnect has started and is blocked,
// and all other goroutines have called ForceReconnect and should be
// waiting on the same singleflight operation.
// Due to singleflight, only one reconnect should have been attempted.
require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnect due to singleflight")
// Release the blocking reconnect function
close(blockChan)
// Wait for all ForceReconnect calls to complete
wg.Wait()
// All calls should succeed (they share the same result from singleflight)
for i, err := range errors {
require.NoError(t, err, "ForceReconnect %d should succeed", i, err)
}
// Final verification: call count should still be exactly 2
require.Equal(t, 2, reconnector.GetCallCount(), "final call count should be exactly 2: initial connect + 1 singleflight reconnect")
}
func TestBackedPipe_SingleReconnectionOnMultipleErrors(t *testing.T) {
t.Parallel()
ctx := context.Background()
testCtx := testutil.Context(t, testutil.WaitShort)
// Create connections for initial connect and reconnection
conn1 := newMockConnection()
conn2 := newMockConnection()
reconnector, signalChan := mockReconnectFunc(conn1, conn2)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Initial connect
err := bp.Connect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 1, reconnector.GetCallCount())
// Write some initial data to establish the connection
_, err = bp.Write([]byte("initial data"))
require.NoError(t, err)
// Set up both read and write errors on the connection
conn1.SetReadError(xerrors.New("read connection lost"))
conn1.SetWriteError(xerrors.New("write connection lost"))
// Trigger write error (this will trigger reconnection)
go func() {
_, _ = bp.Write([]byte("trigger write error"))
}()
// Wait for reconnection to start
testutil.RequireReceive(testCtx, t, signalChan)
// Wait for reconnection to complete
require.Eventually(t, func() bool {
return bp.Connected()
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect after write error")
// Verify that only one reconnection occurred
require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnection")
require.True(t, bp.Connected(), "should be connected after reconnection")
}
func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn := newMockConnection()
reconnector, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Don't connect initially, just force reconnect
err := bp.ForceReconnect()
require.NoError(t, err)
require.True(t, bp.Connected())
require.Equal(t, 1, reconnector.GetCallCount())
// Verify we can write and read
_, err = bp.Write([]byte("test"))
require.NoError(t, err)
require.Equal(t, "test", conn.ReadString())
conn.WriteString("response")
buf := make([]byte, 10)
n, err := bp.Read(buf)
require.NoError(t, err)
require.Equal(t, 8, n)
require.Equal(t, "response", string(buf[:n]))
}
func TestBackedPipe_EOFTriggersReconnection(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Create connections where we can control when EOF occurs
conn1 := newMockConnection()
conn2 := newMockConnection()
conn2.WriteString("newdata") // Pre-populate conn2 with data
// Make conn1 return EOF after reading "world"
hasReadData := false
conn1.readFunc = func(p []byte) (int, error) {
// Don't lock here - the Read method already holds the lock
// First time: return "world"
if !hasReadData && conn1.readBuffer.Len() > 0 {
n, _ := conn1.readBuffer.Read(p)
hasReadData = true
return n, nil
}
// After that: return EOF
return 0, io.EOF
}
conn1.WriteString("world")
reconnector := &eofTestReconnector{
conn1: conn1,
conn2: conn2,
}
bp := backedpipe.NewBackedPipe(ctx, reconnector)
defer bp.Close()
// Initial connect
err := bp.Connect()
require.NoError(t, err)
require.Equal(t, 1, reconnector.GetCallCount())
// Write some data
_, err = bp.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 10)
// First read should succeed
n, err := bp.Read(buf)
require.NoError(t, err)
require.Equal(t, 5, n)
require.Equal(t, "world", string(buf[:n]))
// Next read will encounter EOF and should trigger reconnection
// After reconnection, it should read from conn2
n, err = bp.Read(buf)
require.NoError(t, err)
require.Equal(t, 7, n)
require.Equal(t, "newdata", string(buf[:n]))
// Verify reconnection happened
require.Equal(t, 2, reconnector.GetCallCount())
// Verify the pipe is still connected and functional
require.True(t, bp.Connected())
// Further writes should go to the new connection
_, err = bp.Write([]byte("aftereof"))
require.NoError(t, err)
require.Equal(t, "aftereof", conn2.ReadString())
}
func BenchmarkBackedPipe_Write(b *testing.B) {
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
bp.Connect()
b.Cleanup(func() {
_ = bp.Close()
})
data := make([]byte, 1024) // 1KB writes
b.ResetTimer()
for i := 0; i < b.N; i++ {
bp.Write(data)
}
}
func BenchmarkBackedPipe_Read(b *testing.B) {
ctx := context.Background()
conn := newMockConnection()
reconnectFn, _ := mockReconnectFunc(conn)
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
bp.Connect()
b.Cleanup(func() {
_ = bp.Close()
})
buf := make([]byte, 1024)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Fill connection with fresh data for each iteration
conn.WriteString(string(buf))
bp.Read(buf)
}
}
@@ -1,166 +0,0 @@
package backedpipe
import (
"io"
"sync"
)
// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections.
// It tracks sequence numbers for all bytes read and can handle reconnection,
// blocking reads when disconnected instead of erroring.
type BackedReader struct {
mu sync.Mutex
cond *sync.Cond
reader io.Reader
sequenceNum uint64
closed bool
// Error channel for generation-aware error reporting
errorEventChan chan<- ErrorEvent
// Current connection generation for error reporting
currentGen uint64
}
// NewBackedReader creates a new BackedReader with generation-aware error reporting.
// The reader is initially disconnected and must be connected using Reconnect before
// reads will succeed. The errorEventChan will receive ErrorEvent structs containing
// error details, component info, and connection generation.
func NewBackedReader(errorEventChan chan<- ErrorEvent) *BackedReader {
if errorEventChan == nil {
panic("error event channel cannot be nil")
}
br := &BackedReader{
errorEventChan: errorEventChan,
}
br.cond = sync.NewCond(&br.mu)
return br
}
// Read implements io.Reader. It blocks when disconnected until either:
// 1. A reconnection is established
// 2. The reader is closed
//
// When connected, it reads from the underlying reader and updates sequence numbers.
// Connection failures are automatically detected and reported to the higher layer via callback.
func (br *BackedReader) Read(p []byte) (int, error) {
br.mu.Lock()
defer br.mu.Unlock()
for {
// Step 1: Wait until we have a reader or are closed
for br.reader == nil && !br.closed {
br.cond.Wait()
}
if br.closed {
return 0, io.EOF
}
// Step 2: Perform the read while holding the mutex
// This ensures proper synchronization with Reconnect and Close operations
n, err := br.reader.Read(p)
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
if err == nil {
return n, nil
}
// Mark reader as disconnected so future reads will wait for reconnection
br.reader = nil
// Notify parent of error with generation information
select {
case br.errorEventChan <- ErrorEvent{
Err: err,
Component: "reader",
Generation: br.currentGen,
}:
default:
// Channel is full, drop the error.
// This is not a problem, because we set the reader to nil
// and block until reconnected so no new errors will be sent
// until pipe processes the error and reconnects.
}
// If we got some data before the error, return it now
if n > 0 {
return n, nil
}
}
}
// Reconnect coordinates reconnection using channels for better synchronization.
// The seqNum channel is used to send the current sequence number to the caller.
// The newR channel is used to receive the new reader from the caller.
// This allows for better coordination during the reconnection process.
func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) {
// Grab the lock
br.mu.Lock()
defer br.mu.Unlock()
if br.closed {
// Close the channel to indicate closed state
close(seqNum)
return
}
// Get the sequence number to send to the other side via seqNum channel
seqNum <- br.sequenceNum
close(seqNum)
// Wait for the reconnect to complete, via newR channel, and give us a new io.Reader
newReader := <-newR
// If reconnection fails while we are starting it, the caller sends nil on newR
if newReader == nil {
// Reconnection failed, keep current state
return
}
// Reconnection successful
br.reader = newReader
// Notify any waiting reads via the cond
br.cond.Broadcast()
}
// Close the reader and wake up any blocked reads.
// After closing, all Read calls will return io.EOF.
func (br *BackedReader) Close() error {
br.mu.Lock()
defer br.mu.Unlock()
if br.closed {
return nil
}
br.closed = true
br.reader = nil
// Wake up any blocked reads
br.cond.Broadcast()
return nil
}
// SequenceNum returns the current sequence number (total bytes read).
func (br *BackedReader) SequenceNum() uint64 {
br.mu.Lock()
defer br.mu.Unlock()
return br.sequenceNum
}
// Connected returns whether the reader is currently connected.
func (br *BackedReader) Connected() bool {
br.mu.Lock()
defer br.mu.Unlock()
return br.reader != nil
}
// SetGeneration sets the current connection generation for error reporting.
func (br *BackedReader) SetGeneration(generation uint64) {
br.mu.Lock()
defer br.mu.Unlock()
br.currentGen = generation
}
@@ -1,603 +0,0 @@
package backedpipe_test
import (
"context"
"io"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
"github.com/coder/coder/v2/testutil"
)
// mockReader implements io.Reader with controllable behavior for testing
type mockReader struct {
mu sync.Mutex
data []byte
pos int
err error
readFunc func([]byte) (int, error)
}
func newMockReader(data string) *mockReader {
return &mockReader{data: []byte(data)}
}
func (mr *mockReader) Read(p []byte) (int, error) {
mr.mu.Lock()
defer mr.mu.Unlock()
if mr.readFunc != nil {
return mr.readFunc(p)
}
if mr.err != nil {
return 0, mr.err
}
if mr.pos >= len(mr.data) {
return 0, io.EOF
}
n := copy(p, mr.data[mr.pos:])
mr.pos += n
return n, nil
}
func (mr *mockReader) setError(err error) {
mr.mu.Lock()
defer mr.mu.Unlock()
mr.err = err
}
func TestBackedReader_NewBackedReader(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
require.NotNil(t, br)
require.Equal(t, uint64(0), br.SequenceNum())
require.False(t, br.Connected())
}
func TestBackedReader_BasicReadOperation(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader := newMockReader("hello world")
// Connect the reader
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number from reader
seq := testutil.RequireReceive(ctx, t, seqNum)
require.Equal(t, uint64(0), seq)
// Send new reader
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
// Read data
buf := make([]byte, 5)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, 5, n)
require.Equal(t, "hello", string(buf))
require.Equal(t, uint64(5), br.SequenceNum())
// Read more data
n, err = br.Read(buf)
require.NoError(t, err)
require.Equal(t, 5, n)
require.Equal(t, " worl", string(buf))
require.Equal(t, uint64(10), br.SequenceNum())
}
func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
// Start a read operation that should block
readDone := make(chan struct{})
var readErr error
var readBuf []byte
var readN int
go func() {
defer close(readDone)
buf := make([]byte, 10)
readN, readErr = br.Read(buf)
readBuf = buf[:readN]
}()
// Ensure the read is actually blocked by verifying it hasn't completed
// and that the reader is not connected
select {
case <-readDone:
t.Fatal("Read should be blocked when disconnected")
default:
// Read is still blocked, which is what we want
}
require.False(t, br.Connected(), "Reader should not be connected")
// Connect and the read should unblock
reader := newMockReader("test")
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send new reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
// Wait for read to complete
testutil.TryReceive(ctx, t, readDone)
require.NoError(t, readErr)
require.Equal(t, "test", string(readBuf))
}
func TestBackedReader_ReconnectionAfterFailure(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader1 := newMockReader("first")
// Initial connection
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send new reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, io.Reader(reader1))
// Read some data
buf := make([]byte, 5)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, "first", string(buf[:n]))
require.Equal(t, uint64(5), br.SequenceNum())
// Simulate connection failure
reader1.setError(xerrors.New("connection lost"))
// Start a read that will block due to connection failure
readDone := make(chan error, 1)
go func() {
_, err := br.Read(buf)
readDone <- err
}()
// Wait for the error to be reported via error channel
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Error(t, receivedErrorEvent.Err)
require.Equal(t, "reader", receivedErrorEvent.Component)
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
// Verify read is still blocked
select {
case err := <-readDone:
t.Fatalf("Read should still be blocked, but completed with: %v", err)
default:
// Good, still blocked
}
// Verify disconnection
require.False(t, br.Connected())
// Reconnect with new reader
reader2 := newMockReader("second")
seqNum2 := make(chan uint64, 1)
newR2 := make(chan io.Reader, 1)
go br.Reconnect(seqNum2, newR2)
// Get sequence number and send new reader
seq := testutil.RequireReceive(ctx, t, seqNum2)
require.Equal(t, uint64(5), seq) // Should return current sequence number
testutil.RequireSend(ctx, t, newR2, io.Reader(reader2))
// Wait for read to unblock and succeed with new data
readErr := testutil.RequireReceive(ctx, t, readDone)
require.NoError(t, readErr) // Should succeed with new reader
require.True(t, br.Connected())
}
func TestBackedReader_Close(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader := newMockReader("test")
// Connect
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send new reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
// First, read all available data
buf := make([]byte, 10)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, 4, n) // "test" is 4 bytes
// Close the reader before EOF triggers reconnection
err = br.Close()
require.NoError(t, err)
// After close, reads should return EOF
n, err = br.Read(buf)
require.Equal(t, 0, n)
require.Equal(t, io.EOF, err)
// Subsequent reads should return EOF
_, err = br.Read(buf)
require.Equal(t, io.EOF, err)
}
func TestBackedReader_CloseIdempotent(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
err := br.Close()
require.NoError(t, err)
// Second close should be no-op
err = br.Close()
require.NoError(t, err)
}
func TestBackedReader_ReconnectAfterClose(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
err := br.Close()
require.NoError(t, err)
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Should get 0 sequence number for closed reader
seq := testutil.TryReceive(ctx, t, seqNum)
require.Equal(t, uint64(0), seq)
}
// Helper function to reconnect a reader using channels
func reconnectReader(ctx context.Context, t testing.TB, br *backedpipe.BackedReader, reader io.Reader) {
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send new reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, reader)
}
func TestBackedReader_SequenceNumberTracking(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader := newMockReader("0123456789")
reconnectReader(ctx, t, br, reader)
// Read in chunks and verify sequence number
buf := make([]byte, 3)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, uint64(3), br.SequenceNum())
n, err = br.Read(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, uint64(6), br.SequenceNum())
n, err = br.Read(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, uint64(9), br.SequenceNum())
}
func TestBackedReader_EOFHandling(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader := newMockReader("test")
reconnectReader(ctx, t, br, reader)
// Read all data
buf := make([]byte, 10)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, 4, n)
require.Equal(t, "test", string(buf[:n]))
// Next read should encounter EOF, which triggers disconnection
// The read should block waiting for reconnection
readDone := make(chan struct{})
var readErr error
var readN int
go func() {
defer close(readDone)
readN, readErr = br.Read(buf)
}()
// Wait for EOF to be reported via error channel
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Equal(t, io.EOF, receivedErrorEvent.Err)
require.Equal(t, "reader", receivedErrorEvent.Component)
// Reader should be disconnected after EOF
require.False(t, br.Connected())
// Read should still be blocked
select {
case <-readDone:
t.Fatal("Read should be blocked waiting for reconnection after EOF")
default:
// Good, still blocked
}
// Reconnect with new data
reader2 := newMockReader("more")
reconnectReader(ctx, t, br, reader2)
// Wait for the blocked read to complete with new data
testutil.TryReceive(ctx, t, readDone)
require.NoError(t, readErr)
require.Equal(t, 4, readN)
require.Equal(t, "more", string(buf[:readN]))
}
func BenchmarkBackedReader_Read(b *testing.B) {
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
buf := make([]byte, 1024)
// Create a reader that never returns EOF by cycling through data
reader := &mockReader{
readFunc: func(p []byte) (int, error) {
// Fill buffer with 'x' characters - never EOF
for i := range p {
p[i] = 'x'
}
return len(p), nil
},
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
reconnectReader(ctx, b, br, reader)
b.ResetTimer()
for i := 0; i < b.N; i++ {
br.Read(buf)
}
}
func TestBackedReader_PartialReads(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
// Create a reader that returns partial reads
reader := &mockReader{
readFunc: func(p []byte) (int, error) {
// Always return just 1 byte at a time
if len(p) == 0 {
return 0, nil
}
p[0] = 'A'
return 1, nil
},
}
reconnectReader(ctx, t, br, reader)
// Read multiple times
buf := make([]byte, 10)
for i := 0; i < 5; i++ {
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, 1, n)
require.Equal(t, byte('A'), buf[0])
}
require.Equal(t, uint64(5), br.SequenceNum())
}
func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
// Create a reader that blocks on Read calls but can be unblocked
readStarted := make(chan struct{}, 1)
readUnblocked := make(chan struct{})
blockingReader := &mockReader{
readFunc: func(p []byte) (int, error) {
select {
case readStarted <- struct{}{}:
default:
}
<-readUnblocked // Block until signaled
// After unblocking, return an error to simulate connection failure
return 0, xerrors.New("connection interrupted")
},
}
// Connect the blocking reader
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send blocking reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, io.Reader(blockingReader))
// Start a read that will block on the underlying reader
readDone := make(chan struct{})
var readErr error
var readN int
go func() {
defer close(readDone)
buf := make([]byte, 10)
readN, readErr = br.Read(buf)
}()
// Wait for the read to start and block on the underlying reader
testutil.RequireReceive(ctx, t, readStarted)
// Verify read is blocked by checking that it hasn't completed
// and ensuring we have adequate time for it to reach the blocking state
require.Eventually(t, func() bool {
select {
case <-readDone:
t.Fatal("Read should be blocked on underlying reader")
return false
default:
// Good, still blocked
return true
}
}, testutil.WaitShort, testutil.IntervalMedium)
// Start Close() in a goroutine since it will block until the underlying read completes
closeDone := make(chan error, 1)
go func() {
closeDone <- br.Close()
}()
// Verify Close() is also blocked waiting for the underlying read
select {
case <-closeDone:
t.Fatal("Close should be blocked until underlying read completes")
case <-time.After(10 * time.Millisecond):
// Good, Close is blocked
}
// Unblock the underlying reader, which will cause both the read and close to complete
close(readUnblocked)
// Wait for both the read and close to complete
testutil.TryReceive(ctx, t, readDone)
closeErr := testutil.RequireReceive(ctx, t, closeDone)
require.NoError(t, closeErr)
// The read should return EOF because Close() was called while it was blocked,
// even though the underlying reader returned an error
require.Equal(t, 0, readN)
require.Equal(t, io.EOF, readErr)
// Subsequent reads should return EOF since the reader is now closed
buf := make([]byte, 10)
n, err := br.Read(buf)
require.Equal(t, 0, n)
require.Equal(t, io.EOF, err)
}
func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
br := backedpipe.NewBackedReader(errChan)
reader1 := newMockReader("initial")
// Initial connection
seqNum := make(chan uint64, 1)
newR := make(chan io.Reader, 1)
go br.Reconnect(seqNum, newR)
// Get sequence number and send initial reader
testutil.RequireReceive(ctx, t, seqNum)
testutil.RequireSend(ctx, t, newR, io.Reader(reader1))
// Read initial data
buf := make([]byte, 10)
n, err := br.Read(buf)
require.NoError(t, err)
require.Equal(t, "initial", string(buf[:n]))
// Simulate connection failure
reader1.setError(xerrors.New("connection lost"))
// Start a read that will block waiting for reconnection
readDone := make(chan struct{})
var readErr error
var readN int
go func() {
defer close(readDone)
readN, readErr = br.Read(buf)
}()
// Wait for the error to be reported (indicating disconnection)
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Error(t, receivedErrorEvent.Err)
require.Equal(t, "reader", receivedErrorEvent.Component)
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
// Verify read is blocked waiting for reconnection
select {
case <-readDone:
t.Fatal("Read should be blocked waiting for reconnection")
default:
// Good, still blocked
}
// Verify reader is disconnected
require.False(t, br.Connected())
// Close the BackedReader while read is blocked waiting for reconnection
err = br.Close()
require.NoError(t, err)
// The read should unblock and return EOF
testutil.TryReceive(ctx, t, readDone)
require.Equal(t, 0, readN)
require.Equal(t, io.EOF, readErr)
}
@@ -1,243 +0,0 @@
package backedpipe
import (
"io"
"os"
"sync"
"golang.org/x/xerrors"
)
var (
ErrWriterClosed = xerrors.New("cannot reconnect closed writer")
ErrNilWriter = xerrors.New("new writer cannot be nil")
ErrFutureSequence = xerrors.New("cannot replay from future sequence")
ErrReplayDataUnavailable = xerrors.New("failed to read replay data")
ErrReplayFailed = xerrors.New("replay failed")
ErrPartialReplay = xerrors.New("partial replay")
)
// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections.
// It maintains a ring buffer of recent writes for replay during reconnection.
type BackedWriter struct {
mu sync.Mutex
cond *sync.Cond
writer io.Writer
buffer *ringBuffer
sequenceNum uint64 // total bytes written
closed bool
// Error channel for generation-aware error reporting
errorEventChan chan<- ErrorEvent
// Current connection generation for error reporting
currentGen uint64
}
// NewBackedWriter creates a new BackedWriter with generation-aware error reporting.
// The writer is initially disconnected and will block writes until connected.
// The errorEventChan will receive ErrorEvent structs containing error details,
// component info, and connection generation. Capacity must be > 0.
func NewBackedWriter(capacity int, errorEventChan chan<- ErrorEvent) *BackedWriter {
if capacity <= 0 {
panic("backed writer capacity must be > 0")
}
if errorEventChan == nil {
panic("error event channel cannot be nil")
}
bw := &BackedWriter{
buffer: newRingBuffer(capacity),
errorEventChan: errorEventChan,
}
bw.cond = sync.NewCond(&bw.mu)
return bw
}
// blockUntilConnectedOrClosed blocks until either a writer is available or the BackedWriter is closed.
// Returns os.ErrClosed if closed while waiting, nil if connected. You must hold the mutex to call this.
func (bw *BackedWriter) blockUntilConnectedOrClosed() error {
for bw.writer == nil && !bw.closed {
bw.cond.Wait()
}
if bw.closed {
return os.ErrClosed
}
return nil
}
// Write implements io.Writer.
// When connected, it writes to both the ring buffer (to preserve data in case we need to replay it)
// and the underlying writer.
// If the underlying write fails, the writer is marked as disconnected and the write blocks
// until reconnection occurs.
func (bw *BackedWriter) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
bw.mu.Lock()
defer bw.mu.Unlock()
// Block until connected
if err := bw.blockUntilConnectedOrClosed(); err != nil {
return 0, err
}
// Write to buffer
bw.buffer.Write(p)
bw.sequenceNum += uint64(len(p))
// Try to write to underlying writer
n, err := bw.writer.Write(p)
if err == nil && n != len(p) {
err = io.ErrShortWrite
}
if err != nil {
// Connection failed or partial write, mark as disconnected
bw.writer = nil
// Notify parent of error with generation information
select {
case bw.errorEventChan <- ErrorEvent{
Err: err,
Component: "writer",
Generation: bw.currentGen,
}:
default:
// Channel is full, drop the error.
// This is not a problem, because we set the writer to nil
// and block until reconnected so no new errors will be sent
// until pipe processes the error and reconnects.
}
// Block until reconnected - reconnection will replay this data
if err := bw.blockUntilConnectedOrClosed(); err != nil {
return 0, err
}
// Don't retry - reconnection replay handled it
return len(p), nil
}
// Write succeeded
return len(p), nil
}
// Reconnect replaces the current writer with a new one and replays data from the specified
// sequence number. If the requested sequence number is no longer in the buffer,
// returns an error indicating data loss.
//
// IMPORTANT: You must close the current writer, if any, before calling this method.
// Otherwise, if a Write operation is currently blocked in the underlying writer's
// Write method, this method will deadlock waiting for the mutex that Write holds.
func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error {
bw.mu.Lock()
defer bw.mu.Unlock()
if bw.closed {
return ErrWriterClosed
}
if newWriter == nil {
return ErrNilWriter
}
// Check if we can replay from the requested sequence number
if replayFromSeq > bw.sequenceNum {
return ErrFutureSequence
}
// Calculate how many bytes we need to replay
replayBytes := bw.sequenceNum - replayFromSeq
var replayData []byte
if replayBytes > 0 {
// Get the last replayBytes from buffer
// If the buffer doesn't have enough data (some was evicted),
// ReadLast will return an error
var err error
// Safe conversion: The check above (replayFromSeq > bw.sequenceNum) ensures
// replayBytes = bw.sequenceNum - replayFromSeq is always <= bw.sequenceNum.
// Since sequence numbers are much smaller than maxInt, the uint64->int conversion is safe.
//nolint:gosec // Safe conversion: replayBytes <= sequenceNum, which is much less than maxInt
replayData, err = bw.buffer.ReadLast(int(replayBytes))
if err != nil {
return ErrReplayDataUnavailable
}
}
// Clear the current writer first in case replay fails
bw.writer = nil
// Replay data if needed. We keep the mutex held during replay to ensure
// no concurrent operations can interfere with the reconnection process.
if len(replayData) > 0 {
n, err := newWriter.Write(replayData)
if err != nil {
// Reconnect failed, writer remains nil
return ErrReplayFailed
}
if n != len(replayData) {
// Reconnect failed, writer remains nil
return ErrPartialReplay
}
}
// Set new writer only after successful replay. This ensures no concurrent
// writes can interfere with the replay operation.
bw.writer = newWriter
// Wake up any operations waiting for connection
bw.cond.Broadcast()
return nil
}
// Close closes the writer and prevents further writes.
// After closing, all Write calls will return os.ErrClosed.
// This code keeps the Close() signature consistent with io.Closer,
// but it never actually returns an error.
//
// IMPORTANT: You must close the current underlying writer, if any, before calling
// this method. Otherwise, if a Write operation is currently blocked in the
// underlying writer's Write method, this method will deadlock waiting for the
// mutex that Write holds.
func (bw *BackedWriter) Close() error {
bw.mu.Lock()
defer bw.mu.Unlock()
if bw.closed {
return nil
}
bw.closed = true
bw.writer = nil
// Wake up any blocked operations
bw.cond.Broadcast()
return nil
}
// SequenceNum returns the current sequence number (total bytes written).
func (bw *BackedWriter) SequenceNum() uint64 {
bw.mu.Lock()
defer bw.mu.Unlock()
return bw.sequenceNum
}
// Connected returns whether the writer is currently connected.
func (bw *BackedWriter) Connected() bool {
bw.mu.Lock()
defer bw.mu.Unlock()
return bw.writer != nil
}
// SetGeneration sets the current connection generation for error reporting.
func (bw *BackedWriter) SetGeneration(generation uint64) {
bw.mu.Lock()
defer bw.mu.Unlock()
bw.currentGen = generation
}
@@ -1,992 +0,0 @@
package backedpipe_test
import (
"bytes"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
"github.com/coder/coder/v2/testutil"
)
// mockWriter implements io.Writer with controllable behavior for testing
type mockWriter struct {
mu sync.Mutex
buffer bytes.Buffer
err error
writeFunc func([]byte) (int, error)
writeCalls int
}
func newMockWriter() *mockWriter {
return &mockWriter{}
}
// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior
func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter {
errChan := make(chan backedpipe.ErrorEvent, 1)
return backedpipe.NewBackedWriter(bufferSize, errChan)
}
func (mw *mockWriter) Write(p []byte) (int, error) {
mw.mu.Lock()
defer mw.mu.Unlock()
mw.writeCalls++
if mw.writeFunc != nil {
return mw.writeFunc(p)
}
if mw.err != nil {
return 0, mw.err
}
return mw.buffer.Write(p)
}
func (mw *mockWriter) Len() int {
mw.mu.Lock()
defer mw.mu.Unlock()
return mw.buffer.Len()
}
func (mw *mockWriter) Reset() {
mw.mu.Lock()
defer mw.mu.Unlock()
mw.buffer.Reset()
mw.writeCalls = 0
mw.err = nil
mw.writeFunc = nil
}
func (mw *mockWriter) setError(err error) {
mw.mu.Lock()
defer mw.mu.Unlock()
mw.err = err
}
func TestBackedWriter_NewBackedWriter(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
require.NotNil(t, bw)
require.Equal(t, uint64(0), bw.SequenceNum())
require.False(t, bw.Connected())
}
func TestBackedWriter_WriteBlocksWhenDisconnected(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Write should block when disconnected
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
n, writeErr = bw.Write([]byte("hello"))
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when disconnected")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Connect and verify write completes
writer := newMockWriter()
err := bw.Reconnect(0, writer)
require.NoError(t, err)
// Write should now complete
testutil.TryReceive(ctx, t, writeComplete)
require.NoError(t, writeErr)
require.Equal(t, 5, n)
require.Equal(t, uint64(5), bw.SequenceNum())
require.Equal(t, []byte("hello"), writer.buffer.Bytes())
}
func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
writer := newMockWriter()
// Connect
err := bw.Reconnect(0, writer)
require.NoError(t, err)
require.True(t, bw.Connected())
// Write should go to both buffer and underlying writer
n, err := bw.Write([]byte("hello"))
require.NoError(t, err)
require.Equal(t, 5, n)
// Data should be buffered
require.Equal(t, uint64(5), bw.SequenceNum())
// Check underlying writer
require.Equal(t, []byte("hello"), writer.buffer.Bytes())
}
func TestBackedWriter_BlockOnWriteFailure(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
writer := newMockWriter()
// Connect
err := bw.Reconnect(0, writer)
require.NoError(t, err)
// Cause write to fail
writer.setError(xerrors.New("write failed"))
// Write should block when underlying writer fails, not succeed immediately
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
n, writeErr = bw.Write([]byte("hello"))
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when underlying writer fails")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Wait for error event which implies writer was marked disconnected
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Contains(t, receivedErrorEvent.Err.Error(), "write failed")
require.Equal(t, "writer", receivedErrorEvent.Component)
require.False(t, bw.Connected())
// Reconnect with working writer and verify write completes
writer2 := newMockWriter()
err = bw.Reconnect(0, writer2) // Replay from beginning
require.NoError(t, err)
// Write should now complete
testutil.TryReceive(ctx, t, writeComplete)
require.NoError(t, writeErr)
require.Equal(t, 5, n)
require.Equal(t, uint64(5), bw.SequenceNum())
require.Equal(t, []byte("hello"), writer2.buffer.Bytes())
}
func TestBackedWriter_ReplayOnReconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Connect initially to write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
// Write some data while connected
_, err = bw.Write([]byte("hello"))
require.NoError(t, err)
_, err = bw.Write([]byte(" world"))
require.NoError(t, err)
require.Equal(t, uint64(11), bw.SequenceNum())
// Disconnect by causing a write failure
writer1.setError(xerrors.New("connection lost"))
// Write should block when underlying writer fails
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
n, writeErr = bw.Write([]byte("test"))
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when underlying writer fails")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Wait for error event which implies writer was marked disconnected
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
require.Equal(t, "writer", receivedErrorEvent.Component)
require.False(t, bw.Connected())
// Reconnect with new writer and request replay from beginning
writer2 := newMockWriter()
err = bw.Reconnect(0, writer2)
require.NoError(t, err)
// Write should now complete
select {
case <-writeComplete:
// Expected - write completed
case <-time.After(100 * time.Millisecond):
t.Fatal("Write should have completed after reconnection")
}
require.NoError(t, writeErr)
require.Equal(t, 4, n)
// Should have replayed all data including the failed write that was buffered
require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes())
// Write new data should go to both
_, err = bw.Write([]byte("!"))
require.NoError(t, err)
require.Equal(t, []byte("hello worldtest!"), writer2.buffer.Bytes())
}
func TestBackedWriter_PartialReplay(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Connect initially to write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
// Write some data
_, err = bw.Write([]byte("hello"))
require.NoError(t, err)
_, err = bw.Write([]byte(" world"))
require.NoError(t, err)
_, err = bw.Write([]byte("!"))
require.NoError(t, err)
// Reconnect with new writer and request replay from middle
writer2 := newMockWriter()
err = bw.Reconnect(5, writer2) // From " world!"
require.NoError(t, err)
// Should have replayed only the requested portion
require.Equal(t, []byte(" world!"), writer2.buffer.Bytes())
}
func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Connect initially to write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
_, err = bw.Write([]byte("hello"))
require.NoError(t, err)
writer2 := newMockWriter()
err = bw.Reconnect(10, writer2) // Future sequence
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrFutureSequence)
}
func TestBackedWriter_ReplayDataLoss(t *testing.T) {
t.Parallel()
bw := newBackedWriterForTest(10) // Small buffer for testing
// Connect initially to write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
// Fill buffer beyond capacity to cause eviction
_, err = bw.Write([]byte("0123456789")) // Fills buffer exactly
require.NoError(t, err)
_, err = bw.Write([]byte("abcdef")) // Should evict "012345"
require.NoError(t, err)
writer2 := newMockWriter()
err = bw.Reconnect(0, writer2) // Try to replay from evicted data
// With the new error handling, this should fail because we can't read all the data
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable)
}
func TestBackedWriter_BufferEviction(t *testing.T) {
t.Parallel()
bw := newBackedWriterForTest(5) // Very small buffer for testing
// Connect initially
writer := newMockWriter()
err := bw.Reconnect(0, writer)
require.NoError(t, err)
// Write data that will cause eviction
n, err := bw.Write([]byte("abcde"))
require.NoError(t, err)
require.Equal(t, 5, n)
// Write more to cause eviction
n, err = bw.Write([]byte("fg"))
require.NoError(t, err)
require.Equal(t, 2, n)
// Verify that the buffer contains only the latest data after eviction
// Total sequence number should be 7 (5 + 2)
require.Equal(t, uint64(7), bw.SequenceNum())
// Try to reconnect from the beginning - this should fail because
// the early data was evicted from the buffer
writer2 := newMockWriter()
err = bw.Reconnect(0, writer2)
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable)
// However, reconnecting from a sequence that's still in the buffer should work
// The buffer should contain the last 5 bytes: "cdefg"
writer3 := newMockWriter()
err = bw.Reconnect(2, writer3) // From sequence 2, should replay "cdefg"
require.NoError(t, err)
require.Equal(t, []byte("cdefg"), writer3.buffer.Bytes())
require.True(t, bw.Connected())
}
func TestBackedWriter_Close(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
writer := newMockWriter()
bw.Reconnect(0, writer)
err := bw.Close()
require.NoError(t, err)
// Writes after close should fail
_, err = bw.Write([]byte("test"))
require.Equal(t, os.ErrClosed, err)
// Reconnect after close should fail
err = bw.Reconnect(0, newMockWriter())
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrWriterClosed)
}
func TestBackedWriter_CloseIdempotent(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
err := bw.Close()
require.NoError(t, err)
// Second close should be no-op
err = bw.Close()
require.NoError(t, err)
}
func TestBackedWriter_ReconnectDuringReplay(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Connect initially to write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
_, err = bw.Write([]byte("hello world"))
require.NoError(t, err)
// Create a writer that fails during replay
writer2 := &mockWriter{
err: backedpipe.ErrReplayFailed,
}
err = bw.Reconnect(0, writer2)
require.Error(t, err)
require.ErrorIs(t, err, backedpipe.ErrReplayFailed)
require.False(t, bw.Connected())
}
func TestBackedWriter_BlockOnPartialWrite(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Create writer that does partial writes
writer := &mockWriter{
writeFunc: func(p []byte) (int, error) {
if len(p) > 3 {
return 3, nil // Only write first 3 bytes
}
return len(p), nil
},
}
bw.Reconnect(0, writer)
// Write should block due to partial write
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
n, writeErr = bw.Write([]byte("hello"))
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when underlying writer does partial write")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Wait for error event which implies writer was marked disconnected
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Contains(t, receivedErrorEvent.Err.Error(), "short write")
require.Equal(t, "writer", receivedErrorEvent.Component)
require.False(t, bw.Connected())
// Reconnect with working writer and verify write completes
writer2 := newMockWriter()
err := bw.Reconnect(0, writer2) // Replay from beginning
require.NoError(t, err)
// Write should now complete
testutil.TryReceive(ctx, t, writeComplete)
require.NoError(t, writeErr)
require.Equal(t, 5, n)
require.Equal(t, uint64(5), bw.SequenceNum())
require.Equal(t, []byte("hello"), writer2.buffer.Bytes())
}
func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Start a single write that should block
writeResult := make(chan error, 1)
go func() {
_, err := bw.Write([]byte("test"))
writeResult <- err
}()
// Verify write is blocked
select {
case <-writeResult:
t.Fatal("Write should have blocked when disconnected")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Connect and verify write completes
writer := newMockWriter()
err := bw.Reconnect(0, writer)
require.NoError(t, err)
// Write should now complete
err = testutil.RequireReceive(ctx, t, writeResult)
require.NoError(t, err)
// Write should have been written to the underlying writer
require.Equal(t, "test", writer.buffer.String())
}
func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Start a write that should block
writeComplete := make(chan error, 1)
go func() {
_, err := bw.Write([]byte("test"))
writeComplete <- err
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked when disconnected")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Close the writer
err := bw.Close()
require.NoError(t, err)
// Write should now complete with error
err = testutil.RequireReceive(ctx, t, writeComplete)
require.Equal(t, os.ErrClosed, err)
}
func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
writer := newMockWriter()
// Connect initially
err := bw.Reconnect(0, writer)
require.NoError(t, err)
// Write should succeed when connected
_, err = bw.Write([]byte("hello"))
require.NoError(t, err)
// Cause disconnection - the write should now block instead of returning an error
writer.setError(xerrors.New("connection lost"))
// This write should block
writeComplete := make(chan error, 1)
go func() {
_, err := bw.Write([]byte("world"))
writeComplete <- err
}()
// Verify write is blocked
select {
case <-writeComplete:
t.Fatal("Write should have blocked after disconnection")
case <-time.After(50 * time.Millisecond):
// Expected - write is blocked
}
// Wait for error event which implies writer was marked disconnected
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
require.Equal(t, "writer", receivedErrorEvent.Component)
require.False(t, bw.Connected())
// Reconnect and verify write completes
writer2 := newMockWriter()
err = bw.Reconnect(5, writer2) // Replay from after "hello"
require.NoError(t, err)
err = testutil.RequireReceive(ctx, t, writeComplete)
require.NoError(t, err)
// Check that only "world" was written during replay (not duplicated)
require.Equal(t, []byte("world"), writer2.buffer.Bytes()) // Only "world" since we replayed from sequence 5
}
func TestBackedWriter_ConcurrentWriteAndClose(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Don't connect initially - this will cause writes to block in blockUntilConnectedOrClosed()
writeStarted := make(chan struct{}, 1)
// Start a write operation that will block waiting for connection
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
// Signal that we're about to start the write
writeStarted <- struct{}{}
// This write will block in blockUntilConnectedOrClosed() since no writer is connected
n, writeErr = bw.Write([]byte("hello"))
}()
// Wait for write goroutine to start
ctx := testutil.Context(t, testutil.WaitShort)
testutil.RequireReceive(ctx, t, writeStarted)
// Ensure the write is actually blocked by repeatedly checking that:
// 1. The write hasn't completed yet
// 2. The writer is still not connected
// We use require.Eventually to give it a fair chance to reach the blocking state
require.Eventually(t, func() bool {
select {
case <-writeComplete:
t.Fatal("Write should be blocked when no writer is connected")
return false
default:
// Write is still blocked, which is what we want
return !bw.Connected()
}
}, testutil.WaitShort, testutil.IntervalMedium)
// Close the writer while the write is blocked waiting for connection
closeErr := bw.Close()
require.NoError(t, closeErr)
// Wait for write to complete
select {
case <-writeComplete:
// Good, write completed
case <-ctx.Done():
t.Fatal("Write did not complete in time")
}
// The write should have failed with os.ErrClosed because Close() was called
// while it was waiting for connection
require.ErrorIs(t, writeErr, os.ErrClosed)
require.Equal(t, 0, n)
// Subsequent writes should also fail
n, err := bw.Write([]byte("world"))
require.Equal(t, 0, n)
require.ErrorIs(t, err, os.ErrClosed)
}
func TestBackedWriter_ConcurrentWriteAndReconnect(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Initial connection
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
// Write some initial data
_, err = bw.Write([]byte("initial"))
require.NoError(t, err)
// Start reconnection which will block new writes
replayStarted := make(chan struct{}, 1) // Buffered to prevent race condition
replayCanComplete := make(chan struct{})
writer2 := &mockWriter{
writeFunc: func(p []byte) (int, error) {
// Signal that replay has started
select {
case replayStarted <- struct{}{}:
default:
// Signal already sent, which is fine
}
// Wait for test to allow replay to complete
<-replayCanComplete
return len(p), nil
},
}
// Start the reconnection in a goroutine so we can control timing
reconnectComplete := make(chan error, 1)
go func() {
reconnectComplete <- bw.Reconnect(0, writer2)
}()
ctx := testutil.Context(t, testutil.WaitShort)
// Wait for replay to start
testutil.RequireReceive(ctx, t, replayStarted)
// Now start a write operation that will be blocked by the ongoing reconnect
writeStarted := make(chan struct{}, 1)
writeComplete := make(chan struct{})
var writeErr error
var n int
go func() {
defer close(writeComplete)
// Signal that we're about to start the write
writeStarted <- struct{}{}
// This write should be blocked during reconnect
n, writeErr = bw.Write([]byte("blocked"))
}()
// Wait for write to start
testutil.RequireReceive(ctx, t, writeStarted)
// Use a small timeout to ensure the write goroutine has a chance to get blocked
// on the mutex before we check if it's still blocked
writeCheckTimer := time.NewTimer(testutil.IntervalFast)
defer writeCheckTimer.Stop()
select {
case <-writeComplete:
t.Fatal("Write should be blocked during reconnect")
case <-writeCheckTimer.C:
// Write is still blocked after a reasonable wait
}
// Allow replay to complete, which will allow reconnect to finish
close(replayCanComplete)
// Wait for reconnection to complete
select {
case reconnectErr := <-reconnectComplete:
require.NoError(t, reconnectErr)
case <-ctx.Done():
t.Fatal("Reconnect did not complete in time")
}
// Wait for write to complete
<-writeComplete
// Write should succeed after reconnection completes
require.NoError(t, writeErr)
require.Equal(t, 7, n) // "blocked" is 7 bytes
// Verify the writer is connected
require.True(t, bw.Connected())
}
func TestBackedWriter_ConcurrentReconnectAndClose(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Initial connection and write some data
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
_, err = bw.Write([]byte("test data"))
require.NoError(t, err)
// Start reconnection with slow replay
reconnectStarted := make(chan struct{}, 1)
replayCanComplete := make(chan struct{})
reconnectComplete := make(chan struct{})
var reconnectErr error
go func() {
defer close(reconnectComplete)
writer2 := &mockWriter{
writeFunc: func(p []byte) (int, error) {
// Signal that replay has started
select {
case reconnectStarted <- struct{}{}:
default:
}
// Wait for test to allow replay to complete
<-replayCanComplete
return len(p), nil
},
}
reconnectErr = bw.Reconnect(0, writer2)
}()
// Wait for reconnection to start
ctx := testutil.Context(t, testutil.WaitShort)
testutil.RequireReceive(ctx, t, reconnectStarted)
// Start Close() in a separate goroutine since it will block until Reconnect() completes
closeStarted := make(chan struct{}, 1)
closeComplete := make(chan error, 1)
go func() {
closeStarted <- struct{}{} // Signal that Close() is starting
closeComplete <- bw.Close()
}()
// Wait for Close() to start, then give it a moment to attempt to acquire the mutex
testutil.RequireReceive(ctx, t, closeStarted)
closeCheckTimer := time.NewTimer(testutil.IntervalFast)
defer closeCheckTimer.Stop()
select {
case <-closeComplete:
t.Fatal("Close should be blocked during reconnect")
case <-closeCheckTimer.C:
// Good, Close is still blocked after a reasonable wait
}
// Allow replay to complete so reconnection can finish
close(replayCanComplete)
// Wait for reconnect to complete
select {
case <-reconnectComplete:
// Good, reconnect completed
case <-ctx.Done():
t.Fatal("Reconnect did not complete in time")
}
// Wait for close to complete
select {
case closeErr := <-closeComplete:
require.NoError(t, closeErr)
case <-ctx.Done():
t.Fatal("Close did not complete in time")
}
// With mutex held during replay, Close() waits for Reconnect() to finish.
// So Reconnect() should succeed, then Close() runs and closes the writer.
require.NoError(t, reconnectErr)
// Verify writer is closed (Close() ran after Reconnect() completed)
require.False(t, bw.Connected())
}
func TestBackedWriter_MultipleWritesDuringReconnect(t *testing.T) {
t.Parallel()
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Initial connection
writer1 := newMockWriter()
err := bw.Reconnect(0, writer1)
require.NoError(t, err)
// Write some initial data
_, err = bw.Write([]byte("initial"))
require.NoError(t, err)
// Start multiple write operations
numWriters := 5
var wg sync.WaitGroup
writeResults := make([]error, numWriters)
writesStarted := make(chan struct{}, numWriters)
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Signal that this write is starting
writesStarted <- struct{}{}
data := []byte{byte('A' + id)}
_, writeResults[id] = bw.Write(data)
}(i)
}
// Wait for all writes to start
ctx := testutil.Context(t, testutil.WaitLong)
for i := 0; i < numWriters; i++ {
testutil.RequireReceive(ctx, t, writesStarted)
}
// Use a timer to ensure all write goroutines have had a chance to start executing
// and potentially get blocked on the mutex before we start the reconnection
writesReadyTimer := time.NewTimer(testutil.IntervalFast)
defer writesReadyTimer.Stop()
<-writesReadyTimer.C
// Start reconnection with controlled replay
replayStarted := make(chan struct{}, 1)
replayCanComplete := make(chan struct{})
writer2 := &mockWriter{
writeFunc: func(p []byte) (int, error) {
// Signal that replay has started
select {
case replayStarted <- struct{}{}:
default:
}
// Wait for test to allow replay to complete
<-replayCanComplete
return len(p), nil
},
}
// Start reconnection in a goroutine so we can control timing
reconnectComplete := make(chan error, 1)
go func() {
reconnectComplete <- bw.Reconnect(0, writer2)
}()
// Wait for replay to start
testutil.RequireReceive(ctx, t, replayStarted)
// Allow replay to complete
close(replayCanComplete)
// Wait for reconnection to complete
select {
case reconnectErr := <-reconnectComplete:
require.NoError(t, reconnectErr)
case <-ctx.Done():
t.Fatal("Reconnect did not complete in time")
}
// Wait for all writes to complete
wg.Wait()
// All writes should succeed
for i, err := range writeResults {
require.NoError(t, err, "Write %d should succeed", i)
}
// Verify the writer is connected
require.True(t, bw.Connected())
}
func BenchmarkBackedWriter_Write(b *testing.B) {
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) // 64KB buffer
writer := newMockWriter()
bw.Reconnect(0, writer)
data := bytes.Repeat([]byte("x"), 1024) // 1KB writes
b.ResetTimer()
for i := 0; i < b.N; i++ {
bw.Write(data)
}
}
func BenchmarkBackedWriter_Reconnect(b *testing.B) {
errChan := make(chan backedpipe.ErrorEvent, 1)
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
// Connect initially to fill buffer with data
initialWriter := newMockWriter()
err := bw.Reconnect(0, initialWriter)
if err != nil {
b.Fatal(err)
}
// Fill buffer with data
data := bytes.Repeat([]byte("x"), 1024)
for i := 0; i < 32; i++ {
bw.Write(data)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
writer := newMockWriter()
bw.Reconnect(0, writer)
}
}
@@ -1,129 +0,0 @@
package backedpipe
import "golang.org/x/xerrors"
// ringBuffer implements an efficient circular buffer with a fixed-size allocation.
// This implementation is not thread-safe and relies on external synchronization.
type ringBuffer struct {
buffer []byte
start int // index of first valid byte
end int // index of last valid byte (-1 when empty)
}
// newRingBuffer creates a new ring buffer with the specified capacity.
// Capacity must be > 0.
func newRingBuffer(capacity int) *ringBuffer {
if capacity <= 0 {
panic("ring buffer capacity must be > 0")
}
return &ringBuffer{
buffer: make([]byte, capacity),
end: -1, // -1 indicates empty buffer
}
}
// Size returns the current number of bytes in the buffer.
func (rb *ringBuffer) Size() int {
if rb.end == -1 {
return 0 // Buffer is empty
}
if rb.start <= rb.end {
return rb.end - rb.start + 1
}
// Buffer wraps around
return len(rb.buffer) - rb.start + rb.end + 1
}
// Write writes data to the ring buffer. If the buffer would overflow,
// it evicts the oldest data to make room for new data.
func (rb *ringBuffer) Write(data []byte) {
if len(data) == 0 {
return
}
capacity := len(rb.buffer)
// If data is larger than capacity, only keep the last capacity bytes
if len(data) > capacity {
data = data[len(data)-capacity:]
// Clear buffer and write new data
rb.start = 0
rb.end = -1 // Will be set properly below
}
// Calculate how much we need to evict to fit new data
spaceNeeded := len(data)
availableSpace := capacity - rb.Size()
if spaceNeeded > availableSpace {
bytesToEvict := spaceNeeded - availableSpace
rb.evict(bytesToEvict)
}
// Buffer has data, write after current end
writePos := (rb.end + 1) % capacity
if writePos+len(data) <= capacity {
// No wrap needed - single copy
copy(rb.buffer[writePos:], data)
rb.end = (rb.end + len(data)) % capacity
} else {
// Need to wrap around - two copies
firstChunk := capacity - writePos
copy(rb.buffer[writePos:], data[:firstChunk])
copy(rb.buffer[0:], data[firstChunk:])
rb.end = len(data) - firstChunk - 1
}
}
// evict removes the specified number of bytes from the beginning of the buffer.
func (rb *ringBuffer) evict(count int) {
if count >= rb.Size() {
// Evict everything
rb.start = 0
rb.end = -1
return
}
rb.start = (rb.start + count) % len(rb.buffer)
// Buffer remains non-empty after partial eviction
}
// ReadLast returns the last n bytes from the buffer.
// If n is greater than the available data, returns an error.
// If n is negative, returns an error.
func (rb *ringBuffer) ReadLast(n int) ([]byte, error) {
if n < 0 {
return nil, xerrors.New("cannot read negative number of bytes")
}
if n == 0 {
return nil, nil
}
size := rb.Size()
// If requested more than available, return error
if n > size {
return nil, xerrors.Errorf("requested %d bytes but only %d available", n, size)
}
result := make([]byte, n)
capacity := len(rb.buffer)
// Calculate where to start reading from (n bytes before the end)
startOffset := size - n
actualStart := (rb.start + startOffset) % capacity
// Copy the last n bytes
if actualStart+n <= capacity {
// No wrap needed
copy(result, rb.buffer[actualStart:actualStart+n])
} else {
// Need to wrap around
firstChunk := capacity - actualStart
copy(result[0:firstChunk], rb.buffer[actualStart:capacity])
copy(result[firstChunk:], rb.buffer[0:n-firstChunk])
}
return result, nil
}
@@ -1,261 +0,0 @@
package backedpipe
import (
"bytes"
"os"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
if runtime.GOOS == "windows" {
// Don't run goleak on windows tests, they're super flaky right now.
// See: https://github.com/coder/coder/issues/8954
os.Exit(m.Run())
}
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
func TestRingBuffer_NewRingBuffer(t *testing.T) {
t.Parallel()
rb := newRingBuffer(100)
// Test that we can write and read from the buffer
rb.Write([]byte("test"))
data, err := rb.ReadLast(4)
require.NoError(t, err)
require.Equal(t, []byte("test"), data)
}
func TestRingBuffer_WriteAndRead(t *testing.T) {
t.Parallel()
rb := newRingBuffer(10)
// Write some data
rb.Write([]byte("hello"))
// Read last 4 bytes
data, err := rb.ReadLast(4)
require.NoError(t, err)
require.Equal(t, "ello", string(data))
// Write more data
rb.Write([]byte("world"))
// Read last 5 bytes
data, err = rb.ReadLast(5)
require.NoError(t, err)
require.Equal(t, "world", string(data))
// Read last 3 bytes
data, err = rb.ReadLast(3)
require.NoError(t, err)
require.Equal(t, "rld", string(data))
// Read more than available (should be 10 bytes total)
_, err = rb.ReadLast(15)
require.Error(t, err)
require.Contains(t, err.Error(), "requested 15 bytes but only")
}
func TestRingBuffer_OverflowEviction(t *testing.T) {
t.Parallel()
rb := newRingBuffer(5)
// Fill buffer
rb.Write([]byte("abcde"))
// Overflow should evict oldest data
rb.Write([]byte("fg"))
// Should now contain "cdefg"
data, err := rb.ReadLast(5)
require.NoError(t, err)
require.Equal(t, []byte("cdefg"), data)
}
func TestRingBuffer_LargeWrite(t *testing.T) {
t.Parallel()
rb := newRingBuffer(5)
// Write data larger than capacity
rb.Write([]byte("abcdefghij"))
// Should contain last 5 bytes
data, err := rb.ReadLast(5)
require.NoError(t, err)
require.Equal(t, []byte("fghij"), data)
}
func TestRingBuffer_WrapAround(t *testing.T) {
t.Parallel()
rb := newRingBuffer(5)
// Fill buffer
rb.Write([]byte("abcde"))
// Write more to cause wrap-around
rb.Write([]byte("fgh"))
// Should contain "defgh"
data, err := rb.ReadLast(5)
require.NoError(t, err)
require.Equal(t, []byte("defgh"), data)
// Test reading last 3 bytes after wrap
data, err = rb.ReadLast(3)
require.NoError(t, err)
require.Equal(t, []byte("fgh"), data)
}
func TestRingBuffer_ReadLastEdgeCases(t *testing.T) {
t.Parallel()
rb := newRingBuffer(3)
// Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain)
rb.Write([]byte("hello"))
// Test reading negative count
data, err := rb.ReadLast(-1)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot read negative number of bytes")
require.Nil(t, data)
// Test reading zero bytes
data, err = rb.ReadLast(0)
require.NoError(t, err)
require.Nil(t, data)
// Test reading more than available (buffer has 3 bytes, try to read 10)
_, err = rb.ReadLast(10)
require.Error(t, err)
require.Contains(t, err.Error(), "requested 10 bytes but only 3 available")
// Test reading exact amount available
data, err = rb.ReadLast(3)
require.NoError(t, err)
require.Equal(t, []byte("llo"), data)
}
func TestRingBuffer_EmptyWrite(t *testing.T) {
t.Parallel()
rb := newRingBuffer(10)
// Write empty data
rb.Write([]byte{})
// Buffer should still be empty
_, err := rb.ReadLast(5)
require.Error(t, err)
require.Contains(t, err.Error(), "requested 5 bytes but only 0 available")
}
func TestRingBuffer_MultipleWrites(t *testing.T) {
t.Parallel()
rb := newRingBuffer(10)
// Write data in chunks
rb.Write([]byte("ab"))
rb.Write([]byte("cd"))
rb.Write([]byte("ef"))
data, err := rb.ReadLast(6)
require.NoError(t, err)
require.Equal(t, []byte("abcdef"), data)
// Test partial reads
data, err = rb.ReadLast(4)
require.NoError(t, err)
require.Equal(t, []byte("cdef"), data)
data, err = rb.ReadLast(2)
require.NoError(t, err)
require.Equal(t, []byte("ef"), data)
}
func TestRingBuffer_EdgeCaseEviction(t *testing.T) {
t.Parallel()
rb := newRingBuffer(3)
// Write data that will cause eviction
rb.Write([]byte("abc"))
// Write more to cause eviction
rb.Write([]byte("d"))
// Should now contain "bcd"
data, err := rb.ReadLast(3)
require.NoError(t, err)
require.Equal(t, []byte("bcd"), data)
}
func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) {
t.Parallel()
rb := newRingBuffer(8)
// Fill buffer
rb.Write([]byte("12345678"))
// Evict some and add more to create complex wrap scenario
rb.Write([]byte("abcd"))
data, err := rb.ReadLast(8)
require.NoError(t, err)
require.Equal(t, []byte("5678abcd"), data)
// Add more
rb.Write([]byte("xyz"))
data, err = rb.ReadLast(8)
require.NoError(t, err)
require.Equal(t, []byte("8abcdxyz"), data)
// Test reading various amounts from the end
data, err = rb.ReadLast(7)
require.NoError(t, err)
require.Equal(t, []byte("abcdxyz"), data)
data, err = rb.ReadLast(4)
require.NoError(t, err)
require.Equal(t, []byte("dxyz"), data)
}
// Benchmark tests for performance validation
func BenchmarkRingBuffer_Write(b *testing.B) {
rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks
data := bytes.Repeat([]byte("x"), 1024) // 1KB writes
b.ResetTimer()
for i := 0; i < b.N; i++ {
rb.Write(data)
}
}
func BenchmarkRingBuffer_ReadLast(b *testing.B) {
rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks
// Fill buffer with test data
for i := 0; i < 64; i++ {
rb.Write(bytes.Repeat([]byte("x"), 1024))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := rb.ReadLast((i % 100) + 1)
if err != nil {
b.Fatal(err)
}
}
}
+78 -69
View File
@@ -11,39 +11,23 @@ import (
"strings"
"github.com/shirou/gopsutil/v4/disk"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
var WindowsDriveRegex = regexp.MustCompile(`^[a-zA-Z]:\\$`)
func (a *agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
func (*agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// An absolute path may be optionally provided, otherwise a path split into an
// array must be provided in the body (which can be relative).
query := r.URL.Query()
parser := httpapi.NewQueryParamParser()
path := parser.String(query, "", "path")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
var query LSRequest
if !httpapi.Read(ctx, rw, r, &query) {
return
}
var req workspacesdk.LSRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
resp, err := listFiles(a.filesystem, path, req)
resp, err := listFiles(query)
if err != nil {
status := http.StatusInternalServerError
switch {
@@ -62,66 +46,58 @@ func (a *agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
absolutePathString := path
if absolutePathString != "" {
if !filepath.IsAbs(path) {
return workspacesdk.LSResponse{}, xerrors.Errorf("path must be absolute: %q", path)
}
} else {
var fullPath []string
switch query.Relativity {
case workspacesdk.LSRelativityHome:
home, err := os.UserHomeDir()
if err != nil {
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get user home directory: %w", err)
}
fullPath = []string{home}
case workspacesdk.LSRelativityRoot:
if runtime.GOOS == "windows" {
if len(query.Path) == 0 {
return listDrives()
}
if !WindowsDriveRegex.MatchString(query.Path[0]) {
return workspacesdk.LSResponse{}, xerrors.Errorf("invalid drive letter %q", query.Path[0])
}
} else {
fullPath = []string{"/"}
}
default:
return workspacesdk.LSResponse{}, xerrors.Errorf("unsupported relativity type %q", query.Relativity)
}
fullPath = append(fullPath, query.Path...)
fullPathRelative := filepath.Join(fullPath...)
var err error
absolutePathString, err = filepath.Abs(fullPathRelative)
func listFiles(query LSRequest) (LSResponse, error) {
var fullPath []string
switch query.Relativity {
case LSRelativityHome:
home, err := os.UserHomeDir()
if err != nil {
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get absolute path of %q: %w", fullPathRelative, err)
return LSResponse{}, xerrors.Errorf("failed to get user home directory: %w", err)
}
fullPath = []string{home}
case LSRelativityRoot:
if runtime.GOOS == "windows" {
if len(query.Path) == 0 {
return listDrives()
}
if !WindowsDriveRegex.MatchString(query.Path[0]) {
return LSResponse{}, xerrors.Errorf("invalid drive letter %q", query.Path[0])
}
} else {
fullPath = []string{"/"}
}
default:
return LSResponse{}, xerrors.Errorf("unsupported relativity type %q", query.Relativity)
}
fullPath = append(fullPath, query.Path...)
fullPathRelative := filepath.Join(fullPath...)
absolutePathString, err := filepath.Abs(fullPathRelative)
if err != nil {
return LSResponse{}, xerrors.Errorf("failed to get absolute path of %q: %w", fullPathRelative, err)
}
// codeql[go/path-injection] - The intent is to allow the user to navigate to any directory in their workspace.
f, err := fs.Open(absolutePathString)
f, err := os.Open(absolutePathString)
if err != nil {
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to open directory %q: %w", absolutePathString, err)
return LSResponse{}, xerrors.Errorf("failed to open directory %q: %w", absolutePathString, err)
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to stat directory %q: %w", absolutePathString, err)
return LSResponse{}, xerrors.Errorf("failed to stat directory %q: %w", absolutePathString, err)
}
if !stat.IsDir() {
return workspacesdk.LSResponse{}, xerrors.Errorf("path %q is not a directory", absolutePathString)
return LSResponse{}, xerrors.Errorf("path %q is not a directory", absolutePathString)
}
// `contents` may be partially populated even if the operation fails midway.
contents, _ := f.Readdir(-1)
respContents := make([]workspacesdk.LSFile, 0, len(contents))
contents, _ := f.ReadDir(-1)
respContents := make([]LSFile, 0, len(contents))
for _, file := range contents {
respContents = append(respContents, workspacesdk.LSFile{
respContents = append(respContents, LSFile{
Name: file.Name(),
AbsolutePathString: filepath.Join(absolutePathString, file.Name()),
IsDir: file.IsDir(),
@@ -129,7 +105,7 @@ func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspac
}
// Sort alphabetically: directories then files
slices.SortFunc(respContents, func(a, b workspacesdk.LSFile) int {
slices.SortFunc(respContents, func(a, b LSFile) int {
if a.IsDir && !b.IsDir {
return -1
}
@@ -141,35 +117,35 @@ func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspac
absolutePath := pathToArray(absolutePathString)
return workspacesdk.LSResponse{
return LSResponse{
AbsolutePath: absolutePath,
AbsolutePathString: absolutePathString,
Contents: respContents,
}, nil
}
func listDrives() (workspacesdk.LSResponse, error) {
func listDrives() (LSResponse, error) {
// disk.Partitions() will return partitions even if there was a failure to
// get one. Any errored partitions will not be returned.
partitionStats, err := disk.Partitions(true)
if err != nil && len(partitionStats) == 0 {
// Only return the error if there were no partitions returned.
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get partitions: %w", err)
return LSResponse{}, xerrors.Errorf("failed to get partitions: %w", err)
}
contents := make([]workspacesdk.LSFile, 0, len(partitionStats))
contents := make([]LSFile, 0, len(partitionStats))
for _, a := range partitionStats {
// Drive letters on Windows have a trailing separator as part of their name.
// i.e. `os.Open("C:")` does not work, but `os.Open("C:\\")` does.
name := a.Mountpoint + string(os.PathSeparator)
contents = append(contents, workspacesdk.LSFile{
contents = append(contents, LSFile{
Name: name,
AbsolutePathString: name,
IsDir: true,
})
}
return workspacesdk.LSResponse{
return LSResponse{
AbsolutePath: []string{},
AbsolutePathString: "",
Contents: contents,
@@ -187,3 +163,36 @@ func pathToArray(path string) []string {
}
return out
}
type LSRequest struct {
// e.g. [], ["repos", "coder"],
Path []string `json:"path"`
// Whether the supplied path is relative to the user's home directory,
// or the root directory.
Relativity LSRelativity `json:"relativity"`
}
type LSResponse struct {
AbsolutePath []string `json:"absolute_path"`
// Returned so clients can display the full path to the user, and
// copy it to configure file sync
// e.g. Windows: "C:\\Users\\coder"
// Linux: "/home/coder"
AbsolutePathString string `json:"absolute_path_string"`
Contents []LSFile `json:"contents"`
}
type LSFile struct {
Name string `json:"name"`
// e.g. "C:\\Users\\coder\\hello.txt"
// "/home/coder/hello.txt"
AbsolutePathString string `json:"absolute_path_string"`
IsDir bool `json:"is_dir"`
}
type LSRelativity string
const (
LSRelativityRoot LSRelativity = "root"
LSRelativityHome LSRelativity = "home"
)
+38 -76
View File
@@ -6,103 +6,67 @@ import (
"runtime"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type testFs struct {
afero.Fs
}
func newTestFs(base afero.Fs) *testFs {
return &testFs{
Fs: base,
}
}
func (*testFs) Open(name string) (afero.File, error) {
return nil, os.ErrPermission
}
func TestListFilesWithQueryParam(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
query := workspacesdk.LSRequest{}
_, err := listFiles(fs, "not-relative", query)
require.Error(t, err)
require.Contains(t, err.Error(), "must be absolute")
tmpDir := t.TempDir()
err = fs.MkdirAll(tmpDir, 0o755)
require.NoError(t, err)
res, err := listFiles(fs, tmpDir, query)
require.NoError(t, err)
require.Len(t, res.Contents, 0)
}
func TestListFilesNonExistentDirectory(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
query := workspacesdk.LSRequest{
query := LSRequest{
Path: []string{"idontexist"},
Relativity: workspacesdk.LSRelativityHome,
Relativity: LSRelativityHome,
}
_, err := listFiles(fs, "", query)
_, err := listFiles(query)
require.ErrorIs(t, err, os.ErrNotExist)
}
func TestListFilesPermissionDenied(t *testing.T) {
t.Parallel()
fs := newTestFs(afero.NewMemMapFs())
if runtime.GOOS == "windows" {
t.Skip("creating an unreadable-by-user directory is non-trivial on Windows")
}
home, err := os.UserHomeDir()
require.NoError(t, err)
tmpDir := t.TempDir()
reposDir := filepath.Join(tmpDir, "repos")
err = fs.MkdirAll(reposDir, 0o000)
err = os.Mkdir(reposDir, 0o000)
require.NoError(t, err)
rel, err := filepath.Rel(home, reposDir)
require.NoError(t, err)
query := workspacesdk.LSRequest{
query := LSRequest{
Path: pathToArray(rel),
Relativity: workspacesdk.LSRelativityHome,
Relativity: LSRelativityHome,
}
_, err = listFiles(fs, "", query)
_, err = listFiles(query)
require.ErrorIs(t, err, os.ErrPermission)
}
func TestListFilesNotADirectory(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
home, err := os.UserHomeDir()
require.NoError(t, err)
tmpDir := t.TempDir()
err = fs.MkdirAll(tmpDir, 0o755)
require.NoError(t, err)
filePath := filepath.Join(tmpDir, "file.txt")
err = afero.WriteFile(fs, filePath, []byte("content"), 0o600)
err = os.WriteFile(filePath, []byte("content"), 0o600)
require.NoError(t, err)
rel, err := filepath.Rel(home, filePath)
require.NoError(t, err)
query := workspacesdk.LSRequest{
query := LSRequest{
Path: pathToArray(rel),
Relativity: workspacesdk.LSRelativityHome,
Relativity: LSRelativityHome,
}
_, err = listFiles(fs, "", query)
_, err = listFiles(query)
require.ErrorContains(t, err, "is not a directory")
}
@@ -112,7 +76,7 @@ func TestListFilesSuccess(t *testing.T) {
tc := []struct {
name string
baseFunc func(t *testing.T) string
relativity workspacesdk.LSRelativity
relativity LSRelativity
}{
{
name: "home",
@@ -121,7 +85,7 @@ func TestListFilesSuccess(t *testing.T) {
require.NoError(t, err)
return home
},
relativity: workspacesdk.LSRelativityHome,
relativity: LSRelativityHome,
},
{
name: "root",
@@ -131,7 +95,7 @@ func TestListFilesSuccess(t *testing.T) {
}
return "/"
},
relativity: workspacesdk.LSRelativityRoot,
relativity: LSRelativityRoot,
},
}
@@ -140,20 +104,19 @@ func TestListFilesSuccess(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fs := afero.NewMemMapFs()
base := tc.baseFunc(t)
tmpDir := t.TempDir()
reposDir := filepath.Join(tmpDir, "repos")
err := fs.MkdirAll(reposDir, 0o755)
err := os.Mkdir(reposDir, 0o755)
require.NoError(t, err)
downloadsDir := filepath.Join(tmpDir, "Downloads")
err = fs.MkdirAll(downloadsDir, 0o755)
err = os.Mkdir(downloadsDir, 0o755)
require.NoError(t, err)
textFile := filepath.Join(tmpDir, "file.txt")
err = afero.WriteFile(fs, textFile, []byte("content"), 0o600)
err = os.WriteFile(textFile, []byte("content"), 0o600)
require.NoError(t, err)
var queryComponents []string
@@ -166,16 +129,16 @@ func TestListFilesSuccess(t *testing.T) {
queryComponents = pathToArray(rel)
}
query := workspacesdk.LSRequest{
query := LSRequest{
Path: queryComponents,
Relativity: tc.relativity,
}
resp, err := listFiles(fs, "", query)
resp, err := listFiles(query)
require.NoError(t, err)
require.Equal(t, tmpDir, resp.AbsolutePathString)
// Output is sorted
require.Equal(t, []workspacesdk.LSFile{
require.Equal(t, []LSFile{
{
Name: "Downloads",
AbsolutePathString: downloadsDir,
@@ -203,44 +166,43 @@ func TestListFilesListDrives(t *testing.T) {
t.Skip("skipping test on non-Windows OS")
}
fs := afero.NewOsFs()
query := workspacesdk.LSRequest{
query := LSRequest{
Path: []string{},
Relativity: workspacesdk.LSRelativityRoot,
Relativity: LSRelativityRoot,
}
resp, err := listFiles(fs, "", query)
resp, err := listFiles(query)
require.NoError(t, err)
require.Contains(t, resp.Contents, workspacesdk.LSFile{
require.Contains(t, resp.Contents, LSFile{
Name: "C:\\",
AbsolutePathString: "C:\\",
IsDir: true,
})
query = workspacesdk.LSRequest{
query = LSRequest{
Path: []string{"C:\\"},
Relativity: workspacesdk.LSRelativityRoot,
Relativity: LSRelativityRoot,
}
resp, err = listFiles(fs, "", query)
resp, err = listFiles(query)
require.NoError(t, err)
query = workspacesdk.LSRequest{
query = LSRequest{
Path: resp.AbsolutePath,
Relativity: workspacesdk.LSRelativityRoot,
Relativity: LSRelativityRoot,
}
resp, err = listFiles(fs, "", query)
resp, err = listFiles(query)
require.NoError(t, err)
// System directory should always exist
require.Contains(t, resp.Contents, workspacesdk.LSFile{
require.Contains(t, resp.Contents, LSFile{
Name: "Windows",
AbsolutePathString: "C:\\Windows",
IsDir: true,
})
query = workspacesdk.LSRequest{
query = LSRequest{
// Network drives are not supported.
Path: []string{"\\sshfs\\work"},
Relativity: workspacesdk.LSRelativityRoot,
Relativity: LSRelativityRoot,
}
resp, err = listFiles(fs, "", query)
resp, err = listFiles(query)
require.ErrorContains(t, err, "drive")
}
-5
View File
@@ -25,7 +25,6 @@ import (
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
type screenReconnectingPTY struct {
logger slog.Logger
execer agentexec.Execer
command *pty.Cmd
@@ -63,7 +62,6 @@ type screenReconnectingPTY struct {
// own which causes it to spawn with the specified size.
func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
rpty := &screenReconnectingPTY{
logger: logger,
execer: execer,
command: cmd,
metrics: options.Metrics,
@@ -175,7 +173,6 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
ptty, process, err := rpty.doAttach(ctx, conn, height, width, logger)
if err != nil {
logger.Debug(ctx, "unable to attach to screen reconnecting pty", slog.Error(err))
if errors.Is(err, context.Canceled) {
// Likely the process was too short-lived and canceled the version command.
// TODO: Is it worth distinguishing between that and a cancel from the
@@ -185,7 +182,6 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
}
return err
}
logger.Debug(ctx, "attached to screen reconnecting pty")
defer func() {
// Log only for debugging since the process might have already exited on its
@@ -407,7 +403,6 @@ func (rpty *screenReconnectingPTY) Wait() {
}
func (rpty *screenReconnectingPTY) Close(err error) {
rpty.logger.Debug(context.Background(), "closing screen reconnecting pty", slog.Error(err))
// The closing state change will be handled by the lifecycle.
rpty.state.setState(StateClosing, err)
}
-174
View File
@@ -1,174 +0,0 @@
package unit
import (
"fmt"
"sync"
"golang.org/x/xerrors"
"gonum.org/v1/gonum/graph/encoding/dot"
"gonum.org/v1/gonum/graph/simple"
"gonum.org/v1/gonum/graph/topo"
)
// Graph provides a bidirectional interface over gonum's directed graph implementation.
// While the underlying gonum graph is directed, we overlay bidirectional semantics
// by distinguishing between forward and reverse edges. Wanting and being wanted by
// other units are related but different concepts that have different graph traversal
// implications when Units update their status.
//
// The graph stores edge types to represent different relationships between units,
// allowing for domain-specific semantics beyond simple connectivity.
type Graph[EdgeType, VertexType comparable] struct {
mu sync.RWMutex
// The underlying gonum graph. It stores vertices and edges without knowing about the types of the vertices and edges.
gonumGraph *simple.DirectedGraph
// Maps vertices to their IDs so that a gonum vertex ID can be used to lookup the vertex type.
vertexToID map[VertexType]int64
// Maps vertex IDs to their types so that a vertex type can be used to lookup the gonum vertex ID.
idToVertex map[int64]VertexType
// The next ID to assign to a vertex.
nextID int64
// Store edge types by "fromID->toID" key. This is used to lookup the edge type for a given edge.
edgeTypes map[string]EdgeType
}
// Edge is a convenience type for representing an edge in the graph.
// It encapsulates the from and to vertices and the edge type itself.
type Edge[EdgeType, VertexType comparable] struct {
From VertexType
To VertexType
Edge EdgeType
}
// AddEdge adds an edge to the graph. It initializes the graph and metadata on first use,
// checks for cycles, and adds the edge to the gonum graph.
func (g *Graph[EdgeType, VertexType]) AddEdge(from, to VertexType, edge EdgeType) error {
g.mu.Lock()
defer g.mu.Unlock()
if g.gonumGraph == nil {
g.gonumGraph = simple.NewDirectedGraph()
g.vertexToID = make(map[VertexType]int64)
g.idToVertex = make(map[int64]VertexType)
g.edgeTypes = make(map[string]EdgeType)
g.nextID = 1
}
fromID := g.getOrCreateVertexID(from)
toID := g.getOrCreateVertexID(to)
if g.canReach(to, from) {
return xerrors.Errorf("adding edge (%v -> %v) would create a cycle", from, to)
}
g.gonumGraph.SetEdge(simple.Edge{F: simple.Node(fromID), T: simple.Node(toID)})
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
g.edgeTypes[edgeKey] = edge
return nil
}
// GetForwardAdjacentVertices returns all the edges that originate from the given vertex.
func (g *Graph[EdgeType, VertexType]) GetForwardAdjacentVertices(from VertexType) []Edge[EdgeType, VertexType] {
g.mu.RLock()
defer g.mu.RUnlock()
fromID, exists := g.vertexToID[from]
if !exists {
return []Edge[EdgeType, VertexType]{}
}
edges := []Edge[EdgeType, VertexType]{}
toNodes := g.gonumGraph.From(fromID)
for toNodes.Next() {
toID := toNodes.Node().ID()
to := g.idToVertex[toID]
// Get the edge type
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
edgeType := g.edgeTypes[edgeKey]
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
}
return edges
}
// GetReverseAdjacentVertices returns all the edges that terminate at the given vertex.
func (g *Graph[EdgeType, VertexType]) GetReverseAdjacentVertices(to VertexType) []Edge[EdgeType, VertexType] {
g.mu.RLock()
defer g.mu.RUnlock()
toID, exists := g.vertexToID[to]
if !exists {
return []Edge[EdgeType, VertexType]{}
}
edges := []Edge[EdgeType, VertexType]{}
fromNodes := g.gonumGraph.To(toID)
for fromNodes.Next() {
fromID := fromNodes.Node().ID()
from := g.idToVertex[fromID]
// Get the edge type
edgeKey := fmt.Sprintf("%d->%d", fromID, toID)
edgeType := g.edgeTypes[edgeKey]
edges = append(edges, Edge[EdgeType, VertexType]{From: from, To: to, Edge: edgeType})
}
return edges
}
// getOrCreateVertexID returns the ID for a vertex, creating it if it doesn't exist.
func (g *Graph[EdgeType, VertexType]) getOrCreateVertexID(vertex VertexType) int64 {
if id, exists := g.vertexToID[vertex]; exists {
return id
}
id := g.nextID
g.nextID++
g.vertexToID[vertex] = id
g.idToVertex[id] = vertex
// Add the node to the gonum graph
g.gonumGraph.AddNode(simple.Node(id))
return id
}
// canReach checks if there is a path from the start vertex to the end vertex.
func (g *Graph[EdgeType, VertexType]) canReach(start, end VertexType) bool {
if start == end {
return true
}
startID, startExists := g.vertexToID[start]
endID, endExists := g.vertexToID[end]
if !startExists || !endExists {
return false
}
// Use gonum's built-in path existence check
return topo.PathExistsIn(g.gonumGraph, simple.Node(startID), simple.Node(endID))
}
// ToDOT exports the graph to DOT format for visualization
func (g *Graph[EdgeType, VertexType]) ToDOT(name string) (string, error) {
g.mu.RLock()
defer g.mu.RUnlock()
if g.gonumGraph == nil {
return "", xerrors.New("graph is not initialized")
}
// Marshal the graph to DOT format
dotBytes, err := dot.Marshal(g.gonumGraph, name, "", " ")
if err != nil {
return "", xerrors.Errorf("failed to marshal graph to DOT: %w", err)
}
return string(dotBytes), nil
}
-454
View File
@@ -1,454 +0,0 @@
// Package unit_test provides tests for the unit package.
//
// DOT Graph Testing:
// The graph tests use golden files for DOT representation verification.
// To update the golden files:
// make gen/golden-files
//
// The golden files contain the expected DOT representation and can be easily
// inspected, version controlled, and updated when the graph structure changes.
package unit_test
import (
"bytes"
"flag"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/unit"
"github.com/coder/coder/v2/cryptorand"
)
type testGraphEdge string
const (
testEdgeStarted testGraphEdge = "started"
testEdgeCompleted testGraphEdge = "completed"
)
type testGraphVertex struct {
Name string
}
type (
testGraph = unit.Graph[testGraphEdge, *testGraphVertex]
testEdge = unit.Edge[testGraphEdge, *testGraphVertex]
)
// randInt generates a random integer in the range [0, limit).
func randInt(limit int) int {
if limit <= 0 {
return 0
}
n, err := cryptorand.Int63n(int64(limit))
if err != nil {
return 0
}
return int(n)
}
// UpdateGoldenFiles indicates golden files should be updated.
// To update the golden files:
// make gen/golden-files
var UpdateGoldenFiles = flag.Bool("update", false, "update .golden files")
// assertDOTGraph requires that the graph's DOT representation matches the golden file
func assertDOTGraph(t *testing.T, graph *testGraph, goldenName string) {
t.Helper()
dot, err := graph.ToDOT(goldenName)
require.NoError(t, err)
goldenFile := filepath.Join("testdata", goldenName+".golden")
if *UpdateGoldenFiles {
t.Logf("update golden file for: %q: %s", goldenName, goldenFile)
err := os.MkdirAll(filepath.Dir(goldenFile), 0o755)
require.NoError(t, err, "want no error creating golden file directory")
err = os.WriteFile(goldenFile, []byte(dot), 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenFile)
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
// Normalize line endings for cross-platform compatibility
expected = normalizeLineEndings(expected)
normalizedDot := normalizeLineEndings([]byte(dot))
assert.Empty(t, cmp.Diff(string(expected), string(normalizedDot)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenFile)
}
// normalizeLineEndings ensures that all line endings are normalized to \n.
// Required for Windows compatibility.
func normalizeLineEndings(content []byte) []byte {
content = bytes.ReplaceAll(content, []byte("\r\n"), []byte("\n"))
content = bytes.ReplaceAll(content, []byte("\r"), []byte("\n"))
return content
}
func TestGraph(t *testing.T) {
t.Parallel()
testFuncs := map[string]func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex]{
"ForwardAndReverseEdges": func(t *testing.T) *unit.Graph[testGraphEdge, *testGraphVertex] {
graph := &unit.Graph[testGraphEdge, *testGraphVertex]{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
unit3 := &testGraphVertex{Name: "unit3"}
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit1, unit3, testEdgeStarted)
require.NoError(t, err)
// Check for forward edge
vertices := graph.GetForwardAdjacentVertices(unit1)
require.Len(t, vertices, 2)
// Unit 1 depends on the completion of Unit2
require.Contains(t, vertices, testEdge{
From: unit1,
To: unit2,
Edge: testEdgeCompleted,
})
// Unit 1 depends on the start of Unit3
require.Contains(t, vertices, testEdge{
From: unit1,
To: unit3,
Edge: testEdgeStarted,
})
// Check for reverse edges
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
require.Len(t, unit2ReverseEdges, 1)
// Unit 2 must be completed before Unit 1 can start
require.Contains(t, unit2ReverseEdges, testEdge{
From: unit1,
To: unit2,
Edge: testEdgeCompleted,
})
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
require.Len(t, unit3ReverseEdges, 1)
// Unit 3 must be started before Unit 1 can complete
require.Contains(t, unit3ReverseEdges, testEdge{
From: unit1,
To: unit3,
Edge: testEdgeStarted,
})
return graph
},
"SelfReference": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
err := graph.AddEdge(unit1, unit1, testEdgeCompleted)
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit1, unit1))
return graph
},
"Cycle": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit2, unit1, testEdgeStarted)
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf("adding edge (%v -> %v) would create a cycle", unit2, unit1))
return graph
},
"MultipleDependenciesSameStatus": func(t *testing.T) *testGraph {
graph := &testGraph{}
unit1 := &testGraphVertex{Name: "unit1"}
unit2 := &testGraphVertex{Name: "unit2"}
unit3 := &testGraphVertex{Name: "unit3"}
unit4 := &testGraphVertex{Name: "unit4"}
// Unit1 depends on completion of both unit2 and unit3 (same status type)
err := graph.AddEdge(unit1, unit2, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unit1, unit3, testEdgeCompleted)
require.NoError(t, err)
// Unit1 also depends on starting of unit4 (different status type)
err = graph.AddEdge(unit1, unit4, testEdgeStarted)
require.NoError(t, err)
// Check that unit1 has 3 forward dependencies
forwardEdges := graph.GetForwardAdjacentVertices(unit1)
require.Len(t, forwardEdges, 3)
// Verify all expected dependencies exist
expectedDependencies := []testEdge{
{From: unit1, To: unit2, Edge: testEdgeCompleted},
{From: unit1, To: unit3, Edge: testEdgeCompleted},
{From: unit1, To: unit4, Edge: testEdgeStarted},
}
for _, expected := range expectedDependencies {
require.Contains(t, forwardEdges, expected)
}
// Check reverse dependencies
unit2ReverseEdges := graph.GetReverseAdjacentVertices(unit2)
require.Len(t, unit2ReverseEdges, 1)
require.Contains(t, unit2ReverseEdges, testEdge{
From: unit1, To: unit2, Edge: testEdgeCompleted,
})
unit3ReverseEdges := graph.GetReverseAdjacentVertices(unit3)
require.Len(t, unit3ReverseEdges, 1)
require.Contains(t, unit3ReverseEdges, testEdge{
From: unit1, To: unit3, Edge: testEdgeCompleted,
})
unit4ReverseEdges := graph.GetReverseAdjacentVertices(unit4)
require.Len(t, unit4ReverseEdges, 1)
require.Contains(t, unit4ReverseEdges, testEdge{
From: unit1, To: unit4, Edge: testEdgeStarted,
})
return graph
},
}
for testName, testFunc := range testFuncs {
var graph *testGraph
t.Run(testName, func(t *testing.T) {
t.Parallel()
graph = testFunc(t)
assertDOTGraph(t, graph, testName)
})
}
}
func TestGraphThreadSafety(t *testing.T) {
t.Parallel()
t.Run("ConcurrentReadWrite", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
var wg sync.WaitGroup
const numWriters = 50
const numReaders = 100
const operationsPerWriter = 1000
const operationsPerReader = 2000
barrier := make(chan struct{})
// Launch writers
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(writerID int) {
defer wg.Done()
<-barrier
for j := 0; j < operationsPerWriter; j++ {
from := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j)}
to := &testGraphVertex{Name: fmt.Sprintf("writer-%d-%d", writerID, j+1)}
graph.AddEdge(from, to, testEdgeCompleted)
}
}(i)
}
// Launch readers
readerResults := make([]struct {
panicked bool
readCount int
}, numReaders)
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func(readerID int) {
defer wg.Done()
<-barrier
defer func() {
if r := recover(); r != nil {
readerResults[readerID].panicked = true
}
}()
readCount := 0
for j := 0; j < operationsPerReader; j++ {
// Create a test vertex and read
testUnit := &testGraphVertex{Name: fmt.Sprintf("test-reader-%d-%d", readerID, j)}
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
// Just verify no panics (results may be nil for non-existent vertices)
_ = forwardEdges
_ = reverseEdges
readCount++
}
readerResults[readerID].readCount = readCount
}(i)
}
close(barrier)
wg.Wait()
// Verify no panics occurred in readers
for i, result := range readerResults {
require.False(t, result.panicked, "reader %d panicked", i)
require.Equal(t, operationsPerReader, result.readCount, "reader %d should have performed expected reads", i)
}
})
t.Run("ConcurrentCycleDetection", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
// Pre-create chain: A→B→C→D
unitA := &testGraphVertex{Name: "A"}
unitB := &testGraphVertex{Name: "B"}
unitC := &testGraphVertex{Name: "C"}
unitD := &testGraphVertex{Name: "D"}
err := graph.AddEdge(unitA, unitB, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unitB, unitC, testEdgeCompleted)
require.NoError(t, err)
err = graph.AddEdge(unitC, unitD, testEdgeCompleted)
require.NoError(t, err)
barrier := make(chan struct{})
var wg sync.WaitGroup
const numGoroutines = 50
cycleErrors := make([]error, numGoroutines)
// Launch goroutines trying to add D→A (creates cycle)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
<-barrier
err := graph.AddEdge(unitD, unitA, testEdgeCompleted)
cycleErrors[goroutineID] = err
}(i)
}
close(barrier)
wg.Wait()
// Verify all attempts correctly returned cycle error
for i, err := range cycleErrors {
require.Error(t, err, "goroutine %d should have detected cycle", i)
require.Contains(t, err.Error(), "would create a cycle")
}
// Verify graph remains valid (original chain intact)
dot, err := graph.ToDOT("test")
require.NoError(t, err)
require.NotEmpty(t, dot)
})
t.Run("ConcurrentToDOT", func(t *testing.T) {
t.Parallel()
graph := &testGraph{}
// Pre-populate graph
for i := 0; i < 20; i++ {
from := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i)}
to := &testGraphVertex{Name: fmt.Sprintf("dot-unit-%d", i+1)}
err := graph.AddEdge(from, to, testEdgeCompleted)
require.NoError(t, err)
}
barrier := make(chan struct{})
var wg sync.WaitGroup
const numReaders = 100
const numWriters = 20
dotResults := make([]string, numReaders)
// Launch readers calling ToDOT
dotErrors := make([]error, numReaders)
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func(readerID int) {
defer wg.Done()
<-barrier
dot, err := graph.ToDOT(fmt.Sprintf("test-%d", readerID))
dotErrors[readerID] = err
if err == nil {
dotResults[readerID] = dot
}
}(i)
}
// Launch writers adding edges
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(writerID int) {
defer wg.Done()
<-barrier
from := &testGraphVertex{Name: fmt.Sprintf("writer-dot-%d", writerID)}
to := &testGraphVertex{Name: fmt.Sprintf("writer-dot-target-%d", writerID)}
graph.AddEdge(from, to, testEdgeCompleted)
}(i)
}
close(barrier)
wg.Wait()
// Verify no errors occurred during DOT generation
for i, err := range dotErrors {
require.NoError(t, err, "DOT generation error at index %d", i)
}
// Verify all DOT results are valid
for i, dot := range dotResults {
require.NotEmpty(t, dot, "DOT result %d should not be empty", i)
}
})
}
func BenchmarkGraph_ConcurrentMixedOperations(b *testing.B) {
graph := &testGraph{}
var wg sync.WaitGroup
const numGoroutines = 200
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Launch goroutines performing random operations
for j := 0; j < numGoroutines; j++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
operationCount := 0
for operationCount < 50 {
operation := float32(randInt(100)) / 100.0
if operation < 0.6 { // 60% reads
// Read operation
testUnit := &testGraphVertex{Name: fmt.Sprintf("bench-read-%d-%d", goroutineID, operationCount)}
forwardEdges := graph.GetForwardAdjacentVertices(testUnit)
reverseEdges := graph.GetReverseAdjacentVertices(testUnit)
// Just verify no panics (results may be nil for non-existent vertices)
_ = forwardEdges
_ = reverseEdges
} else { // 40% writes
// Write operation
from := &testGraphVertex{Name: fmt.Sprintf("bench-write-%d-%d", goroutineID, operationCount)}
to := &testGraphVertex{Name: fmt.Sprintf("bench-write-target-%d-%d", goroutineID, operationCount)}
graph.AddEdge(from, to, testEdgeCompleted)
}
operationCount++
}
}(j)
}
wg.Wait()
}
}
-8
View File
@@ -1,8 +0,0 @@
strict digraph Cycle {
// Node definitions.
1;
2;
// Edge definitions.
1 -> 2;
}
-10
View File
@@ -1,10 +0,0 @@
strict digraph ForwardAndReverseEdges {
// Node definitions.
1;
2;
3;
// Edge definitions.
1 -> 2;
1 -> 3;
}
@@ -1,12 +0,0 @@
strict digraph MultipleDependenciesSameStatus {
// Node definitions.
1;
2;
3;
4;
// Edge definitions.
1 -> 2;
1 -> 3;
1 -> 4;
}
-4
View File
@@ -1,4 +0,0 @@
strict digraph SelfReference {
// Node definitions.
1;
}
+11 -5
View File
@@ -6,7 +6,10 @@
"defaultBranch": "main"
},
"files": {
"includes": ["**", "!**/pnpm-lock.yaml"],
"includes": [
"**",
"!**/pnpm-lock.yaml"
],
"ignoreUnknown": true
},
"linter": {
@@ -45,14 +48,13 @@
"options": {
"paths": {
"@mui/material": "Use @mui/material/<name> instead. See: https://material-ui.com/guides/minimizing-bundle-size/.",
"@mui/icons-material": "Use @mui/icons-material/<name> instead. See: https://material-ui.com/guides/minimizing-bundle-size/.",
"@mui/material/Avatar": "Use components/Avatar/Avatar instead.",
"@mui/material/Alert": "Use components/Alert/Alert instead.",
"@mui/material/Popover": "Use components/Popover/Popover instead.",
"@mui/material/Typography": "Use native HTML elements instead. Eg: <span>, <p>, <h1>, etc.",
"@mui/material/Box": "Use a <div> instead.",
"@mui/material/Button": "Use a components/Button/Button instead.",
"@mui/material/styles": "Import from @emotion/react instead.",
"@mui/material/Table*": "Import from components/Table/Table instead.",
"lodash": "Use lodash/<name> instead."
}
}
@@ -67,7 +69,11 @@
"noConsole": {
"level": "error",
"options": {
"allow": ["error", "info", "warn"]
"allow": [
"error",
"info",
"warn"
]
}
}
},
@@ -76,5 +82,5 @@
}
}
},
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
"$schema": "https://biomejs.dev/schemas/2.2.0/schema.json"
}
+91 -12
View File
@@ -15,6 +15,7 @@ import (
"strings"
"time"
"cloud.google.com/go/compute/metadata"
"golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2"
@@ -37,8 +38,9 @@ import (
"github.com/coder/coder/v2/codersdk/agentsdk"
)
func workspaceAgent() *serpent.Command {
func (r *RootCmd) workspaceAgent() *serpent.Command {
var (
auth string
logDir string
scriptDataDir string
pprofAddress string
@@ -57,7 +59,6 @@ func workspaceAgent() *serpent.Command {
devcontainerProjectDiscovery bool
devcontainerDiscoveryAutostart bool
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "agent",
Short: `Starts the Coder workspace agent.`,
@@ -175,14 +176,12 @@ func workspaceAgent() *serpent.Command {
version := buildinfo.Version()
logger.Info(ctx, "agent is starting now",
slog.F("url", agentAuth.agentURL),
slog.F("auth", agentAuth.agentAuth),
slog.F("url", r.agentURL),
slog.F("auth", auth),
slog.F("version", version),
)
client, err := agentAuth.CreateClient()
if err != nil {
return xerrors.Errorf("create agent client: %w", err)
}
client := agentsdk.New(r.agentURL)
client.SDK.SetLogger(logger)
// Set a reasonable timeout so requests can't hang forever!
// The timeout needs to be reasonably long, because requests
@@ -191,7 +190,7 @@ func workspaceAgent() *serpent.Command {
client.SDK.HTTPClient.Timeout = 30 * time.Second
// Attach header transport so we process --agent-header and
// --agent-header-command flags
headerTransport, err := headerTransport(ctx, &agentAuth.agentURL, agentHeader, agentHeaderCommand)
headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand)
if err != nil {
return xerrors.Errorf("configure header transport: %w", err)
}
@@ -215,6 +214,68 @@ func workspaceAgent() *serpent.Command {
ignorePorts[port] = "debug"
}
// exchangeToken returns a session token.
// This is abstracted to allow for the same looping condition
// regardless of instance identity auth type.
var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error)
switch auth {
case "token":
token, _ := inv.ParsedFlags().GetString(varAgentToken)
if token == "" {
tokenFile, _ := inv.ParsedFlags().GetString(varAgentTokenFile)
if tokenFile != "" {
tokenBytes, err := os.ReadFile(tokenFile)
if err != nil {
return xerrors.Errorf("read token file %q: %w", tokenFile, err)
}
token = strings.TrimSpace(string(tokenBytes))
}
}
if token == "" {
return xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth")
}
client.SetSessionToken(token)
case "google-instance-identity":
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var gcpClient *metadata.Client
gcpClientRaw := ctx.Value("gcp-client")
if gcpClientRaw != nil {
gcpClient, _ = gcpClientRaw.(*metadata.Client)
}
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient)
}
case "aws-instance-identity":
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var awsClient *http.Client
awsClientRaw := ctx.Value("aws-client")
if awsClientRaw != nil {
awsClient, _ = awsClientRaw.(*http.Client)
if awsClient != nil {
client.SDK.HTTPClient = awsClient
}
}
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
return client.AuthAWSInstanceIdentity(ctx)
}
case "azure-instance-identity":
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
var azureClient *http.Client
azureClientRaw := ctx.Value("azure-client")
if azureClientRaw != nil {
azureClient, _ = azureClientRaw.(*http.Client)
if azureClient != nil {
client.SDK.HTTPClient = azureClient
}
}
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
return client.AuthAzureInstanceIdentity(ctx)
}
}
executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
@@ -282,7 +343,18 @@ func workspaceAgent() *serpent.Command {
LogDir: logDir,
ScriptDataDir: scriptDataDir,
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
TailnetListenPort: uint16(tailnetListenPort),
TailnetListenPort: uint16(tailnetListenPort),
ExchangeToken: func(ctx context.Context) (string, error) {
if exchangeToken == nil {
return client.SDK.SessionToken(), nil
}
resp, err := exchangeToken(ctx)
if err != nil {
return "", err
}
client.SetSessionToken(resp.SessionToken)
return resp.SessionToken, nil
},
EnvironmentVariables: environmentVariables,
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
@@ -293,7 +365,7 @@ func workspaceAgent() *serpent.Command {
Execer: execer,
Devcontainers: devcontainers,
DevcontainerAPIOptions: []agentcontainers.Option{
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
agentcontainers.WithSubAgentURL(r.agentURL.String()),
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
},
@@ -328,6 +400,13 @@ func workspaceAgent() *serpent.Command {
}
cmd.Options = serpent.OptionSet{
{
Flag: "auth",
Default: "token",
Description: "Specify the authentication type to use for the agent.",
Env: "CODER_AGENT_AUTH",
Value: serpent.StringOf(&auth),
},
{
Flag: "log-dir",
Default: os.TempDir(),
@@ -450,7 +529,7 @@ func workspaceAgent() *serpent.Command {
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
+157
View File
@@ -1,6 +1,7 @@
package cli_test
import (
"context"
"fmt"
"net/http"
"os"
@@ -10,6 +11,7 @@ import (
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -19,7 +21,10 @@ import (
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -59,6 +64,158 @@ func TestWorkspaceAgent(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalMedium)
})
t.Run("Azure", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AzureCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "azure-client", metadataClient),
)
ctx := inv.Context()
clitest.Start(t, inv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("AWS", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
AWSCertificates: certificates,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
inv = inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "aws-client", metadataClient),
)
clitest.Start(t, inv)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
})
t.Run("GoogleCloud", func(t *testing.T) {
t.Parallel()
instanceID := "instanceidentifier"
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
db, ps := dbtestutil.NewDB(t,
dbtestutil.WithDumpOnFailure(),
)
client := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
GoogleTokenValidator: validator,
})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: memberUser.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
return agents
}).Do()
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
clitest.SetupConfig(t, member, cfg)
clitest.Start(t,
inv.WithContext(
//nolint:revive,staticcheck
context.WithValue(inv.Context(), "gcp-client", metadataClient),
),
)
ctx := inv.Context()
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
MatchResources(matchAgentWithVersion).
Wait()
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
resources := workspace.LatestBuild.Resources
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
key := "CODER_AGENT_TOKEN"
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
token, err := session.CombinedOutput(command)
require.NoError(t, err)
_, err = uuid.Parse(strings.TrimSpace(string(token)))
require.NoError(t, err)
})
t.Run("PostStartup", func(t *testing.T) {
t.Parallel()
-78
View File
@@ -1,78 +0,0 @@
package cli
import (
"encoding/csv"
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
var (
_ pflag.SliceValue = &AllowListFlag{}
_ pflag.Value = &AllowListFlag{}
)
// AllowListFlag implements pflag.SliceValue for codersdk.APIAllowListTarget entries.
type AllowListFlag []codersdk.APIAllowListTarget
func AllowListFlagOf(al *[]codersdk.APIAllowListTarget) *AllowListFlag {
return (*AllowListFlag)(al)
}
func (a AllowListFlag) String() string {
return strings.Join(a.GetSlice(), ",")
}
func (a AllowListFlag) Value() []codersdk.APIAllowListTarget {
return []codersdk.APIAllowListTarget(a)
}
func (AllowListFlag) Type() string { return "allow-list" }
func (a *AllowListFlag) Set(set string) error {
values, err := csv.NewReader(strings.NewReader(set)).Read()
if err != nil {
return xerrors.Errorf("parse allow list entries as csv: %w", err)
}
for _, v := range values {
if err := a.Append(v); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) Append(value string) error {
value = strings.TrimSpace(value)
if value == "" {
return xerrors.New("allow list entry cannot be empty")
}
var target codersdk.APIAllowListTarget
if err := target.UnmarshalText([]byte(value)); err != nil {
return err
}
*a = append(*a, target)
return nil
}
func (a *AllowListFlag) Replace(items []string) error {
*a = []codersdk.APIAllowListTarget{}
for _, item := range items {
if err := a.Append(item); err != nil {
return err
}
}
return nil
}
func (a *AllowListFlag) GetSlice() []string {
out := make([]string, len(*a))
for i, entry := range *a {
out[i] = entry.String()
}
return out
}
+3 -6
View File
@@ -12,21 +12,18 @@ import (
)
func (r *RootCmd) autoupdate() *serpent.Command {
client := new(codersdk.Client)
cmd := &serpent.Command{
Annotations: workspaceCommand,
Use: "autoupdate <workspace> <always|never>",
Short: "Toggle auto-update policy for a workspace",
Middleware: serpent.Chain(
serpent.RequireNArgs(2),
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
policy := strings.ToLower(inv.Args[1])
err = validateAutoUpdatePolicy(policy)
err := validateAutoUpdatePolicy(policy)
if err != nil {
return xerrors.Errorf("validate policy: %w", err)
}
+1 -26
View File
@@ -53,9 +53,6 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
t := time.NewTimer(0)
defer t.Stop()
startTime := time.Now()
baseInterval := opts.FetchInterval
for {
select {
case <-ctx.Done():
@@ -71,11 +68,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
return
}
fetchedAgent <- fetchAgent{agent: agent}
// Adjust the interval based on how long we've been waiting.
elapsed := time.Since(startTime)
currentInterval := GetProgressiveInterval(baseInterval, elapsed)
t.Reset(currentInterval)
t.Reset(opts.FetchInterval)
}
}
}()
@@ -300,24 +293,6 @@ func safeDuration(sw *stageWriter, a, b *time.Time) time.Duration {
return a.Sub(*b)
}
// GetProgressiveInterval returns an interval that increases over time.
// The interval starts at baseInterval and increases to
// a maximum of baseInterval * 16 over time.
func GetProgressiveInterval(baseInterval time.Duration, elapsed time.Duration) time.Duration {
switch {
case elapsed < 60*time.Second:
return baseInterval // 500ms for first 60 seconds
case elapsed < 2*time.Minute:
return baseInterval * 2 // 1s for next 1 minute
case elapsed < 5*time.Minute:
return baseInterval * 4 // 2s for next 3 minutes
case elapsed < 10*time.Minute:
return baseInterval * 8 // 4s for next 5 minutes
default:
return baseInterval * 16 // 8s after 10 minutes
}
}
type closeFunc func() error
func (c closeFunc) Close() error {
-28
View File
@@ -866,31 +866,3 @@ func TestConnDiagnostics(t *testing.T) {
})
}
}
func TestGetProgressiveInterval(t *testing.T) {
t.Parallel()
baseInterval := 500 * time.Millisecond
testCases := []struct {
name string
elapsed time.Duration
expected time.Duration
}{
{"first_minute", 30 * time.Second, baseInterval},
{"second_minute", 90 * time.Second, baseInterval * 2},
{"third_to_fifth_minute", 3 * time.Minute, baseInterval * 4},
{"sixth_to_tenth_minute", 7 * time.Minute, baseInterval * 8},
{"after_ten_minutes", 15 * time.Minute, baseInterval * 16},
{"boundary_first_minute", 59 * time.Second, baseInterval},
{"boundary_second_minute", 61 * time.Second, baseInterval * 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := cliui.GetProgressiveInterval(baseInterval, tc.elapsed)
require.Equal(t, tc.expected, result)
})
}
}
+9 -18
View File
@@ -296,23 +296,22 @@ func renderTable(out any, sort string, headers table.Row, filterColumns []string
// returned. If the table tag is malformed, an error is returned.
//
// The returned name is transformed from "snake_case" to "normal text".
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName, emptyNil bool, err error) {
func parseTableStructTag(field reflect.StructField) (name string, defaultSort, noSortOpt, recursive, skipParentName bool, err error) {
tags, err := structtag.Parse(string(field.Tag))
if err != nil {
return "", false, false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
return "", false, false, false, false, xerrors.Errorf("parse struct field tag %q: %w", string(field.Tag), err)
}
tag, err := tags.Get("table")
if err != nil || tag.Name == "-" {
// tags.Get only returns an error if the tag is not found.
return "", false, false, false, false, false, nil
return "", false, false, false, false, nil
}
defaultSortOpt := false
noSortOpt = false
recursiveOpt := false
skipParentNameOpt := false
emptyNilOpt := false
for _, opt := range tag.Options {
switch opt {
case "default_sort":
@@ -327,14 +326,12 @@ func parseTableStructTag(field reflect.StructField) (name string, defaultSort, n
// make sure the child name is unique across all nested structs in the parent.
recursiveOpt = true
skipParentNameOpt = true
case "empty_nil":
emptyNilOpt = true
default:
return "", false, false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
return "", false, false, false, false, xerrors.Errorf("unknown option %q in struct field tag", opt)
}
}
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, emptyNilOpt, nil
return strings.ReplaceAll(tag.Name, "_", " "), defaultSortOpt, noSortOpt, recursiveOpt, skipParentNameOpt, nil
}
func isStructOrStructPointer(t reflect.Type) bool {
@@ -361,7 +358,7 @@ func typeToTableHeaders(t reflect.Type, requireDefault bool) ([]string, string,
noSortOpt := false
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
name, defaultSort, noSort, recursive, skip, _, err := parseTableStructTag(field)
name, defaultSort, noSort, recursive, skip, err := parseTableStructTag(field)
if err != nil {
return nil, "", xerrors.Errorf("parse struct tags for field %q in type %q: %w", field.Name, t.String(), err)
}
@@ -438,7 +435,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
for i := 0; i < val.NumField(); i++ {
field := val.Type().Field(i)
fieldVal := val.Field(i)
name, _, _, recursive, skip, emptyNil, err := parseTableStructTag(field)
name, _, _, recursive, skip, err := parseTableStructTag(field)
if err != nil {
return nil, xerrors.Errorf("parse struct tags for field %q in type %T: %w", field.Name, val, err)
}
@@ -446,14 +443,8 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
continue
}
fieldType := field.Type
// If empty_nil is set and this is a nil pointer, use a zero value.
if emptyNil && fieldVal.Kind() == reflect.Pointer && fieldVal.IsNil() {
fieldVal = reflect.New(fieldType.Elem())
}
// Recurse if it's a struct.
fieldType := field.Type
if recursive {
if !isStructOrStructPointer(fieldType) {
return nil, xerrors.Errorf("field %q in type %q is marked as recursive but does not contain a struct or a pointer to a struct", field.Name, fieldType.String())
@@ -476,7 +467,7 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
}
// Otherwise, we just use the field value.
row[name] = fieldVal.Interface()
row[name] = val.Field(i).Interface()
}
return row, nil
-72
View File
@@ -400,78 +400,6 @@ foo <nil> 10 [a, b, c] foo1 11 foo2 12 fo
})
})
})
t.Run("EmptyNil", func(t *testing.T) {
t.Parallel()
type emptyNilTest struct {
Name string `table:"name,default_sort"`
EmptyOnNil *string `table:"empty_on_nil,empty_nil"`
NormalBehavior *string `table:"normal_behavior"`
}
value := "value"
in := []emptyNilTest{
{
Name: "has_value",
EmptyOnNil: &value,
NormalBehavior: &value,
},
{
Name: "has_nil",
EmptyOnNil: nil,
NormalBehavior: nil,
},
}
expected := `
NAME EMPTY ON NIL NORMAL BEHAVIOR
has_nil <nil>
has_value value value
`
out, err := cliui.DisplayTable(in, "", nil)
log.Println("rendered table:\n" + out)
require.NoError(t, err)
compareTables(t, expected, out)
})
t.Run("EmptyNilWithRecursiveInline", func(t *testing.T) {
t.Parallel()
type nestedData struct {
Name string `table:"name"`
}
type inlineTest struct {
Nested *nestedData `table:"ignored,recursive_inline,empty_nil"`
Count int `table:"count,default_sort"`
}
in := []inlineTest{
{
Nested: &nestedData{
Name: "alice",
},
Count: 1,
},
{
Nested: nil,
Count: 2,
},
}
expected := `
NAME COUNT
alice 1
2
`
out, err := cliui.DisplayTable(in, "", nil)
log.Println("rendered table:\n" + out)
require.NoError(t, err)
compareTables(t, expected, out)
})
}
// compareTables normalizes the incoming table lines
+3 -5
View File
@@ -236,6 +236,7 @@ func (r *RootCmd) configSSH() *serpent.Command {
dryRun bool
coderCliPath string
)
client := new(codersdk.Client)
cmd := &serpent.Command{
Annotations: workspaceCommand,
Use: "config-ssh",
@@ -252,13 +253,9 @@ func (r *RootCmd) configSSH() *serpent.Command {
),
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
ctx := inv.Context()
if sshConfigOpts.waitEnum != "auto" && sshConfigOpts.skipProxyCommand {
@@ -283,6 +280,7 @@ func (r *RootCmd) configSSH() *serpent.Command {
out = inv.Stderr
}
var err error
coderBinary := coderCliPath
if coderBinary == "" {
coderBinary, err = currentBinPath(out)
+4 -6
View File
@@ -135,13 +135,11 @@ func Test_sshConfigSplitOnCoderSection(t *testing.T) {
}
}
// This test tries to mimic the behavior of OpenSSH when executing e.g. a ProxyCommand.
// nolint:paralleltest
// This test tries to mimic the behavior of OpenSSH
// when executing e.g. a ProxyCommand.
// nolint:tparallel
func Test_sshConfigProxyCommandEscape(t *testing.T) {
// Don't run this test, or any of its subtests in parallel. The test works by writing a file and then immediately
// executing it. Other tests might also exec a subprocess, and if they do in parallel, there is a small race
// condition where our file is open when they fork, and remains open while we attempt to execute it, causing
// a "text file busy" error.
t.Parallel()
tests := []struct {
name string
+3 -5
View File
@@ -50,6 +50,7 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
// shares the same name across multiple organizations.
orgContext = NewOrganizationContext()
)
client := new(codersdk.Client)
cmd := &serpent.Command{
Annotations: workspaceCommand,
Use: "create [workspace]",
@@ -60,12 +61,9 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
Command: "coder create <username>/<workspace_name>",
},
),
Middleware: serpent.Chain(r.InitClient(client)),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
var err error
workspaceOwner := codersdk.Me
if len(inv.Args) >= 1 {
workspaceOwner, workspaceName, err = splitNamedWorkspace(inv.Args[0])
+2 -5
View File
@@ -16,6 +16,7 @@ func (r *RootCmd) deleteWorkspace() *serpent.Command {
orphan bool
prov buildFlags
)
client := new(codersdk.Client)
cmd := &serpent.Command{
Annotations: workspaceCommand,
Use: "delete <workspace>",
@@ -28,13 +29,9 @@ func (r *RootCmd) deleteWorkspace() *serpent.Command {
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
if err != nil {
return err
+6
View File
@@ -185,6 +185,9 @@ func TestDelete(t *testing.T) {
t.Run("WarnNoProvisioners", func(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
store, ps, db := dbtestutil.NewDBWithSQLDB(t)
client, closeDaemon := coderdtest.NewWithProvisionerCloser(t, &coderdtest.Options{
@@ -225,6 +228,9 @@ func TestDelete(t *testing.T) {
t.Run("Prebuilt workspace delete permissions", func(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
// Setup
db, pb := dbtestutil.NewDB(t, dbtestutil.WithDumpOnFailure())
+23
View File
@@ -0,0 +1,23 @@
package cli
import "github.com/coder/serpent"
func (r *RootCmd) expCmd() *serpent.Command {
cmd := &serpent.Command{
Use: "exp",
Short: "Internal commands for testing and experimentation. These are prone to breaking changes with no notice.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Hidden: true,
Children: []*serpent.Command{
r.scaletestCmd(),
r.errorExample(),
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
r.tasksCommand(),
},
}
return cmd
}
-12
View File
@@ -1,12 +0,0 @@
package cli
import (
boundarycli "github.com/coder/boundary/cli"
"github.com/coder/serpent"
)
func (*RootCmd) boundary() *serpent.Command {
cmd := boundarycli.BaseCommand() // Package coder/boundary/cli exports a "base command" designed to be integrated as a subcommand.
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
return cmd
}
-33
View File
@@ -1,33 +0,0 @@
package cli_test
import (
"testing"
"github.com/stretchr/testify/assert"
boundarycli "github.com/coder/boundary/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
// Actually testing the functionality of coder/boundary takes place in the
// coder/boundary repo, since it's a dependency of coder.
// Here we want to test basically that integrating it as a subcommand doesn't break anything.
func TestBoundarySubcommand(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
inv, _ := clitest.New(t, "exp", "boundary", "--help")
pty := ptytest.New(t).Attach(inv)
go func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
}()
// Expect the --help output to include the short description.
// We're simply confirming that `coder boundary --help` ran without a runtime error as
// a good chunk of serpents self validation logic happens at runtime.
pty.ExpectMatch(boundarycli.BaseCommand().Short)
}
+9 -15
View File
@@ -56,7 +56,7 @@ func (r *RootCmd) mcpConfigure() *serpent.Command {
},
Children: []*serpent.Command{
r.mcpConfigureClaudeDesktop(),
mcpConfigureClaudeCode(),
r.mcpConfigureClaudeCode(),
r.mcpConfigureCursor(),
},
}
@@ -117,7 +117,7 @@ func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command {
return cmd
}
func mcpConfigureClaudeCode() *serpent.Command {
func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command {
var (
claudeAPIKey string
claudeConfigPath string
@@ -131,7 +131,6 @@ func mcpConfigureClaudeCode() *serpent.Command {
deprecatedCoderMCPClaudeAPIKey string
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
Use: "claude-code <project-directory>",
Short: "Configure the Claude Code server. You will need to run this command for each project you want to use. Specify the project directory as the first argument.",
@@ -149,7 +148,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
binPath = testBinaryName
}
configureClaudeEnv := map[string]string{}
agentClient, err := agentAuth.CreateClient()
agentClient, err := r.createAgentClient()
if err != nil {
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
} else {
@@ -293,7 +292,6 @@ func mcpConfigureClaudeCode() *serpent.Command {
},
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
@@ -399,20 +397,15 @@ type mcpServer struct {
func (r *RootCmd) mcpServer() *serpent.Command {
var (
client = new(codersdk.Client)
instructions string
allowedTools []string
appStatusSlug string
aiAgentAPIURL url.URL
)
agentAuth := &AgentAuth{}
cmd := &serpent.Command{
return &serpent.Command{
Use: "server",
Handler: func(inv *serpent.Invocation) error {
client, err := r.TryInitClient(inv)
if err != nil {
return err
}
var lastReport taskReport
// Create a queue that skips duplicates and preserves summaries.
queue := cliutil.NewQueue[taskReport](512).WithPredicate(func(report taskReport) (taskReport, bool) {
@@ -501,7 +494,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
}
// Try to create an agent client for status reporting. Not validated.
agentClient, err := agentAuth.CreateClient()
agentClient, err := r.createAgentClient()
if err == nil {
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
srv.agentClient = agentClient
@@ -552,6 +545,9 @@ func (r *RootCmd) mcpServer() *serpent.Command {
return srv.startServer(ctx, inv, instructions, allowedTools)
},
Short: "Start the Coder MCP server.",
Middleware: serpent.Chain(
r.TryInitClient(client),
),
Options: []serpent.Option{
{
Name: "instructions",
@@ -583,8 +579,6 @@ func (r *RootCmd) mcpServer() *serpent.Command {
},
},
}
agentAuth.AttachOptions(cmd, false)
return cmd
}
func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) {
+5 -5
View File
@@ -22,17 +22,16 @@ import (
)
func (r *RootCmd) rptyCommand() *serpent.Command {
var args handleRPTYArgs
var (
client = new(codersdk.Client)
args handleRPTYArgs
)
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
if r.disableDirect {
return xerrors.New("direct connections are disabled, but you can try websocat ;-)")
}
client, err := r.InitClient(inv)
if err != nil {
return err
}
args.NamedWorkspace = inv.Args[0]
args.Command = inv.Args[1:]
return handleRPTY(inv, client, args)
@@ -40,6 +39,7 @@ func (r *RootCmd) rptyCommand() *serpent.Command {
Long: "Establish an RPTY session with a workspace/agent. This uses the same mechanism as the Web Terminal.",
Middleware: serpent.Chain(
serpent.RequireRangeArgs(1, -1),
r.InitClient(client),
),
Options: []serpent.Option{
{
+118 -675
View File
@@ -32,17 +32,14 @@ import (
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/scaletest/agentconn"
"github.com/coder/coder/v2/scaletest/autostart"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/createworkspaces"
"github.com/coder/coder/v2/scaletest/dashboard"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/loadtestutil"
"github.com/coder/coder/v2/scaletest/reconnectingpty"
"github.com/coder/coder/v2/scaletest/workspacebuild"
"github.com/coder/coder/v2/scaletest/workspacetraffic"
"github.com/coder/coder/v2/scaletest/workspaceupdates"
"github.com/coder/serpent"
)
@@ -58,13 +55,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
Children: []*serpent.Command{
r.scaletestCleanup(),
r.scaletestDashboard(),
r.scaletestDynamicParameters(),
r.scaletestCreateWorkspaces(),
r.scaletestWorkspaceUpdates(),
r.scaletestWorkspaceTraffic(),
r.scaletestAutostart(),
r.scaletestNotifications(),
r.scaletestSMTP(),
},
}
@@ -139,111 +131,80 @@ func (s *scaletestTracingFlags) provider(ctx context.Context) (trace.TracerProvi
}, true, nil
}
type concurrencyFlags struct {
cleanup bool
concurrency int64
}
func (c *concurrencyFlags) attach(opts *serpent.OptionSet) {
concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_SCALETEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited."
if c.cleanup {
concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_SCALETEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs")
}
*opts = append(*opts, serpent.Option{
Flag: concurrencyLong,
Env: concurrencyEnv,
Description: concurrencyDescription,
Default: "1",
Value: serpent.Int64Of(&c.concurrency),
})
}
func (c *concurrencyFlags) toStrategy() harness.ExecutionStrategy {
switch c.concurrency {
case 1:
return harness.LinearExecutionStrategy{}
case 0:
return harness.ConcurrentExecutionStrategy{}
default:
return harness.ParallelExecutionStrategy{
Limit: int(c.concurrency),
}
}
}
type timeoutFlags struct {
type scaletestStrategyFlags struct {
cleanup bool
concurrency int64
timeout time.Duration
timeoutPerJob time.Duration
}
func (t *timeoutFlags) attach(opts *serpent.OptionSet) {
func (s *scaletestStrategyFlags) attach(opts *serpent.OptionSet) {
concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_SCALETEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited."
timeoutLong, timeoutEnv, timeoutDescription := "timeout", "CODER_SCALETEST_TIMEOUT", "Timeout for the entire test run. 0 means unlimited."
jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription := "job-timeout", "CODER_SCALETEST_JOB_TIMEOUT", "Timeout per job. Jobs may take longer to complete under higher concurrency limits."
if t.cleanup {
if s.cleanup {
concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_SCALETEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs")
timeoutLong, timeoutEnv, timeoutDescription = "cleanup-"+timeoutLong, "CODER_SCALETEST_CLEANUP_TIMEOUT", strings.ReplaceAll(timeoutDescription, "test", "cleanup")
jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription = "cleanup-"+jobTimeoutLong, "CODER_SCALETEST_CLEANUP_JOB_TIMEOUT", strings.ReplaceAll(jobTimeoutDescription, "jobs", "cleanup jobs")
}
*opts = append(
*opts,
serpent.Option{
Flag: concurrencyLong,
Env: concurrencyEnv,
Description: concurrencyDescription,
Default: "1",
Value: serpent.Int64Of(&s.concurrency),
},
serpent.Option{
Flag: timeoutLong,
Env: timeoutEnv,
Description: timeoutDescription,
Default: "30m",
Value: serpent.DurationOf(&t.timeout),
Value: serpent.DurationOf(&s.timeout),
},
serpent.Option{
Flag: jobTimeoutLong,
Env: jobTimeoutEnv,
Description: jobTimeoutDescription,
Default: "5m",
Value: serpent.DurationOf(&t.timeoutPerJob),
Value: serpent.DurationOf(&s.timeoutPerJob),
},
)
}
func (t *timeoutFlags) wrapStrategy(strategy harness.ExecutionStrategy) harness.ExecutionStrategy {
if t.timeoutPerJob > 0 {
return harness.TimeoutExecutionStrategyWrapper{
Timeout: t.timeoutPerJob,
func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy {
var strategy harness.ExecutionStrategy
switch s.concurrency {
case 1:
strategy = harness.LinearExecutionStrategy{}
case 0:
strategy = harness.ConcurrentExecutionStrategy{}
default:
strategy = harness.ParallelExecutionStrategy{
Limit: int(s.concurrency),
}
}
if s.timeoutPerJob > 0 {
strategy = harness.TimeoutExecutionStrategyWrapper{
Timeout: s.timeoutPerJob,
Inner: strategy,
}
}
return strategy
}
func (t *timeoutFlags) toContext(ctx context.Context) (context.Context, context.CancelFunc) {
if t.timeout > 0 {
return context.WithTimeout(ctx, t.timeout)
func (s *scaletestStrategyFlags) toContext(ctx context.Context) (context.Context, context.CancelFunc) {
if s.timeout > 0 {
return context.WithTimeout(ctx, s.timeout)
}
return context.WithCancel(ctx)
}
type scaletestStrategyFlags struct {
concurrencyFlags
timeoutFlags
}
func newScaletestCleanupStrategy() *scaletestStrategyFlags {
return &scaletestStrategyFlags{
concurrencyFlags: concurrencyFlags{cleanup: true},
timeoutFlags: timeoutFlags{cleanup: true},
}
}
func (s *scaletestStrategyFlags) attach(opts *serpent.OptionSet) {
s.timeoutFlags.attach(opts)
s.concurrencyFlags.attach(opts)
}
func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy {
return s.timeoutFlags.wrapStrategy(s.concurrencyFlags.toStrategy())
}
type scaleTestOutputFormat string
const (
@@ -434,17 +395,18 @@ func (r *userCleanupRunner) Run(ctx context.Context, _ string, _ io.Writer) erro
func (r *RootCmd) scaletestCleanup() *serpent.Command {
var template string
cleanupStrategy := newScaletestCleanupStrategy()
cleanupStrategy := &scaletestStrategyFlags{cleanup: true}
client := new(codersdk.Client)
cmd := &serpent.Command{
Use: "cleanup",
Short: "Cleanup scaletest workspaces, then cleanup scaletest users.",
Long: "The strategy flags will apply to each stage of the cleanup process.",
Middleware: serpent.Chain(
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
ctx := inv.Context()
me, err := requireAdmin(ctx, client)
@@ -585,20 +547,18 @@ func (r *RootCmd) scaletestCreateWorkspaces() *serpent.Command {
tracingFlags = &scaletestTracingFlags{}
strategy = &scaletestStrategyFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
cleanupStrategy = &scaletestStrategyFlags{cleanup: true}
output = &scaletestOutputFlags{}
)
cmd := &serpent.Command{
Use: "create-workspaces",
Short: "Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard.",
Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`,
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
client := new(codersdk.Client)
cmd := &serpent.Command{
Use: "create-workspaces",
Short: "Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard.",
Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`,
Middleware: r.InitClient(client),
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
me, err := requireAdmin(ctx, client)
@@ -687,6 +647,16 @@ func (r *RootCmd) scaletestCreateWorkspaces() *serpent.Command {
if useHostUser {
config.User.SessionToken = client.SessionToken()
} else {
config.User.Username, config.User.Email, err = newScaleTestUser(id)
if err != nil {
return xerrors.Errorf("create scaletest username and email: %w", err)
}
}
config.Workspace.Request.Name, err = newScaleTestWorkspace(id)
if err != nil {
return xerrors.Errorf("create scaletest workspace name: %w", err)
}
if runCommand != "" {
@@ -889,319 +859,21 @@ func (r *RootCmd) scaletestCreateWorkspaces() *serpent.Command {
return cmd
}
func (r *RootCmd) scaletestWorkspaceUpdates() *serpent.Command {
var (
workspaceCount int64
powerUserWorkspaces int64
powerUserPercentage float64
workspaceUpdatesTimeout time.Duration
dialTimeout time.Duration
template string
noCleanup bool
parameterFlags workspaceParameterFlags
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "workspace-updates",
Short: "Simulate the load of Coder Desktop clients receiving workspace updates",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.TryInitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) // Checked later.
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if workspaceCount <= 0 {
return xerrors.Errorf("--workspace-count must be greater than 0")
}
if powerUserWorkspaces <= 1 {
return xerrors.Errorf("--power-user-workspaces must be greater than 1")
}
if powerUserPercentage < 0 || powerUserPercentage > 100 {
return xerrors.Errorf("--power-user-proportion must be between 0 and 100")
}
powerUserWorkspaceCount := int64(float64(workspaceCount) * powerUserPercentage / 100)
remainder := powerUserWorkspaceCount % powerUserWorkspaces
// If the power user workspaces can't be evenly divided, round down
// to the nearest multiple so that we only have two groups of users.
workspaceCount -= remainder
powerUserWorkspaceCount -= remainder
powerUserCount := powerUserWorkspaceCount / powerUserWorkspaces
regularWorkspaceCount := workspaceCount - powerUserWorkspaceCount
regularUserCount := regularWorkspaceCount
regularUserWorkspaceCount := 1
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total workspaces: %d\n", workspaceCount)
_, _ = fmt.Fprintf(inv.Stderr, " Power users: %d (each owning %d workspaces = %d total)\n",
powerUserCount, powerUserWorkspaces, powerUserWorkspaceCount)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (each owning %d workspace = %d total)\n",
regularUserCount, regularUserWorkspaceCount, regularWorkspaceCount)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
if err != nil {
return xerrors.Errorf("parse template: %w", err)
}
cliRichParameters, err := asWorkspaceBuildParameters(parameterFlags.richParameters)
if err != nil {
return xerrors.Errorf("can't parse given parameter values: %w", err)
}
richParameters, err := prepWorkspaceBuild(inv, client, prepWorkspaceBuildArgs{
Action: WorkspaceCreate,
TemplateVersionID: tpl.ActiveVersionID,
RichParameterFile: parameterFlags.richParameterFile,
RichParameters: cliRichParameters,
})
if err != nil {
return xerrors.Errorf("prepare build: %w", err)
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := workspaceupdates.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := new(sync.WaitGroup)
dialBarrier.Add(int(powerUserCount + regularUserCount))
configs := make([]workspaceupdates.Config, 0, powerUserCount+regularUserCount)
for range powerUserCount {
config := workspaceupdates.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Workspace: workspacebuild.Config{
OrganizationID: me.OrganizationIDs[0],
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
},
NoWaitForAgents: true,
},
WorkspaceCount: powerUserWorkspaces,
WorkspaceUpdatesTimeout: workspaceUpdatesTimeout,
DialTimeout: dialTimeout,
Metrics: metrics,
DialBarrier: dialBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := workspaceupdates.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Workspace: workspacebuild.Config{
OrganizationID: me.OrganizationIDs[0],
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
},
NoWaitForAgents: true,
},
WorkspaceCount: int64(regularUserWorkspaceCount),
WorkspaceUpdatesTimeout: workspaceUpdatesTimeout,
DialTimeout: dialTimeout,
Metrics: metrics,
DialBarrier: dialBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
name := fmt.Sprintf("workspaceupdates-%dw", config.WorkspaceCount)
id := strconv.Itoa(i)
var runner harness.Runnable = workspaceupdates.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: fmt.Sprintf("%s/%s", name, id),
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running workspace updates scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "workspace-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_WORKSPACE_COUNT",
Description: "Required: Total number of workspaces to create.",
Value: serpent.Int64Of(&workspaceCount),
Required: true,
},
{
Flag: "power-user-workspaces",
Env: "CODER_SCALETEST_POWER_USER_WORKSPACES",
Description: "Number of workspaces each power-user owns.",
Value: serpent.Int64Of(&powerUserWorkspaces),
Required: true,
},
{
Flag: "power-user-percentage",
Env: "CODER_SCALETEST_POWER_USER_PERCENTAGE",
Default: "50.0",
Description: "Percentage of total workspaces owned by power-users (0-100).",
Value: serpent.Float64Of(&powerUserPercentage),
},
{
Flag: "workspace-updates-timeout",
Env: "CODER_SCALETEST_WORKSPACE_UPDATES_TIMEOUT",
Default: "5m",
Description: "How long to wait for all expected workspace updates.",
Value: serpent.DurationOf(&workspaceUpdatesTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the tailnet endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "template",
FlagShorthand: "t",
Env: "CODER_SCALETEST_TEMPLATE",
Description: "Required: Name or ID of the template to use for workspaces.",
Value: serpent.StringOf(&template),
Required: true,
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
}
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
var (
tickInterval time.Duration
bytesPerTick int64
ssh bool
disableDirect bool
useHostLogin bool
app string
template string
targetWorkspaces string
workspaceProxyURL string
client = &codersdk.Client{}
tracingFlags = &scaletestTracingFlags{}
strategy = &scaletestStrategyFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
cleanupStrategy = &scaletestStrategyFlags{cleanup: true}
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
@@ -1209,12 +881,10 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
cmd := &serpent.Command{
Use: "workspace-traffic",
Short: "Generate traffic to scaletest workspaces through coderd",
Middleware: serpent.Chain(
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) (err error) {
client, err := r.InitClient(inv)
if err != nil {
return err
}
ctx := inv.Context()
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) // Checked later.
@@ -1341,10 +1011,9 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
return xerrors.Errorf("parse workspace proxy URL: %w", err)
}
webClient = codersdk.New(u,
codersdk.WithHTTPClient(client.HTTPClient),
codersdk.WithSessionToken(client.SessionToken()),
)
webClient = codersdk.New(u)
webClient.HTTPClient = client.HTTPClient
webClient.SetSessionToken(client.SessionToken())
appConfig, err = createWorkspaceAppConfig(webClient, appHost.Host, app, ws, agent)
if err != nil {
@@ -1354,16 +1023,15 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
// Setup our workspace agent connection.
config := workspacetraffic.Config{
AgentID: agent.ID,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name),
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name),
SSH: ssh,
DisableDirect: disableDirect,
Echo: ssh,
App: appConfig,
AgentID: agent.ID,
BytesPerTick: bytesPerTick,
Duration: strategy.timeout,
TickInterval: tickInterval,
ReadMetrics: metrics.ReadMetrics(ws.OwnerName, ws.Name, agent.Name),
WriteMetrics: metrics.WriteMetrics(ws.OwnerName, ws.Name, agent.Name),
SSH: ssh,
Echo: ssh,
App: appConfig,
}
if webClient != nil {
@@ -1449,13 +1117,6 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
Description: "Send traffic over SSH, cannot be used with --app.",
Value: serpent.BoolOf(&ssh),
},
{
Flag: "disable-direct",
Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_DISABLE_DIRECT_CONNECTIONS",
Default: "false",
Description: "Disable direct connections for SSH traffic to workspaces. Does nothing if `--ssh` is not also set.",
Value: serpent.BoolOf(&disableDirect),
},
{
Flag: "app",
Env: "CODER_SCALETEST_WORKSPACE_TRAFFIC_APP",
@@ -1490,14 +1151,16 @@ func (r *RootCmd) scaletestWorkspaceTraffic() *serpent.Command {
func (r *RootCmd) scaletestDashboard() *serpent.Command {
var (
interval time.Duration
jitter time.Duration
headless bool
randSeed int64
targetUsers string
interval time.Duration
jitter time.Duration
headless bool
randSeed int64
targetUsers string
client = &codersdk.Client{}
tracingFlags = &scaletestTracingFlags{}
strategy = &scaletestStrategyFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
cleanupStrategy = &scaletestStrategyFlags{cleanup: true}
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
@@ -1505,12 +1168,10 @@ func (r *RootCmd) scaletestDashboard() *serpent.Command {
cmd := &serpent.Command{
Use: "dashboard",
Short: "Generate traffic to the HTTP API to simulate use of the dashboard.",
Middleware: serpent.Chain(
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
if !(interval > 0) {
return xerrors.Errorf("--interval must be greater than zero")
}
@@ -1578,9 +1239,8 @@ func (r *RootCmd) scaletestDashboard() *serpent.Command {
return xerrors.Errorf("create token for user: %w", err)
}
userClient := codersdk.New(client.URL,
codersdk.WithSessionToken(userTokResp.Key),
)
userClient := codersdk.New(client.URL)
userClient.SetSessionToken(userTokResp.Key)
config := dashboard.Config{
Interval: interval,
@@ -1687,239 +1347,6 @@ func (r *RootCmd) scaletestDashboard() *serpent.Command {
return cmd
}
const (
autostartTestName = "autostart"
)
func (r *RootCmd) scaletestAutostart() *serpent.Command {
var (
workspaceCount int64
workspaceJobTimeout time.Duration
autostartDelay time.Duration
autostartTimeout time.Duration
template string
noCleanup bool
parameterFlags workspaceParameterFlags
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "autostart",
Short: "Replicate a thundering herd of autostarting workspaces",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...) // Checked later.
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if workspaceCount <= 0 {
return xerrors.Errorf("--workspace-count must be greater than zero")
}
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tpl, err := parseTemplate(ctx, client, me.OrganizationIDs, template)
if err != nil {
return xerrors.Errorf("parse template: %w", err)
}
cliRichParameters, err := asWorkspaceBuildParameters(parameterFlags.richParameters)
if err != nil {
return xerrors.Errorf("can't parse given parameter values: %w", err)
}
richParameters, err := prepWorkspaceBuild(inv, client, prepWorkspaceBuildArgs{
Action: WorkspaceCreate,
TemplateVersionID: tpl.ActiveVersionID,
RichParameterFile: parameterFlags.richParameterFile,
RichParameters: cliRichParameters,
})
if err != nil {
return xerrors.Errorf("prepare build: %w", err)
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := autostart.NewMetrics(reg)
setupBarrier := new(sync.WaitGroup)
setupBarrier.Add(int(workspaceCount))
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i := range workspaceCount {
id := strconv.Itoa(int(i))
config := autostart.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Workspace: workspacebuild.Config{
OrganizationID: me.OrganizationIDs[0],
Request: codersdk.CreateWorkspaceRequest{
TemplateID: tpl.ID,
RichParameterValues: richParameters,
},
},
WorkspaceJobTimeout: workspaceJobTimeout,
AutostartDelay: autostartDelay,
AutostartTimeout: autostartTimeout,
Metrics: metrics,
SetupBarrier: setupBarrier,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
var runner harness.Runnable = autostart.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: fmt.Sprintf("%s/%s", autostartTestName, id),
runner: runner,
}
}
th.AddRun(autostartTestName, id, runner)
}
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Running autostart load test...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "workspace-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_WORKSPACE_COUNT",
Description: "Required: Total number of workspaces to create.",
Value: serpent.Int64Of(&workspaceCount),
Required: true,
},
{
Flag: "workspace-job-timeout",
Env: "CODER_SCALETEST_WORKSPACE_JOB_TIMEOUT",
Default: "5m",
Description: "Timeout for workspace jobs (e.g. build, start).",
Value: serpent.DurationOf(&workspaceJobTimeout),
},
{
Flag: "autostart-delay",
Env: "CODER_SCALETEST_AUTOSTART_DELAY",
Default: "2m",
Description: "How long after all the workspaces have been stopped to schedule them to be started again.",
Value: serpent.DurationOf(&autostartDelay),
},
{
Flag: "autostart-timeout",
Env: "CODER_SCALETEST_AUTOSTART_TIMEOUT",
Default: "5m",
Description: "Timeout for the autostart build to be initiated after the scheduled start time.",
Value: serpent.DurationOf(&autostartTimeout),
},
{
Flag: "template",
FlagShorthand: "t",
Env: "CODER_SCALETEST_TEMPLATE",
Description: "Required: Name or ID of the template to use for workspaces.",
Value: serpent.StringOf(&template),
Required: true,
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
}
cmd.Options = append(cmd.Options, parameterFlags.cliParameters()...)
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
type runnableTraceWrapper struct {
tracer trace.Tracer
spanName string
@@ -1929,9 +1356,8 @@ type runnableTraceWrapper struct {
}
var (
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
_ harness.Collectable = &runnableTraceWrapper{}
_ harness.Runnable = &runnableTraceWrapper{}
_ harness.Cleanable = &runnableTraceWrapper{}
)
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
@@ -1973,12 +1399,29 @@ func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string, logs io.W
return c.Cleanup(ctx, id, logs)
}
func (r *runnableTraceWrapper) GetMetrics() map[string]any {
c, ok := r.runner.(harness.Collectable)
if !ok {
return nil
}
return c.GetMetrics()
// newScaleTestUser returns a random username and email address that can be used
// for scale testing. The returned username is prefixed with "scaletest-" and
// the returned email address is suffixed with "@scaletest.local".
func newScaleTestUser(id string) (username string, email string, err error) {
randStr, err := cryptorand.String(8)
return fmt.Sprintf("scaletest-%s-%s", randStr, id), fmt.Sprintf("%s-%s@scaletest.local", randStr, id), err
}
// newScaleTestWorkspace returns a random workspace name that can be used for
// scale testing. The returned workspace name is prefixed with "scaletest-" and
// suffixed with the given id.
func newScaleTestWorkspace(id string) (name string, err error) {
randStr, err := cryptorand.String(8)
return fmt.Sprintf("scaletest-%s-%s", randStr, id), err
}
func isScaleTestUser(user codersdk.User) bool {
return strings.HasSuffix(user.Email, "@scaletest.local")
}
func isScaleTestWorkspace(workspace codersdk.Workspace) bool {
return strings.HasPrefix(workspace.OwnerName, "scaletest-") ||
strings.HasPrefix(workspace.Name, "scaletest-")
}
func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner, template string) ([]codersdk.Workspace, int, error) {
@@ -2019,7 +1462,7 @@ func getScaletestWorkspaces(ctx context.Context, client *codersdk.Client, owner,
pageWorkspaces := make([]codersdk.Workspace, 0, len(page.Workspaces))
for _, w := range page.Workspaces {
if !loadtestutil.IsScaleTestWorkspace(w.Name, w.OwnerName) {
if !isScaleTestWorkspace(w) {
continue
}
if noOwnerAccess && w.OwnerID != me.ID {
@@ -2059,7 +1502,7 @@ func getScaletestUsers(ctx context.Context, client *codersdk.Client) ([]codersdk
pageUsers := make([]codersdk.User, 0, len(page.Users))
for _, u := range page.Users {
if loadtestutil.IsScaleTestUser(u.Username, u.Email) {
if isScaleTestUser(u) {
pageUsers = append(pageUsers, u)
}
}
-181
View File
@@ -1,181 +0,0 @@
//go:build !slim
package cli
import (
"fmt"
"net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/serpent"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/dynamicparameters"
"github.com/coder/coder/v2/scaletest/harness"
)
const (
dynamicParametersTestName = "dynamic-parameters"
)
func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
var (
templateName string
provisionerTags []string
numEvals int64
tracingFlags = &scaletestTracingFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
// This test requires unlimited concurrency
timeoutStrategy = &timeoutFlags{}
)
orgContext := NewOrganizationContext()
output := &scaletestOutputFlags{}
cmd := &serpent.Command{
Use: "dynamic-parameters",
Short: "Generates load on the Coder server evaluating dynamic parameters",
Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`,
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
client, err := r.InitClient(inv)
if err != nil {
return err
}
if templateName == "" {
return xerrors.Errorf("template cannot be empty")
}
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
org, err := orgContext.Selected(inv, client)
if err != nil {
return err
}
_, err = requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
reg := prometheus.NewRegistry()
metrics := dynamicparameters.NewMetrics(reg, "concurrent_evaluations")
logger := slog.Make(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug)
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
defer func() {
// Allow time for traces to flush even if command context is
// canceled. This is a no-op if tracing is not enabled.
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
tracer := tracerProvider.Tracer(scaletestTracerName)
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, tags, numEvals, logger)
if err != nil {
return xerrors.Errorf("setup dynamic parameters partitions: %w", err)
}
th := harness.NewTestHarness(
timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}),
// there is no cleanup since it's just a connection that we sever.
nil)
for i, part := range partitions {
for j := range part.ConcurrentEvaluations {
cfg := dynamicparameters.Config{
TemplateVersion: part.TemplateVersion.ID,
Metrics: metrics,
MetricLabelValues: []string{fmt.Sprintf("%d", part.ConcurrentEvaluations)},
}
var runner harness.Runnable = dynamicparameters.NewRunner(client, cfg)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: fmt.Sprintf("%s/%d/%d", dynamicParametersTestName, i, j),
runner: runner,
}
}
th.AddRun(dynamicParametersTestName, fmt.Sprintf("%d/%d", j, i), runner)
}
}
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness: %w", err)
}
res := th.Results()
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "template",
Description: "Name of the template to use. If it does not exist, it will be created.",
Default: "scaletest-dynamic-parameters",
Value: serpent.StringOf(&templateName),
},
{
Flag: "concurrent-evaluations",
Description: "Number of concurrent dynamic parameter evaluations to perform.",
Default: "100",
Value: serpent.Int64Of(&numEvals),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
orgContext.AttachOptions(cmd)
output.attach(&cmd.Options)
tracingFlags.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
return cmd
}
-447
View File
@@ -1,447 +0,0 @@
//go:build !slim
package cli
import (
"context"
"fmt"
"net/http"
"os/signal"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/xerrors"
"cdr.dev/slog"
notificationsLib "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/scaletest/createusers"
"github.com/coder/coder/v2/scaletest/harness"
"github.com/coder/coder/v2/scaletest/notifications"
"github.com/coder/serpent"
)
func (r *RootCmd) scaletestNotifications() *serpent.Command {
var (
userCount int64
ownerUserPercentage float64
notificationTimeout time.Duration
dialTimeout time.Duration
noCleanup bool
smtpAPIURL string
tracingFlags = &scaletestTracingFlags{}
// This test requires unlimited concurrency.
timeoutStrategy = &timeoutFlags{}
cleanupStrategy = newScaletestCleanupStrategy()
output = &scaletestOutputFlags{}
prometheusFlags = &scaletestPrometheusFlags{}
)
cmd := &serpent.Command{
Use: "notifications",
Short: "Simulate notification delivery by creating many users listening to notifications.",
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
me, err := requireAdmin(ctx, client)
if err != nil {
return err
}
client.HTTPClient = &http.Client{
Transport: &codersdk.HeaderTransport{
Transport: http.DefaultTransport,
Header: map[string][]string{
codersdk.BypassRatelimitHeader: {"true"},
},
},
}
if userCount <= 0 {
return xerrors.Errorf("--user-count must be greater than 0")
}
if ownerUserPercentage < 0 || ownerUserPercentage > 100 {
return xerrors.Errorf("--owner-user-percentage must be between 0 and 100")
}
if smtpAPIURL != "" && !strings.HasPrefix(smtpAPIURL, "http://") && !strings.HasPrefix(smtpAPIURL, "https://") {
return xerrors.Errorf("--smtp-api-url must start with http:// or https://")
}
ownerUserCount := int64(float64(userCount) * ownerUserPercentage / 100)
if ownerUserCount == 0 && ownerUserPercentage > 0 {
ownerUserCount = 1
}
regularUserCount := userCount - ownerUserCount
_, _ = fmt.Fprintf(inv.Stderr, "Distribution plan:\n")
_, _ = fmt.Fprintf(inv.Stderr, " Total users: %d\n", userCount)
_, _ = fmt.Fprintf(inv.Stderr, " Owner users: %d (%.1f%%)\n", ownerUserCount, ownerUserPercentage)
_, _ = fmt.Fprintf(inv.Stderr, " Regular users: %d (%.1f%%)\n", regularUserCount, 100.0-ownerUserPercentage)
outputs, err := output.parse()
if err != nil {
return xerrors.Errorf("could not parse --output flags")
}
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
if err != nil {
return xerrors.Errorf("create tracer provider: %w", err)
}
tracer := tracerProvider.Tracer(scaletestTracerName)
reg := prometheus.NewRegistry()
metrics := notifications.NewMetrics(reg)
logger := inv.Logger
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
defer prometheusSrvClose()
defer func() {
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
if err := closeTracing(ctx); err != nil {
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
}
// Wait for prometheus metrics to be scraped
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
<-time.After(prometheusFlags.Wait)
}()
_, _ = fmt.Fprintln(inv.Stderr, "Creating users...")
dialBarrier := &sync.WaitGroup{}
ownerWatchBarrier := &sync.WaitGroup{}
dialBarrier.Add(int(userCount))
ownerWatchBarrier.Add(int(ownerUserCount))
expectedNotificationIDs := map[uuid.UUID]struct{}{
notificationsLib.TemplateUserAccountCreated: {},
notificationsLib.TemplateUserAccountDeleted: {},
}
triggerTimes := make(map[uuid.UUID]chan time.Time, len(expectedNotificationIDs))
for id := range expectedNotificationIDs {
triggerTimes[id] = make(chan time.Time, 1)
}
configs := make([]notifications.Config, 0, userCount)
for range ownerUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{codersdk.RoleOwner},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
ExpectedNotificationsIDs: expectedNotificationIDs,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
for range regularUserCount {
config := notifications.Config{
User: createusers.Config{
OrganizationID: me.OrganizationIDs[0],
},
Roles: []string{},
NotificationTimeout: notificationTimeout,
DialTimeout: dialTimeout,
DialBarrier: dialBarrier,
ReceivingWatchBarrier: ownerWatchBarrier,
Metrics: metrics,
SMTPApiURL: smtpAPIURL,
}
if err := config.Validate(); err != nil {
return xerrors.Errorf("validate config: %w", err)
}
configs = append(configs, config)
}
go triggerUserNotifications(
ctx,
logger,
client,
me.OrganizationIDs[0],
dialBarrier,
dialTimeout,
triggerTimes,
)
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
for i, config := range configs {
id := strconv.Itoa(i)
name := fmt.Sprintf("notifications-%s", id)
var runner harness.Runnable = notifications.NewRunner(client, config)
if tracingEnabled {
runner = &runnableTraceWrapper{
tracer: tracer,
spanName: name,
runner: runner,
}
}
th.AddRun(name, id, runner)
}
_, _ = fmt.Fprintln(inv.Stderr, "Running notification delivery scaletest...")
testCtx, testCancel := timeoutStrategy.toContext(ctx)
defer testCancel()
err = th.Run(testCtx)
if err != nil {
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
}
// If the command was interrupted, skip stats.
if notifyCtx.Err() != nil {
return notifyCtx.Err()
}
res := th.Results()
if err := computeNotificationLatencies(ctx, logger, triggerTimes, res, metrics); err != nil {
return xerrors.Errorf("compute notification latencies: %w", err)
}
for _, o := range outputs {
err = o.write(res, inv.Stdout)
if err != nil {
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
}
}
if !noCleanup {
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
defer cleanupCancel()
err = th.Cleanup(cleanupCtx)
if err != nil {
return xerrors.Errorf("cleanup tests: %w", err)
}
}
if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details")
}
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "user-count",
FlagShorthand: "c",
Env: "CODER_SCALETEST_NOTIFICATION_USER_COUNT",
Description: "Required: Total number of users to create.",
Value: serpent.Int64Of(&userCount),
Required: true,
},
{
Flag: "owner-user-percentage",
Env: "CODER_SCALETEST_NOTIFICATION_OWNER_USER_PERCENTAGE",
Default: "20.0",
Description: "Percentage of users to assign Owner role to (0-100).",
Value: serpent.Float64Of(&ownerUserPercentage),
},
{
Flag: "notification-timeout",
Env: "CODER_SCALETEST_NOTIFICATION_TIMEOUT",
Default: "5m",
Description: "How long to wait for notifications after triggering.",
Value: serpent.DurationOf(&notificationTimeout),
},
{
Flag: "dial-timeout",
Env: "CODER_SCALETEST_DIAL_TIMEOUT",
Default: "2m",
Description: "Timeout for dialing the notification websocket endpoint.",
Value: serpent.DurationOf(&dialTimeout),
},
{
Flag: "no-cleanup",
Env: "CODER_SCALETEST_NO_CLEANUP",
Description: "Do not clean up resources after the test completes.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "smtp-api-url",
Env: "CODER_SCALETEST_SMTP_API_URL",
Description: "SMTP mock HTTP API address.",
Value: serpent.StringOf(&smtpAPIURL),
},
}
tracingFlags.attach(&cmd.Options)
timeoutStrategy.attach(&cmd.Options)
cleanupStrategy.attach(&cmd.Options)
output.attach(&cmd.Options)
prometheusFlags.attach(&cmd.Options)
return cmd
}
func computeNotificationLatencies(
ctx context.Context,
logger slog.Logger,
expectedNotifications map[uuid.UUID]chan time.Time,
results harness.Results,
metrics *notifications.Metrics,
) error {
triggerTimes := make(map[uuid.UUID]time.Time)
for notificationID, triggerTimeChan := range expectedNotifications {
select {
case triggerTime := <-triggerTimeChan:
triggerTimes[notificationID] = triggerTime
logger.Info(ctx, "received trigger time",
slog.F("notification_id", notificationID),
slog.F("trigger_time", triggerTime))
default:
logger.Warn(ctx, "no trigger time received for notification",
slog.F("notification_id", notificationID))
}
}
if len(triggerTimes) == 0 {
logger.Warn(ctx, "no trigger times available, skipping latency computation")
return nil
}
var totalLatencies int
for runID, runResult := range results.Runs {
if runResult.Error != nil {
logger.Debug(ctx, "skipping failed run for latency computation",
slog.F("run_id", runID))
continue
}
if runResult.Metrics == nil {
continue
}
// Process websocket notifications.
if wsReceiptTimes, ok := runResult.Metrics[notifications.WebsocketNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range wsReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeWebsocket)
totalLatencies++
logger.Debug(ctx, "computed websocket latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
// Process SMTP notifications
if smtpReceiptTimes, ok := runResult.Metrics[notifications.SMTPNotificationReceiptTimeMetric].(map[uuid.UUID]time.Time); ok {
for notificationID, receiptTime := range smtpReceiptTimes {
if triggerTime, ok := triggerTimes[notificationID]; ok {
latency := receiptTime.Sub(triggerTime)
metrics.RecordLatency(latency, notificationID.String(), notifications.NotificationTypeSMTP)
totalLatencies++
logger.Debug(ctx, "computed SMTP latency",
slog.F("run_id", runID),
slog.F("notification_id", notificationID),
slog.F("latency", latency))
}
}
}
}
logger.Info(ctx, "finished computing notification latencies",
slog.F("total_runs", results.TotalRuns),
slog.F("total_latencies_computed", totalLatencies))
return nil
}
// triggerUserNotifications waits for all test users to connect,
// then creates and deletes a test user to trigger notification events for testing.
func triggerUserNotifications(
ctx context.Context,
logger slog.Logger,
client *codersdk.Client,
orgID uuid.UUID,
dialBarrier *sync.WaitGroup,
dialTimeout time.Duration,
expectedNotifications map[uuid.UUID]chan time.Time,
) {
logger.Info(ctx, "waiting for all users to connect")
// Wait for all users to connect
waitCtx, cancel := context.WithTimeout(ctx, dialTimeout+30*time.Second)
defer cancel()
done := make(chan struct{})
go func() {
dialBarrier.Wait()
close(done)
}()
select {
case <-done:
logger.Info(ctx, "all users connected")
case <-waitCtx.Done():
if waitCtx.Err() == context.DeadlineExceeded {
logger.Error(ctx, "timeout waiting for users to connect")
} else {
logger.Info(ctx, "context canceled while waiting for users")
}
return
}
const (
triggerUsername = "scaletest-trigger-user"
triggerEmail = "scaletest-trigger@example.com"
)
logger.Info(ctx, "creating test user to test notifications",
slog.F("username", triggerUsername),
slog.F("email", triggerEmail),
slog.F("org_id", orgID))
testUser, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{orgID},
Username: triggerUsername,
Email: triggerEmail,
Password: "test-password-123",
})
if err != nil {
logger.Error(ctx, "create test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountCreated] <- time.Now()
err = client.DeleteUser(ctx, testUser.ID)
if err != nil {
logger.Error(ctx, "delete test user", slog.Error(err))
return
}
expectedNotifications[notificationsLib.TemplateUserAccountDeleted] <- time.Now()
close(expectedNotifications[notificationsLib.TemplateUserAccountCreated])
close(expectedNotifications[notificationsLib.TemplateUserAccountDeleted])
}
-112
View File
@@ -1,112 +0,0 @@
//go:build !slim
package cli
import (
"fmt"
"os/signal"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/scaletest/smtpmock"
"github.com/coder/serpent"
)
func (*RootCmd) scaletestSMTP() *serpent.Command {
var (
hostAddress string
smtpPort int64
apiPort int64
purgeAtCount int64
)
cmd := &serpent.Command{
Use: "smtp",
Short: "Start a mock SMTP server for testing",
Long: `Start a mock SMTP server with an HTTP API server that can be used to purge
messages and get messages by email.`,
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
defer stop()
ctx = notifyCtx
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
config := smtpmock.Config{
HostAddress: hostAddress,
SMTPPort: int(smtpPort),
APIPort: int(apiPort),
Logger: logger,
}
srv := new(smtpmock.Server)
if err := srv.Start(ctx, config); err != nil {
return xerrors.Errorf("start mock SMTP server: %w", err)
}
defer func() {
_ = srv.Stop()
}()
_, _ = fmt.Fprintf(inv.Stdout, "Mock SMTP server started on %s\n", srv.SMTPAddress())
_, _ = fmt.Fprintf(inv.Stdout, "HTTP API server started on %s\n", srv.APIAddress())
if purgeAtCount > 0 {
_, _ = fmt.Fprintf(inv.Stdout, " Auto-purge when message count reaches %d\n", purgeAtCount)
}
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
_, _ = fmt.Fprintf(inv.Stdout, "\nTotal messages received since last purge: %d\n", srv.MessageCount())
return nil
case <-ticker.C:
count := srv.MessageCount()
if count > 0 {
_, _ = fmt.Fprintf(inv.Stdout, "Messages received: %d\n", count)
}
if purgeAtCount > 0 && int64(count) >= purgeAtCount {
_, _ = fmt.Fprintf(inv.Stdout, "Message count (%d) reached threshold (%d). Purging...\n", count, purgeAtCount)
srv.Purge()
continue
}
}
}
},
}
cmd.Options = []serpent.Option{
{
Flag: "host-address",
Env: "CODER_SCALETEST_SMTP_HOST_ADDRESS",
Default: "localhost",
Description: "Host address to bind the mock SMTP and API servers.",
Value: serpent.StringOf(&hostAddress),
},
{
Flag: "smtp-port",
Env: "CODER_SCALETEST_SMTP_PORT",
Description: "Port for the mock SMTP server. Uses a random port if not specified.",
Value: serpent.Int64Of(&smtpPort),
},
{
Flag: "api-port",
Env: "CODER_SCALETEST_SMTP_API_PORT",
Description: "Port for the HTTP API server. Uses a random port if not specified.",
Value: serpent.Int64Of(&apiPort),
},
{
Flag: "purge-at-count",
Env: "CODER_SCALETEST_SMTP_PURGE_AT_COUNT",
Default: "100000",
Description: "Maximum number of messages to keep before auto-purging. Set to 0 to disable.",
Value: serpent.Int64Of(&purgeAtCount),
},
}
return cmd
}
+1 -4
View File
@@ -13,11 +13,8 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.taskCreate(),
r.taskDelete(),
r.taskList(),
r.taskLogs(),
r.taskSend(),
r.taskCreate(),
r.taskStatus(),
},
}
-237
View File
@@ -1,237 +0,0 @@
package cli
import (
"fmt"
"io"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) taskCreate() *serpent.Command {
var (
orgContext = NewOrganizationContext()
ownerArg string
taskName string
templateName string
templateVersionName string
presetName string
stdin bool
quiet bool
)
cmd := &serpent.Command{
Use: "create [input]",
Short: "Create an experimental task",
Long: FormatExamples(
Example{
Description: "Create a task with direct input",
Command: "coder exp task create \"Add authentication to the user service\"",
},
Example{
Description: "Create a task with stdin input",
Command: "echo \"Add authentication to the user service\" | coder exp task create",
},
Example{
Description: "Create a task with a specific name",
Command: "coder exp task create --name task1 \"Add authentication to the user service\"",
},
Example{
Description: "Create a task from a specific template / preset",
Command: "coder exp task create --template backend-dev --preset \"My Preset\" \"Add authentication to the user service\"",
},
Example{
Description: "Create a task for another user (requires appropriate permissions)",
Command: "coder exp task create --owner user@example.com \"Add authentication to the user service\"",
},
),
Middleware: serpent.Chain(
serpent.RequireRangeArgs(0, 1),
),
Options: serpent.OptionSet{
{
Name: "name",
Flag: "name",
Description: "Specify the name of the task. If you do not specify one, a name will be generated for you.",
Value: serpent.StringOf(&taskName),
Required: false,
Default: "",
},
{
Name: "owner",
Flag: "owner",
Description: "Specify the owner of the task. Defaults to the current user.",
Value: serpent.StringOf(&ownerArg),
Required: false,
Default: codersdk.Me,
},
{
Name: "template",
Flag: "template",
Env: "CODER_TASK_TEMPLATE_NAME",
Value: serpent.StringOf(&templateName),
},
{
Name: "template-version",
Flag: "template-version",
Env: "CODER_TASK_TEMPLATE_VERSION",
Value: serpent.StringOf(&templateVersionName),
},
{
Name: "preset",
Flag: "preset",
Env: "CODER_TASK_PRESET_NAME",
Value: serpent.StringOf(&presetName),
Default: PresetNone,
},
{
Name: "stdin",
Flag: "stdin",
Description: "Reads from stdin for the task input.",
Value: serpent.BoolOf(&stdin),
},
{
Name: "quiet",
Flag: "quiet",
FlagShorthand: "q",
Description: "Only display the created task's ID.",
Value: serpent.BoolOf(&quiet),
},
},
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
var (
ctx = inv.Context()
expClient = codersdk.NewExperimentalClient(client)
taskInput string
templateVersionID uuid.UUID
templateVersionPresetID uuid.UUID
)
organization, err := orgContext.Selected(inv, client)
if err != nil {
return xerrors.Errorf("get current organization: %w", err)
}
if stdin {
bytes, err := io.ReadAll(inv.Stdin)
if err != nil {
return xerrors.Errorf("reading stdin: %w", err)
}
taskInput = string(bytes)
} else {
if len(inv.Args) != 1 {
return xerrors.Errorf("expected an input for task")
}
taskInput = inv.Args[0]
}
if taskInput == "" {
return xerrors.Errorf("a task cannot be started with an empty input")
}
switch {
case templateName == "":
templates, err := client.Templates(ctx, codersdk.TemplateFilter{SearchQuery: "has-ai-task:true", OrganizationID: organization.ID})
if err != nil {
return xerrors.Errorf("list templates: %w", err)
}
if len(templates) == 0 {
return xerrors.Errorf("no task templates configured")
}
// When a deployment has only 1 AI task template, we will
// allow omitting the template. Otherwise we will require
// the user to be explicit with their choice of template.
if len(templates) > 1 {
templateNames := make([]string, 0, len(templates))
for _, template := range templates {
templateNames = append(templateNames, template.Name)
}
return xerrors.Errorf("template name not provided, available templates: %s", strings.Join(templateNames, ", "))
}
if templateVersionName != "" {
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templates[0].Name, templateVersionName)
if err != nil {
return xerrors.Errorf("get template version: %w", err)
}
templateVersionID = templateVersion.ID
} else {
templateVersionID = templates[0].ActiveVersionID
}
case templateVersionName != "":
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templateName, templateVersionName)
if err != nil {
return xerrors.Errorf("get template version: %w", err)
}
templateVersionID = templateVersion.ID
default:
template, err := client.TemplateByName(ctx, organization.ID, templateName)
if err != nil {
return xerrors.Errorf("get template: %w", err)
}
templateVersionID = template.ActiveVersionID
}
if presetName != PresetNone {
templatePresets, err := client.TemplateVersionPresets(ctx, templateVersionID)
if err != nil {
return xerrors.Errorf("get template presets: %w", err)
}
preset, err := resolvePreset(templatePresets, presetName)
if err != nil {
return xerrors.Errorf("resolve preset: %w", err)
}
templateVersionPresetID = preset.ID
}
task, err := expClient.CreateTask(ctx, ownerArg, codersdk.CreateTaskRequest{
Name: taskName,
TemplateVersionID: templateVersionID,
TemplateVersionPresetID: templateVersionPresetID,
Input: taskInput,
})
if err != nil {
return xerrors.Errorf("create task: %w", err)
}
if quiet {
_, _ = fmt.Fprintln(inv.Stdout, task.ID)
} else {
_, _ = fmt.Fprintf(
inv.Stdout,
"The task %s has been created at %s!\n",
cliui.Keyword(task.Name),
cliui.Timestamp(task.CreatedAt),
)
}
return nil
},
}
orgContext.AttachOptions(cmd)
return cmd
}
-87
View File
@@ -1,87 +0,0 @@
package cli
import (
"fmt"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/pretty"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) taskDelete() *serpent.Command {
cmd := &serpent.Command{
Use: "delete <task> [<task> ...]",
Short: "Delete experimental tasks",
Long: FormatExamples(
Example{
Description: "Delete a single task.",
Command: "$ coder exp task delete task1",
},
Example{
Description: "Delete multiple tasks.",
Command: "$ coder exp task delete task1 task2 task3",
},
Example{
Description: "Delete a task without confirmation.",
Command: "$ coder exp task delete task4 --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireRangeArgs(1, -1),
),
Options: serpent.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
exp := codersdk.NewExperimentalClient(client)
var tasks []codersdk.Task
for _, identifier := range inv.Args {
task, err := exp.TaskByIdentifier(ctx, identifier)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", identifier, err)
}
tasks = append(tasks, task)
}
// Confirm deletion of the tasks.
var displayList []string
for _, task := range tasks {
displayList = append(displayList, fmt.Sprintf("%s/%s", task.OwnerName, task.Name))
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Delete these tasks: %s?", pretty.Sprint(cliui.DefaultStyles.Code, strings.Join(displayList, ", "))),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
for i, task := range tasks {
display := displayList[i]
if err := exp.DeleteTask(ctx, task.OwnerName, task.ID); err != nil {
return xerrors.Errorf("delete task %q: %w", display, err)
}
_, _ = fmt.Fprintln(
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, display)+" at "+cliui.Timestamp(time.Now()),
)
}
return nil
},
}
return cmd
}
-248
View File
@@ -1,248 +0,0 @@
package cli_test
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskDelete(t *testing.T) {
t.Parallel()
type testCounters struct {
deleteCalls atomic.Int64
nameResolves atomic.Int64
}
type handlerBuilder func(c *testCounters) http.HandlerFunc
type testCase struct {
name string
args []string
promptYes bool
wantErr bool
wantDeleteCalls int64
wantNameResolves int64
wantDeletedMessage int
buildHandler handlerBuilder
}
const (
id1 = "11111111-1111-1111-1111-111111111111"
id2 = "22222222-2222-2222-2222-222222222222"
id3 = "33333333-3333-3333-3333-333333333333"
id4 = "44444444-4444-4444-4444-444444444444"
id5 = "55555555-5555-5555-5555-555555555555"
)
cases := []testCase{
{
name: "Prompted_ByName_OK",
args: []string{"exists"},
promptYes: true,
buildHandler: func(c *testCounters) http.HandlerFunc {
taskID := uuid.MustParse(id1)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: taskID,
Name: "exists",
OwnerName: "me",
}},
Count: 1,
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id1:
c.deleteCalls.Add(1)
w.WriteHeader(http.StatusAccepted)
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
}
},
wantDeleteCalls: 1,
wantNameResolves: 1,
},
{
name: "Prompted_ByUUID_OK",
args: []string{id2},
promptYes: true,
buildHandler: func(c *testCounters) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id2:
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse(id2),
OwnerName: "me",
Name: "uuid-task",
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id2:
c.deleteCalls.Add(1)
w.WriteHeader(http.StatusAccepted)
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
}
},
wantDeleteCalls: 1,
},
{
name: "Multiple_YesFlag",
args: []string{"--yes", "first", id4},
buildHandler: func(c *testCounters) http.HandlerFunc {
firstID := uuid.MustParse(id3)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: firstID,
Name: "first",
OwnerName: "me",
}},
Count: 1,
})
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id4:
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse(id4),
OwnerName: "me",
Name: "uuid-task-2",
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id3:
c.deleteCalls.Add(1)
w.WriteHeader(http.StatusAccepted)
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id4:
c.deleteCalls.Add(1)
w.WriteHeader(http.StatusAccepted)
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
}
},
wantDeleteCalls: 2,
wantNameResolves: 1,
wantDeletedMessage: 2,
},
{
name: "ResolveNameError",
args: []string{"doesnotexist"},
wantErr: true,
buildHandler: func(_ *testCounters) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{},
Count: 0,
})
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
}
},
},
{
name: "DeleteError",
args: []string{"bad"},
promptYes: true,
wantErr: true,
buildHandler: func(c *testCounters) http.HandlerFunc {
taskID := uuid.MustParse(id5)
return func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks" && r.URL.Query().Get("q") == "owner:\"me\"":
c.nameResolves.Add(1)
httpapi.Write(r.Context(), w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: taskID,
Name: "bad",
OwnerName: "me",
}},
Count: 1,
})
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id5:
httpapi.InternalServerError(w, xerrors.New("boom"))
default:
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
}
}
},
wantNameResolves: 1,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
var counters testCounters
srv := httptest.NewServer(tc.buildHandler(&counters))
t.Cleanup(srv.Close)
client := codersdk.New(testutil.MustURL(t, srv.URL))
args := append([]string{"exp", "task", "delete"}, tc.args...)
inv, root := clitest.New(t, args...)
inv = inv.WithContext(ctx)
clitest.SetupConfig(t, client, root)
var runErr error
var outBuf bytes.Buffer
if tc.promptYes {
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatch("Delete these tasks:")
pty.WriteLine("yes")
runErr = w.Wait()
outBuf.Write(pty.ReadAll())
} else {
inv.Stdout = &outBuf
inv.Stderr = &outBuf
runErr = inv.Run()
}
if tc.wantErr {
require.Error(t, runErr)
} else {
require.NoError(t, runErr)
}
require.Equal(t, tc.wantDeleteCalls, counters.deleteCalls.Load(), "wrong delete call count")
require.Equal(t, tc.wantNameResolves, counters.nameResolves.Load(), "wrong name resolve count")
if tc.wantDeletedMessage > 0 {
output := outBuf.String()
require.GreaterOrEqual(t, strings.Count(output, "Deleted task"), tc.wantDeletedMessage)
}
})
}
}
-70
View File
@@ -1,70 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) taskLogs() *serpent.Command {
formatter := cliui.NewOutputFormatter(
cliui.TableFormat(
[]codersdk.TaskLogEntry{},
[]string{
"type",
"content",
},
),
cliui.JSONFormat(),
)
cmd := &serpent.Command{
Use: "logs <task>",
Short: "Show a task's logs",
Long: FormatExamples(
Example{
Description: "Show logs for a given task.",
Command: "coder exp task logs task1",
}),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
var (
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
identifier = inv.Args[0]
)
task, err := exp.TaskByIdentifier(ctx, identifier)
if err != nil {
return xerrors.Errorf("resolve task %q: %w", identifier, err)
}
logs, err := exp.TaskLogs(ctx, codersdk.Me, task.ID)
if err != nil {
return xerrors.Errorf("get task logs: %w", err)
}
out, err := formatter.Format(ctx, logs.Logs)
if err != nil {
return xerrors.Errorf("format task logs: %w", err)
}
_, _ = fmt.Fprintln(inv.Stdout, out)
return nil
},
}
formatter.AttachOptions(&cmd.Options)
return cmd
}
-187
View File
@@ -1,187 +0,0 @@
package cli_test
import (
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
func Test_TaskLogs(t *testing.T) {
t.Parallel()
testMessages := []agentapisdk.Message{
{
Id: 0,
Role: agentapisdk.RoleUser,
Content: "What is 1 + 1?",
Time: time.Now().Add(-2 * time.Minute),
},
{
Id: 1,
Role: agentapisdk.RoleAgent,
Content: "2",
Time: time.Now().Add(-1 * time.Minute),
},
}
t.Run("ByTaskName_JSON", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client // user already has access to their own workspace
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.Name, "--output", "json")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
var logs []codersdk.TaskLogEntry
err = json.NewDecoder(strings.NewReader(stdout.String())).Decode(&logs)
require.NoError(t, err)
require.Len(t, logs, 2)
require.Equal(t, "What is 1 + 1?", logs[0].Content)
require.Equal(t, codersdk.TaskLogTypeInput, logs[0].Type)
require.Equal(t, "2", logs[1].Content)
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
})
t.Run("ByTaskID_JSON", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String(), "--output", "json")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
var logs []codersdk.TaskLogEntry
err = json.NewDecoder(strings.NewReader(stdout.String())).Decode(&logs)
require.NoError(t, err)
require.Len(t, logs, 2)
require.Equal(t, "What is 1 + 1?", logs[0].Content)
require.Equal(t, codersdk.TaskLogTypeInput, logs[0].Type)
require.Equal(t, "2", logs[1].Content)
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
})
t.Run("ByTaskID_Table", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
output := stdout.String()
require.Contains(t, output, "What is 1 + 1?")
require.Contains(t, output, "2")
require.Contains(t, output, "input")
require.Contains(t, output, "output")
})
t.Run("TaskNotFound_ByName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", "doesnotexist")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("TaskNotFound_ByID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "logs", uuid.Nil.String())
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("ErrorFetchingLogs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
userClient := client
inv, root := clitest.New(t, "exp", "task", "logs", task.ID.String())
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, assert.AnError.Error())
})
}
func fakeAgentAPITaskLogsOK(messages []agentapisdk.Message) map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/messages": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"messages": messages,
})
},
}
}
func fakeAgentAPITaskLogsErr(err error) map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/messages": func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"error": err.Error(),
})
},
}
}
-77
View File
@@ -1,77 +0,0 @@
package cli
import (
"io"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) taskSend() *serpent.Command {
var stdin bool
cmd := &serpent.Command{
Use: "send <task> [<input> | --stdin]",
Short: "Send input to a task",
Long: FormatExamples(Example{
Description: "Send direct input to a task.",
Command: "coder exp task send task1 \"Please also add unit tests\"",
}, Example{
Description: "Send input from stdin to a task.",
Command: "echo \"Please also add unit tests\" | coder exp task send task1 --stdin",
}),
Middleware: serpent.RequireRangeArgs(1, 2),
Options: serpent.OptionSet{
{
Name: "stdin",
Flag: "stdin",
Description: "Reads the input from stdin.",
Value: serpent.BoolOf(&stdin),
},
},
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
var (
ctx = inv.Context()
exp = codersdk.NewExperimentalClient(client)
identifier = inv.Args[0]
taskInput string
)
if stdin {
bytes, err := io.ReadAll(inv.Stdin)
if err != nil {
return xerrors.Errorf("reading stdio: %w", err)
}
taskInput = string(bytes)
} else {
if len(inv.Args) != 2 {
return xerrors.Errorf("expected an input for the task")
}
taskInput = inv.Args[1]
}
task, err := exp.TaskByIdentifier(ctx, identifier)
if err != nil {
return xerrors.Errorf("resolve task: %w", err)
}
if err = exp.TaskSend(ctx, codersdk.Me, task.ID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
return xerrors.Errorf("send input to task: %w", err)
}
return nil
},
}
return cmd
}
-171
View File
@@ -1,171 +0,0 @@
package cli_test
import (
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/testutil"
)
func Test_TaskSend(t *testing.T) {
t.Parallel()
t.Run("ByTaskName_WithArgument", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
})
t.Run("ByTaskID_WithArgument", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.ID.String(), "carry on with the task")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
})
t.Run("ByTaskName_WithStdin", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "--stdin")
inv.Stdout = &stdout
inv.Stdin = strings.NewReader("carry on with the task")
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
})
t.Run("TaskNotFound_ByName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", "doesnotexist", "some task input")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("TaskNotFound_ByID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", uuid.Nil.String(), "some task input")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
})
t.Run("SendError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
userClient, task := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
var stdout strings.Builder
inv, root := clitest.New(t, "exp", "task", "send", task.Name, "some task input")
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, assert.AnError.Error())
})
}
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/status": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "stable",
})
},
"/message": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
var msg agentapisdk.PostMessageParams
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assert.Equal(t, expectMessage, msg.Content)
message := agentapisdk.Message{
Id: 999,
Role: agentapisdk.RoleAgent,
Content: returnMessage,
Time: time.Now(),
}
_ = json.NewEncoder(w).Encode(message)
},
}
}
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
return map[string]http.HandlerFunc{
"/status": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"status": "stable",
})
},
"/message": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(returnErr.Error()))
},
}
}
+53 -80
View File
@@ -5,6 +5,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
@@ -14,13 +15,13 @@ import (
func (r *RootCmd) taskStatus() *serpent.Command {
var (
client = new(codersdk.Client)
formatter = cliui.NewOutputFormatter(
cliui.TableFormat(
[]taskStatusRow{},
[]string{
"state changed",
"status",
"healthy",
"state",
"message",
},
@@ -43,17 +44,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
watchIntervalArg time.Duration
)
cmd := &serpent.Command{
Short: "Show the status of a task.",
Long: FormatExamples(
Example{
Description: "Show the status of a given task.",
Command: "coder exp task status task1",
},
Example{
Description: "Watch the status of a given task until it completes (idle or stopped).",
Command: "coder exp task status task1 --watch",
},
),
Short: "Show the status of a task.",
Use: "status",
Aliases: []string{"stat"},
Options: serpent.OptionSet{
@@ -75,62 +66,67 @@ func (r *RootCmd) taskStatus() *serpent.Command {
},
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
r.InitClient(client),
),
Handler: func(i *serpent.Invocation) error {
client, err := r.InitClient(i)
if err != nil {
return err
}
ctx := i.Context()
exp := codersdk.NewExperimentalClient(client)
ec := codersdk.NewExperimentalClient(client)
identifier := i.Args[0]
task, err := exp.TaskByIdentifier(ctx, identifier)
taskID, err := uuid.Parse(identifier)
if err != nil {
// Try to resolve the task as a named workspace
// TODO: right now tasks are still "workspaces" under the hood.
// We should update this once we have a proper task model.
ws, err := namedWorkspace(ctx, client, identifier)
if err != nil {
return err
}
taskID = ws.ID
}
task, err := ec.TaskByID(ctx, taskID)
if err != nil {
return err
}
tsr := toStatusRow(task)
out, err := formatter.Format(ctx, []taskStatusRow{tsr})
out, err := formatter.Format(ctx, toStatusRow(task))
if err != nil {
return xerrors.Errorf("format task status: %w", err)
}
_, _ = fmt.Fprintln(i.Stdout, out)
if !watchArg || taskWatchIsEnded(task) {
if !watchArg {
return nil
}
lastStatus := task.Status
lastState := task.CurrentState
t := time.NewTicker(watchIntervalArg)
defer t.Stop()
// TODO: implement streaming updates instead of polling
lastStatusRow := tsr
for range t.C {
task, err := exp.TaskByID(ctx, task.ID)
task, err := ec.TaskByID(ctx, taskID)
if err != nil {
return err
}
// Only print if something changed
newStatusRow := toStatusRow(task)
if !taskStatusRowEqual(lastStatusRow, newStatusRow) {
out, err := formatter.Format(ctx, []taskStatusRow{newStatusRow})
if err != nil {
return xerrors.Errorf("format task status: %w", err)
}
// hack: skip the extra column header from formatter
if formatter.FormatID() != cliui.JSONFormat().ID() {
out = strings.SplitN(out, "\n", 2)[1]
}
_, _ = fmt.Fprintln(i.Stdout, out)
if lastStatus == task.Status && taskStatusEqual(lastState, task.CurrentState) {
continue
}
out, err := formatter.Format(ctx, toStatusRow(task))
if err != nil {
return xerrors.Errorf("format task status: %w", err)
}
// hack: skip the extra column header from formatter
if formatter.FormatID() != cliui.JSONFormat().ID() {
out = strings.SplitN(out, "\n", 2)[1]
}
_, _ = fmt.Fprintln(i.Stdout, out)
if taskWatchIsEnded(task) {
if task.Status == codersdk.WorkspaceStatusStopped {
return nil
}
lastStatusRow = newStatusRow
lastStatus = task.Status
lastState = task.CurrentState
}
return nil
},
@@ -139,60 +135,37 @@ func (r *RootCmd) taskStatus() *serpent.Command {
return cmd
}
func taskWatchIsEnded(task codersdk.Task) bool {
if task.WorkspaceStatus == codersdk.WorkspaceStatusStopped {
func taskStatusEqual(s1, s2 *codersdk.TaskStateEntry) bool {
if s1 == nil && s2 == nil {
return true
}
if task.WorkspaceAgentHealth == nil || !task.WorkspaceAgentHealth.Healthy {
if s1 == nil || s2 == nil {
return false
}
if task.WorkspaceAgentLifecycle == nil || task.WorkspaceAgentLifecycle.Starting() || task.WorkspaceAgentLifecycle.ShuttingDown() {
return false
}
if task.CurrentState == nil || task.CurrentState.State == codersdk.TaskStateWorking {
return false
}
return true
return s1.State == s2.State
}
type taskStatusRow struct {
codersdk.Task `table:"r,recursive_inline"`
ChangedAgo string `json:"-" table:"state changed"`
Healthy bool `json:"-" table:"healthy"`
codersdk.Task `table:"-"`
ChangedAgo string `json:"-" table:"state changed,default_sort"`
Timestamp time.Time `json:"-" table:"-"`
TaskStatus string `json:"-" table:"status"`
TaskState string `json:"-" table:"state"`
Message string `json:"-" table:"message"`
}
func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
return r1.Status == r2.Status &&
r1.Healthy == r2.Healthy &&
taskStateEqual(r1.CurrentState, r2.CurrentState)
}
func toStatusRow(task codersdk.Task) taskStatusRow {
func toStatusRow(task codersdk.Task) []taskStatusRow {
tsr := taskStatusRow{
Task: task,
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
Timestamp: task.UpdatedAt,
TaskStatus: string(task.Status),
}
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
task.WorkspaceAgentHealth.Healthy &&
task.WorkspaceAgentLifecycle != nil &&
!task.WorkspaceAgentLifecycle.Starting() &&
!task.WorkspaceAgentLifecycle.ShuttingDown()
if task.CurrentState != nil {
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
tsr.Timestamp = task.CurrentState.Timestamp
tsr.TaskState = string(task.CurrentState.State)
tsr.Message = task.CurrentState.Message
}
return tsr
}
func taskStateEqual(se1, se2 *codersdk.TaskStateEntry) bool {
var s1, m1, s2, m2 string
if se1 != nil {
s1 = string(se1.State)
m1 = se1.Message
}
if se2 != nil {
s2 = string(se2.State)
m2 = se2.Message
}
return s1 == s2 && m1 == m2
return []taskStatusRow{tsr}
}
+96 -178
View File
@@ -16,7 +16,6 @@ import (
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -36,17 +35,26 @@ func Test_TaskStatus(t *testing.T) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{},
Count: 0,
})
return
}
case "/api/v2/users/me/workspace/doesnotexist":
httpapi.ResourceNotFound(w)
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
{
args: []string{"err-fetching-workspace"},
expectError: assert.AnError.Error(),
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/workspace/err-fetching-workspace":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.InternalServerError(w, assert.AnError)
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -55,57 +63,27 @@ func Test_TaskStatus(t *testing.T) {
},
{
args: []string{"exists"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
0s ago active true working Thinking furiously...`,
expectOutput: `STATE CHANGED STATUS STATE MESSAGE
0s ago running working Thinking furiously...`,
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: now,
Message: "Thinking furiously...",
},
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now,
UpdatedAt: now,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: now,
Message: "Thinking furiously...",
},
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
})
return
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -114,114 +92,92 @@ func Test_TaskStatus(t *testing.T) {
},
{
args: []string{"exists", "--watch"},
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
5s ago pending true
4s ago initializing true
4s ago active true
3s ago active true working Reticulating splines...
2s ago active true complete Splines reticulated successfully!`,
expectOutput: `
STATE CHANGED STATUS STATE MESSAGE
4s ago running
3s ago running working Reticulating splines...
2s ago running completed Splines reticulated successfully!
2s ago stopping completed Splines reticulated successfully!
2s ago stopped completed Splines reticulated successfully!`,
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
var calls atomic.Int64
return func(w http.ResponseWriter, r *http.Request) {
defer calls.Add(1)
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
// Return initial task state for --watch test
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusPending,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-5 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusPending,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
defer calls.Add(1)
switch calls.Load() {
case 0:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusInitializing,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusPending,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-5 * time.Second),
})
return
case 1:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
UpdatedAt: now.Add(-4 * time.Second),
Status: codersdk.TaskStatusActive,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
})
return
case 2:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: now.Add(-3 * time.Second),
Message: "Reticulating splines...",
},
Status: codersdk.TaskStatusActive,
})
return
case 3:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-4 * time.Second),
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateComplete,
State: codersdk.TaskStateCompleted,
Timestamp: now.Add(-2 * time.Second),
Message: "Splines reticulated successfully!",
},
})
case 4:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusStopping,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now.Add(-1 * time.Second),
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateCompleted,
Timestamp: now.Add(-2 * time.Second),
Message: "Splines reticulated successfully!",
},
})
case 5:
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusStopped,
CreatedAt: now.Add(-5 * time.Second),
UpdatedAt: now,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateCompleted,
Timestamp: now.Add(-2 * time.Second),
Message: "Splines reticulated successfully!",
},
Status: codersdk.TaskStatusActive,
})
return
default:
httpapi.InternalServerError(w, xerrors.New("too many calls!"))
return
}
default:
httpapi.InternalServerError(w, xerrors.Errorf("unexpected path: %q", r.URL.Path))
return
}
}
},
@@ -232,24 +188,11 @@ func Test_TaskStatus(t *testing.T) {
"id": "11111111-1111-1111-1111-111111111111",
"organization_id": "00000000-0000-0000-0000-000000000000",
"owner_id": "00000000-0000-0000-0000-000000000000",
"owner_name": "me",
"name": "exists",
"name": "",
"template_id": "00000000-0000-0000-0000-000000000000",
"template_version_id": "00000000-0000-0000-0000-000000000000",
"template_name": "",
"template_display_name": "",
"template_icon": "",
"workspace_id": null,
"workspace_name": "",
"workspace_status": "running",
"workspace_agent_id": null,
"workspace_agent_lifecycle": "ready",
"workspace_agent_health": {
"healthy": true
},
"workspace_app_id": null,
"initial_prompt": "",
"status": "active",
"status": "running",
"current_state": {
"timestamp": "2025-08-26T12:34:57Z",
"state": "working",
@@ -259,52 +202,26 @@ func Test_TaskStatus(t *testing.T) {
"created_at": "2025-08-26T12:34:56Z",
"updated_at": "2025-08-26T12:34:56Z"
}`,
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
hf: func(ctx context.Context, _ time.Time) func(w http.ResponseWriter, r *http.Request) {
ts := time.Date(2025, 8, 26, 12, 34, 56, 0, time.UTC)
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/experimental/tasks":
if r.URL.Query().Get("q") == "owner:\"me\"" {
httpapi.Write(ctx, w, http.StatusOK, struct {
Tasks []codersdk.Task `json:"tasks"`
Count int `json:"count"`
}{
Tasks: []codersdk.Task{{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Name: "exists",
OwnerName: "me",
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: ts.Add(time.Second),
Message: "Thinking furiously...",
},
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
Healthy: true,
},
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
Status: codersdk.TaskStatusActive,
}},
Count: 1,
})
return
}
case "/api/v2/users/me/workspace/exists":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Workspace{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
})
case "/api/experimental/tasks/me/11111111-1111-1111-1111-111111111111":
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
WorkspaceStatus: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
Status: codersdk.WorkspaceStatusRunning,
CreatedAt: ts,
UpdatedAt: ts,
CurrentState: &codersdk.TaskStateEntry{
State: codersdk.TaskStateWorking,
Timestamp: ts.Add(time.Second),
Message: "Thinking furiously...",
},
Status: codersdk.TaskStatusActive,
})
return
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -319,12 +236,13 @@ func Test_TaskStatus(t *testing.T) {
ctx = testutil.Context(t, testutil.WaitShort)
now = time.Now().UTC() // TODO: replace with quartz
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
client = codersdk.New(testutil.MustURL(t, srv.URL))
client = new(codersdk.Client)
sb = strings.Builder{}
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
)
t.Cleanup(srv.Close)
client.URL = testutil.MustURL(t, srv.URL)
args = append(args, tc.args...)
inv, root := clitest.New(t, args...)
inv.Stdout = &sb
-425
View File
@@ -1,425 +0,0 @@
package cli_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
// This test performs an integration-style test for tasks functionality.
//
//nolint:tparallel // The sub-tests of this test must be run sequentially.
func Test_Tasks(t *testing.T) {
t.Parallel()
// Given: a template configured for tasks
var (
ctx = testutil.Context(t, testutil.WaitLong)
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner = coderdtest.CreateFirstUser(t, client)
userClient, _ = coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
initMsg = agentapisdk.Message{
Content: "test task input for " + t.Name(),
Id: 0,
Role: "user",
Time: time.Now().UTC(),
}
authToken = uuid.NewString()
echoAgentAPI = startFakeAgentAPI(t, fakeAgentAPIEcho(ctx, t, initMsg, "hello"))
taskTpl = createAITaskTemplate(t, client, owner.OrganizationID, withAgentToken(authToken), withSidebarURL(echoAgentAPI.URL()))
taskName = strings.ReplaceAll(testutil.GetRandomName(t), "_", "-")
)
//nolint:paralleltest // The sub-tests of this test must be run sequentially.
for _, tc := range []struct {
name string
cmdArgs []string
assertFn func(stdout string, userClient *codersdk.Client)
}{
{
name: "create task",
cmdArgs: []string{"exp", "task", "create", "test task input for " + t.Name(), "--name", taskName, "--template", taskTpl.Name},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, taskName, "task name should be in output")
},
},
{
name: "list tasks after create",
cmdArgs: []string{"exp", "task", "list", "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var tasks []codersdk.Task
err := json.NewDecoder(strings.NewReader(stdout)).Decode(&tasks)
require.NoError(t, err, "list output should unmarshal properly")
require.Len(t, tasks, 1, "expected one task")
require.Equal(t, taskName, tasks[0].Name, "task name should match")
require.Equal(t, initMsg.Content, tasks[0].InitialPrompt, "initial prompt should match")
require.True(t, tasks[0].WorkspaceID.Valid, "workspace should be created")
// For the next test, we need to wait for the workspace to be healthy
ws := coderdtest.MustWorkspace(t, userClient, tasks[0].WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, ws.LatestBuild.ID)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, tasks[0].WorkspaceID.UUID).WithContext(ctx).WaitFor(coderdtest.AgentsReady)
},
},
{
name: "get task status after create",
cmdArgs: []string{"exp", "task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, task.Name, taskName, "task name should match")
require.Equal(t, codersdk.TaskStatusActive, task.Status, "task should be active")
},
},
{
name: "send task message",
cmdArgs: []string{"exp", "task", "send", taskName, "hello"},
// Assertions for this happen in the fake agent API handler.
},
{
name: "read task logs",
cmdArgs: []string{"exp", "task", "logs", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var logs []codersdk.TaskLogEntry
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&logs), "should unmarshal task logs")
require.Len(t, logs, 3, "should have 3 logs")
require.Equal(t, logs[0].Content, initMsg.Content, "first message should be the init message")
require.Equal(t, logs[0].Type, codersdk.TaskLogTypeInput, "first message should be an input")
require.Equal(t, logs[1].Content, "hello", "second message should be the sent message")
require.Equal(t, logs[1].Type, codersdk.TaskLogTypeInput, "second message should be an input")
require.Equal(t, logs[2].Content, "hello", "third message should be the echoed message")
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
},
},
{
name: "delete task",
cmdArgs: []string{"exp", "task", "delete", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
// The task should eventually no longer show up in the list of tasks
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
expClient := codersdk.NewExperimentalClient(userClient)
tasks, err := expClient.Tasks(ctx, &codersdk.TasksFilter{})
if !assert.NoError(t, err) {
return false
}
return slices.IndexFunc(tasks, func(task codersdk.Task) bool {
return task.Name == taskName
}) == -1
}, testutil.IntervalMedium)
},
},
} {
t.Run(tc.name, func(t *testing.T) {
var stdout strings.Builder
inv, root := clitest.New(t, tc.cmdArgs...)
inv.Stdout = &stdout
clitest.SetupConfig(t, userClient, root)
require.NoError(t, inv.WithContext(ctx).Run())
if tc.assertFn != nil {
tc.assertFn(stdout.String(), userClient)
}
})
}
}
func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Message, want ...string) map[string]http.HandlerFunc {
t.Helper()
var mmu sync.RWMutex
msgs := []agentapisdk.Message{initMsg}
wantCpy := make([]string, len(want))
copy(wantCpy, want)
t.Cleanup(func() {
mmu.Lock()
defer mmu.Unlock()
if !t.Failed() {
assert.Empty(t, wantCpy, "not all expected messages received: missing %v", wantCpy)
}
})
writeAgentAPIError := func(w http.ResponseWriter, err error, status int) {
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(agentapisdk.ErrorModel{
Errors: ptr.Ref([]agentapisdk.ErrorDetail{
{
Message: ptr.Ref(err.Error()),
},
}),
})
}
return map[string]http.HandlerFunc{
"/status": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(agentapisdk.GetStatusResponse{
Status: "stable",
})
},
"/messages": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
mmu.RLock()
defer mmu.RUnlock()
bs, err := json.Marshal(agentapisdk.GetMessagesResponse{
Messages: msgs,
})
if err != nil {
writeAgentAPIError(w, err, http.StatusBadRequest)
return
}
_, _ = w.Write(bs)
},
"/message": func(w http.ResponseWriter, r *http.Request) {
mmu.Lock()
defer mmu.Unlock()
var params agentapisdk.PostMessageParams
w.Header().Set("Content-Type", "application/json")
err := json.NewDecoder(r.Body).Decode(&params)
if !assert.NoError(t, err, "decode message") {
writeAgentAPIError(w, err, http.StatusBadRequest)
return
}
if len(wantCpy) == 0 {
assert.Fail(t, "unexpected message", "received message %v, but no more expected messages", params)
writeAgentAPIError(w, xerrors.New("no more expected messages"), http.StatusBadRequest)
return
}
exp := wantCpy[0]
wantCpy = wantCpy[1:]
if !assert.Equal(t, exp, params.Content, "message content mismatch") {
writeAgentAPIError(w, xerrors.New("unexpected message content: expected "+exp+", got "+params.Content), http.StatusBadRequest)
return
}
msgs = append(msgs, agentapisdk.Message{
Id: int64(len(msgs) + 1),
Content: params.Content,
Role: agentapisdk.RoleUser,
Time: time.Now().UTC(),
})
msgs = append(msgs, agentapisdk.Message{
Id: int64(len(msgs) + 1),
Content: params.Content,
Role: agentapisdk.RoleAgent,
Time: time.Now().UTC(),
})
assert.NoError(t, json.NewEncoder(w).Encode(agentapisdk.PostMessageResponse{
Ok: true,
}))
},
}
}
// setupCLITaskTest creates a test workspace with an AI task template and agent,
// with a fake agent API configured with the provided set of handlers.
// Returns the user client and workspace.
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
t.Helper()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
fakeAPI := startFakeAgentAPI(t, agentAPIHandlers)
authToken := uuid.NewString()
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
wantPrompt := "test prompt"
exp := codersdk.NewExperimentalClient(userClient)
task, err := exp.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: wantPrompt,
Name: "test-task",
})
require.NoError(t, err)
// Wait for the task's underlying workspace to be built
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
WaitFor(coderdtest.AgentsReady)
return userClient, task
}
// createAITaskTemplate creates a template configured for AI tasks with a sidebar app.
func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID, opts ...aiTemplateOpt) codersdk.Template {
t.Helper()
opt := aiTemplateOpts{
authToken: uuid.NewString(),
}
for _, o := range opts {
o(&opt)
}
taskAppID := uuid.New()
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{
{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
HasAiTasks: true,
},
},
},
},
ProvisionApply: []*proto.Response{
{
Type: &proto.Response_Apply{
Apply: &proto.ApplyComplete{
Resources: []*proto.Resource{
{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{
{
Id: uuid.NewString(),
Name: "example",
Auth: &proto.Agent_Token{
Token: opt.authToken,
},
Apps: []*proto.App{
{
Id: taskAppID.String(),
Slug: "task-sidebar",
DisplayName: "Task Sidebar",
Url: opt.appURL,
},
},
},
},
},
},
AiTasks: []*proto.AITask{
{
SidebarApp: &proto.AITaskSidebarApp{
Id: taskAppID.String(),
},
},
},
},
},
},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
return template
}
// fakeAgentAPI implements a fake AgentAPI HTTP server for testing.
type fakeAgentAPI struct {
t *testing.T
server *httptest.Server
handlers map[string]http.HandlerFunc
called map[string]bool
mu sync.Mutex
}
// startFakeAgentAPI starts an HTTP server that implements the AgentAPI endpoints.
// handlers is a map of path -> handler function.
func startFakeAgentAPI(t *testing.T, handlers map[string]http.HandlerFunc) *fakeAgentAPI {
t.Helper()
fake := &fakeAgentAPI{
t: t,
handlers: handlers,
called: make(map[string]bool),
}
mux := http.NewServeMux()
// Register all provided handlers with call tracking
for path, handler := range handlers {
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
fake.mu.Lock()
fake.called[path] = true
fake.mu.Unlock()
handler(w, r)
})
}
knownEndpoints := []string{"/status", "/messages", "/message"}
for _, endpoint := range knownEndpoints {
if handlers[endpoint] == nil {
endpoint := endpoint // capture loop variable
mux.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected call to %s %s - no handler defined", r.Method, endpoint)
})
}
}
// Default handler for unknown endpoints should cause the test to fail.
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected call to %s %s - no handler defined", r.Method, r.URL.Path)
})
fake.server = httptest.NewServer(mux)
// Register cleanup to check that all defined handlers were called
t.Cleanup(func() {
fake.server.Close()
fake.mu.Lock()
for path := range handlers {
if !fake.called[path] {
t.Errorf("handler for %s was defined but never called", path)
}
}
})
return fake
}
func (f *fakeAgentAPI) URL() string {
return f.server.URL
}
type aiTemplateOpts struct {
appURL string
authToken string
}
type aiTemplateOpt func(*aiTemplateOpts)
func withSidebarURL(url string) aiTemplateOpt {
return func(o *aiTemplateOpts) { o.appURL = url }
}
func withAgentToken(token string) aiTemplateOpt {
return func(o *aiTemplateOpts) { o.authToken = token }
}
+128
View File
@@ -0,0 +1,128 @@
package cli
import (
"fmt"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) taskCreate() *serpent.Command {
var (
orgContext = NewOrganizationContext()
client = new(codersdk.Client)
templateName string
templateVersionName string
presetName string
taskInput string
)
cmd := &serpent.Command{
Use: "create [template]",
Short: "Create an experimental task",
Middleware: serpent.Chain(
serpent.RequireRangeArgs(0, 1),
r.InitClient(client),
),
Options: serpent.OptionSet{
{
Flag: "input",
Env: "CODER_TASK_INPUT",
Value: serpent.StringOf(&taskInput),
Required: true,
},
{
Env: "CODER_TASK_TEMPLATE_NAME",
Value: serpent.StringOf(&templateName),
},
{
Env: "CODER_TASK_TEMPLATE_VERSION",
Value: serpent.StringOf(&templateVersionName),
},
{
Flag: "preset",
Env: "CODER_TASK_PRESET_NAME",
Value: serpent.StringOf(&presetName),
Default: PresetNone,
},
},
Handler: func(inv *serpent.Invocation) error {
var (
ctx = inv.Context()
expClient = codersdk.NewExperimentalClient(client)
templateVersionID uuid.UUID
templateVersionPresetID uuid.UUID
)
organization, err := orgContext.Selected(inv, client)
if err != nil {
return xerrors.Errorf("get current organization: %w", err)
}
if len(inv.Args) > 0 {
templateName, templateVersionName, _ = strings.Cut(inv.Args[0], "@")
}
if templateName == "" {
return xerrors.Errorf("template name not provided")
}
if templateVersionName != "" {
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templateName, templateVersionName)
if err != nil {
return xerrors.Errorf("get template version: %w", err)
}
templateVersionID = templateVersion.ID
} else {
template, err := client.TemplateByName(ctx, organization.ID, templateName)
if err != nil {
return xerrors.Errorf("get template: %w", err)
}
templateVersionID = template.ActiveVersionID
}
if presetName != PresetNone {
templatePresets, err := client.TemplateVersionPresets(ctx, templateVersionID)
if err != nil {
return xerrors.Errorf("get template presets: %w", err)
}
preset, err := resolvePreset(templatePresets, presetName)
if err != nil {
return xerrors.Errorf("resolve preset: %w", err)
}
templateVersionPresetID = preset.ID
}
workspace, err := expClient.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: templateVersionID,
TemplateVersionPresetID: templateVersionPresetID,
Prompt: taskInput,
})
if err != nil {
return xerrors.Errorf("create task: %w", err)
}
_, _ = fmt.Fprintf(
inv.Stdout,
"The task %s has been created at %s!\n",
cliui.Keyword(workspace.Name),
cliui.Timestamp(workspace.CreatedAt),
)
return nil
},
}
orgContext.AttachOptions(cmd)
return cmd
}
@@ -5,12 +5,14 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
@@ -31,10 +33,9 @@ func TestTaskCreate(t *testing.T) {
templateID = uuid.New()
templateVersionID = uuid.New()
templateVersionPresetID = uuid.New()
taskID = uuid.New()
)
templateAndVersionFoundHandler := func(t *testing.T, ctx context.Context, orgID uuid.UUID, templateName, templateVersionName, presetName, prompt, taskName, username string) http.HandlerFunc {
templateAndVersionFoundHandler := func(t *testing.T, ctx context.Context, orgID uuid.UUID, templateName, templateVersionName, presetName, prompt string) http.HandlerFunc {
t.Helper()
return func(w http.ResponseWriter, r *http.Request) {
@@ -45,11 +46,11 @@ func TestTaskCreate(t *testing.T) {
ID: orgID,
}},
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/%s/versions/%s", orgID, templateName, templateVersionName):
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template/versions/my-template-version", orgID):
httpapi.Write(ctx, w, http.StatusOK, codersdk.TemplateVersion{
ID: templateVersionID,
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/%s", orgID, templateName):
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template", orgID):
httpapi.Write(ctx, w, http.StatusOK, codersdk.Template{
ID: templateID,
ActiveVersionID: templateVersionID,
@@ -61,21 +62,13 @@ func TestTaskCreate(t *testing.T) {
Name: presetName,
},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
{
ID: templateID,
Name: templateName,
ActiveVersionID: templateVersionID,
},
})
case fmt.Sprintf("/api/experimental/tasks/%s", username):
case "/api/experimental/tasks/me":
var req codersdk.CreateTaskRequest
if !httpapi.Read(ctx, w, r, &req) {
return
}
assert.Equal(t, prompt, req.Input, "prompt mismatch")
assert.Equal(t, prompt, req.Prompt, "prompt mismatch")
assert.Equal(t, templateVersionID, req.TemplateVersionID, "template version mismatch")
if presetName == "" {
@@ -84,17 +77,10 @@ func TestTaskCreate(t *testing.T) {
assert.Equal(t, templateVersionPresetID, req.TemplateVersionPresetID, "template version preset id mismatch")
}
created := codersdk.Task{
ID: taskID,
Name: taskName,
httpapi.Write(ctx, w, http.StatusCreated, codersdk.Workspace{
Name: "task-wild-goldfish-27",
CreatedAt: taskCreatedAt,
}
if req.Name != "" {
assert.Equal(t, req.Name, taskName, "name mismatch")
created.Name = req.Name
}
httpapi.Write(ctx, w, http.StatusCreated, created)
})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
@@ -104,101 +90,71 @@ func TestTaskCreate(t *testing.T) {
tests := []struct {
args []string
env []string
stdin string
expectError string
expectOutput string
handler func(t *testing.T, ctx context.Context) http.HandlerFunc
}{
{
args: []string{"--stdin"},
stdin: "reads prompt from stdin",
args: []string{"my-template@my-template-version", "--input", "my custom prompt", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "reads prompt from stdin", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my custom prompt"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
},
},
{
args: []string{"my custom prompt", "--owner", "someone-else"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", "someone-else")
},
},
{
args: []string{"--name", "abc123", "my custom prompt"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("abc123"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "abc123", codersdk.Me)
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "my-template-version", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_VERSION=my-template-version"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "--org", organizationID.String()},
args: []string{"--input", "my custom prompt", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version", "CODER_TASK_INPUT=my custom prompt", "CODER_ORGANIZATION=" + organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--preset", "my-preset", "--org", organizationID.String()},
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "--template", "my-template"},
args: []string{"my-template", "--input", "my custom prompt", "--preset", "my-preset", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt"},
env: []string{"CODER_TASK_PRESET_NAME=my-preset"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "-q"},
expectOutput: taskID.String(),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--preset", "not-real-preset"},
args: []string{"my-template", "--input", "my custom prompt", "--preset", "not-real-preset"},
expectError: `preset "not-real-preset" not found`,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "not-real-template-version"},
args: []string{"my-template@not-real-template-version", "--input", "my custom prompt"},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -209,11 +165,6 @@ func TestTaskCreate(t *testing.T) {
ID: organizationID,
}},
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template", organizationID):
httpapi.Write(ctx, w, http.StatusOK, codersdk.Template{
ID: templateID,
ActiveVersionID: templateVersionID,
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template/versions/not-real-template-version", organizationID):
httpapi.ResourceNotFound(w)
default:
@@ -223,7 +174,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"my custom prompt", "--template", "not-real-template", "--org", organizationID.String()},
args: []string{"not-real-template", "--input", "my custom prompt", "--org", organizationID.String()},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -243,7 +194,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"my-custom-prompt", "--template", "template-in-different-org", "--org", anotherOrganizationID.String()},
args: []string{"template-in-different-org", "--input", "my-custom-prompt", "--org", anotherOrganizationID.String()},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -263,7 +214,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"no-org-prompt"},
args: []string{"no-org", "--input", "my-custom-prompt"},
expectError: "Must select an organization with --org=<org_name>",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -276,49 +227,6 @@ func TestTaskCreate(t *testing.T) {
}
},
},
{
args: []string{"no task templates"},
expectError: "no task templates configured",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/organizations":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
{MinimalOrganization: codersdk.MinimalOrganization{
ID: organizationID,
}},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
{
args: []string{"no template name provided"},
expectError: "template name not provided, available templates: wibble, wobble",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/organizations":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
{MinimalOrganization: codersdk.MinimalOrganization{
ID: organizationID,
}},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
{Name: "wibble"},
{Name: "wobble"},
})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
}
for _, tt := range tests {
@@ -328,7 +236,7 @@ func TestTaskCreate(t *testing.T) {
var (
ctx = testutil.Context(t, testutil.WaitShort)
srv = httptest.NewServer(tt.handler(t, ctx))
client = codersdk.New(testutil.MustURL(t, srv.URL))
client = new(codersdk.Client)
args = []string{"exp", "task", "create"}
sb strings.Builder
err error
@@ -336,9 +244,11 @@ func TestTaskCreate(t *testing.T) {
t.Cleanup(srv.Close)
client.URL, err = url.Parse(srv.URL)
require.NoError(t, err)
inv, root := clitest.New(t, append(args, tt.args...)...)
inv.Environ = serpent.ParseEnviron(tt.env, "")
inv.Stdin = strings.NewReader(tt.stdin)
inv.Stdout = &sb
inv.Stderr = &sb
clitest.SetupConfig(t, client, root)
+8 -50
View File
@@ -8,7 +8,6 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
@@ -37,12 +36,13 @@ func (r *RootCmd) taskList() *serpent.Command {
statusFilter string
all bool
user string
quiet bool
client = new(codersdk.Client)
formatter = cliui.NewOutputFormatter(
cliui.TableFormat(
[]taskListRow{},
[]string{
"id",
"name",
"status",
"state",
@@ -68,41 +68,20 @@ func (r *RootCmd) taskList() *serpent.Command {
)
cmd := &serpent.Command{
Use: "list",
Short: "List experimental tasks",
Long: FormatExamples(
Example{
Description: "List tasks for the current user.",
Command: "coder exp task list",
},
Example{
Description: "List tasks for a specific user.",
Command: "coder exp task list --user someone-else",
},
Example{
Description: "List all tasks you can view.",
Command: "coder exp task list --all",
},
Example{
Description: "List all your running tasks.",
Command: "coder exp task list --status running",
},
Example{
Description: "As above, but only show IDs.",
Command: "coder exp task list --status running --quiet",
},
),
Use: "list",
Short: "List experimental tasks",
Aliases: []string{"ls"},
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
r.InitClient(client),
),
Options: serpent.OptionSet{
{
Name: "status",
Description: "Filter by task status.",
Description: "Filter by task status (e.g. running, failed, etc).",
Flag: "status",
Default: "",
Value: serpent.EnumOf(&statusFilter, slice.ToStrings(codersdk.AllTaskStatuses())...),
Value: serpent.StringOf(&statusFilter),
},
{
Name: "all",
@@ -119,21 +98,8 @@ func (r *RootCmd) taskList() *serpent.Command {
Default: "",
Value: serpent.StringOf(&user),
},
{
Name: "quiet",
Description: "Only display task IDs.",
Flag: "quiet",
FlagShorthand: "q",
Default: "false",
Value: serpent.BoolOf(&quiet),
},
},
Handler: func(inv *serpent.Invocation) error {
client, err := r.InitClient(inv)
if err != nil {
return err
}
ctx := inv.Context()
exp := codersdk.NewExperimentalClient(client)
@@ -144,20 +110,12 @@ func (r *RootCmd) taskList() *serpent.Command {
tasks, err := exp.Tasks(ctx, &codersdk.TasksFilter{
Owner: targetUser,
Status: codersdk.TaskStatus(statusFilter),
Status: statusFilter,
})
if err != nil {
return xerrors.Errorf("list tasks: %w", err)
}
if quiet {
for _, task := range tasks {
_, _ = fmt.Fprintln(inv.Stdout, task.ID.String())
}
return nil
}
// If no rows and not JSON, show a friendly message.
if len(tasks) == 0 && formatter.FormatID() != cliui.JSONFormat().ID() {
_, _ = fmt.Fprintln(inv.Stderr, "No tasks found.")
@@ -6,8 +6,6 @@ import (
"database/sql"
"encoding/json"
"io"
"slices"
"strings"
"testing"
"github.com/google/uuid"
@@ -22,7 +20,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
@@ -30,7 +27,7 @@ import (
)
// makeAITask creates an AI-task workspace.
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) database.Task {
func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UUID, transition database.WorkspaceTransition, prompt string) (workspace database.WorkspaceTable) {
t.Helper()
tv := dbfake.TemplateVersion(t, db).
@@ -92,27 +89,7 @@ func makeAITask(t *testing.T, db database.Store, orgID, adminID, ownerID uuid.UU
)
require.NoError(t, err)
// Create a task record in the tasks table for the new data model.
task := dbgen.Task(t, db, database.TaskTable{
OrganizationID: orgID,
OwnerID: ownerID,
Name: build.Workspace.Name,
WorkspaceID: uuid.NullUUID{UUID: build.Workspace.ID, Valid: true},
TemplateVersionID: tv.TemplateVersion.ID,
TemplateParameters: []byte("{}"),
Prompt: prompt,
CreatedAt: dbtime.Now(),
})
// Link the task to the workspace app.
dbgen.TaskWorkspaceApp(t, db, database.TaskWorkspaceApp{
TaskID: task.ID,
WorkspaceBuildNumber: build.Build.BuildNumber,
WorkspaceAgentID: uuid.NullUUID{UUID: agentID, Valid: true},
WorkspaceAppID: uuid.NullUUID{UUID: app.ID, Valid: true},
})
return task
return build.Workspace
}
func TestExpTaskList(t *testing.T) {
@@ -149,7 +126,7 @@ func TestExpTaskList(t *testing.T) {
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
wantPrompt := "build me a web app"
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, wantPrompt)
inv, root := clitest.New(t, "exp", "task", "list", "--column", "id,name,status,initial prompt")
clitest.SetupConfig(t, memberClient, root)
@@ -161,8 +138,8 @@ func TestExpTaskList(t *testing.T) {
require.NoError(t, err)
// Validate the table includes the task and status.
pty.ExpectMatch(task.Name)
pty.ExpectMatch("initializing")
pty.ExpectMatch(ws.Name)
pty.ExpectMatch("running")
pty.ExpectMatch(wantPrompt)
})
@@ -175,12 +152,12 @@ func TestExpTaskList(t *testing.T) {
owner := coderdtest.CreateFirstUser(t, client)
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Create two AI tasks: one initializing, one paused.
initializingTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me initializing")
pausedTask := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Create two AI tasks: one running, one stopped.
running := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
stopped := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Use JSON output to reliably validate filtering.
inv, root := clitest.New(t, "exp", "task", "list", "--status=paused", "--output=json")
inv, root := clitest.New(t, "exp", "task", "list", "--status=stopped", "--output=json")
clitest.SetupConfig(t, memberClient, root)
ctx := testutil.Context(t, testutil.WaitShort)
@@ -194,10 +171,10 @@ func TestExpTaskList(t *testing.T) {
var tasks []codersdk.Task
require.NoError(t, json.Unmarshal(stdout.Bytes(), &tasks))
// Only the paused task is returned.
// Only the stopped task is returned.
require.Len(t, tasks, 1, "expected one task after filtering")
require.Equal(t, pausedTask.ID, tasks[0].ID)
require.NotEqual(t, initializingTask.ID, tasks[0].ID)
require.Equal(t, stopped.ID, tasks[0].ID)
require.NotEqual(t, running.ID, tasks[0].ID)
})
t.Run("UserFlag_Me_Table", func(t *testing.T) {
@@ -209,7 +186,7 @@ func TestExpTaskList(t *testing.T) {
_, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
_ = makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "other-task")
task := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
ws := makeAITask(t, db, owner.OrganizationID, owner.UserID, owner.UserID, database.WorkspaceTransitionStart, "me-task")
inv, root := clitest.New(t, "exp", "task", "list", "--user", "me")
//nolint:gocritic // Owner client is intended here smoke test the member task not showing up.
@@ -221,44 +198,7 @@ func TestExpTaskList(t *testing.T) {
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
pty.ExpectMatch(task.Name)
})
t.Run("Quiet", func(t *testing.T) {
t.Parallel()
// Quiet logger to reduce noise.
quiet := slog.Make(sloghuman.Sink(io.Discard))
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &quiet})
owner := coderdtest.CreateFirstUser(t, client)
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Given: We have two tasks
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me active")
task2 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
// Given: We add the `--quiet` flag
inv, root := clitest.New(t, "exp", "task", "list", "--quiet")
clitest.SetupConfig(t, memberClient, root)
ctx := testutil.Context(t, testutil.WaitShort)
var stdout bytes.Buffer
inv.Stdout = &stdout
inv.Stderr = &stdout
// When: We run the command
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
want := []string{task1.ID.String(), task2.ID.String()}
got := slice.Filter(strings.Split(stdout.String(), "\n"), func(s string) bool {
return len(s) != 0
})
slices.Sort(want)
slices.Sort(got)
require.Equal(t, want, got)
pty.ExpectMatch(ws.Name)
})
}

Some files were not shown because too many files have changed in this diff Show More