Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 41dfbc7f1d | |||
| ee8e8cb805 | |||
| 4793806569 | |||
| 03440f6ae2 | |||
| 7afe6c813b | |||
| 536920459d | |||
| c0f1b9d73e | |||
| a056cb6577 | |||
| 0a73f842b3 |
@@ -1,21 +1,13 @@
|
||||
#!/bin/sh
|
||||
|
||||
install_devcontainer_cli() {
|
||||
set -e
|
||||
echo "🔧 Installing DevContainer CLI..."
|
||||
cd "$(dirname "$0")/../tools/devcontainer-cli"
|
||||
npm ci --omit=dev
|
||||
ln -sf "$(pwd)/node_modules/.bin/devcontainer" "$(npm config get prefix)/bin/devcontainer"
|
||||
npm install -g @devcontainers/cli@0.80.0 --integrity=sha512-w2EaxgjyeVGyzfA/KUEZBhyXqu/5PyWNXcnrXsZOBrt3aN2zyGiHrXoG54TF6K0b5DSCF01Rt5fnIyrCeFzFKw==
|
||||
}
|
||||
|
||||
install_ssh_config() {
|
||||
echo "🔑 Installing SSH configuration..."
|
||||
if [ -d /mnt/home/coder/.ssh ]; then
|
||||
rsync -a /mnt/home/coder/.ssh/ ~/.ssh/
|
||||
chmod 0700 ~/.ssh
|
||||
else
|
||||
echo "⚠️ SSH directory not found."
|
||||
fi
|
||||
rsync -a /mnt/home/coder/.ssh/ ~/.ssh/
|
||||
chmod 0700 ~/.ssh
|
||||
}
|
||||
|
||||
install_git_config() {
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
{
|
||||
"name": "devcontainer-cli",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "devcontainer-cli",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@devcontainers/cli": "^0.80.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@devcontainers/cli": {
|
||||
"version": "0.80.0",
|
||||
"resolved": "https://registry.npmjs.org/@devcontainers/cli/-/cli-0.80.0.tgz",
|
||||
"integrity": "sha512-w2EaxgjyeVGyzfA/KUEZBhyXqu/5PyWNXcnrXsZOBrt3aN2zyGiHrXoG54TF6K0b5DSCF01Rt5fnIyrCeFzFKw==",
|
||||
"bin": {
|
||||
"devcontainer": "devcontainer.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": "^16.13.0 || >=18.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"name": "devcontainer-cli",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"@devcontainers/cli": "^0.80.0"
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,5 @@ ignorePatterns:
|
||||
- pattern: "claude.ai"
|
||||
- pattern: "splunk.com"
|
||||
- pattern: "stackoverflow.com/questions"
|
||||
- pattern: "developer.hashicorp.com/terraform/language"
|
||||
aliveStatusCodes:
|
||||
- 200
|
||||
|
||||
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.24.6"
|
||||
default: "1.24.10"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -16,7 +16,7 @@ runs:
|
||||
- name: Setup Node
|
||||
uses: actions/setup-node@0a44ba7841725637a19e28fa30b79a866c81b0a6 # v4.0.4
|
||||
with:
|
||||
node-version: 22.19.0
|
||||
node-version: 20.19.4
|
||||
# See https://github.com/actions/setup-node#caching-global-packages-data
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: ${{ inputs.directory }}/pnpm-lock.yaml
|
||||
|
||||
@@ -80,9 +80,6 @@ updates:
|
||||
mui:
|
||||
patterns:
|
||||
- "@mui*"
|
||||
radix:
|
||||
patterns:
|
||||
- "@radix-ui/*"
|
||||
react:
|
||||
patterns:
|
||||
- "react"
|
||||
@@ -107,7 +104,6 @@ updates:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- version-update:semver-major
|
||||
- dependency-name: "@playwright/test"
|
||||
open-pull-requests-limit: 15
|
||||
|
||||
- package-ecosystem: "terraform"
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
<!--
|
||||
|
||||
If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting.
|
||||
|
||||
-->
|
||||
|
||||
+179
-87
@@ -4,7 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
@@ -35,12 +34,12 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -124,7 +123,7 @@ jobs:
|
||||
# runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
# uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
# with:
|
||||
# fetch-depth: 1
|
||||
# # See: https://github.com/stefanzweifel/git-auto-commit-action?tab=readme-ov-file#commits-made-by-this-action-do-not-trigger-new-workflow-runs
|
||||
@@ -157,12 +156,12 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -191,7 +190,7 @@ jobs:
|
||||
|
||||
# Check for any typos
|
||||
- name: Check for typos
|
||||
uses: crate-ci/typos@85f62a8a84f939ae994ab3763f01a0296d61a7ee # v1.36.2
|
||||
uses: crate-ci/typos@52bd719c2c91f9d676e2aa359fc8e0db8925e6d8 # v1.35.3
|
||||
with:
|
||||
config: .github/workflows/typos.toml
|
||||
|
||||
@@ -204,7 +203,7 @@ jobs:
|
||||
|
||||
# Needed for helm chart linting
|
||||
- name: Install helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # v4.3.1
|
||||
uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
|
||||
with:
|
||||
version: v3.9.2
|
||||
|
||||
@@ -235,12 +234,12 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -291,12 +290,12 @@ jobs:
|
||||
timeout-minutes: 7
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -341,7 +340,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -367,7 +366,7 @@ jobs:
|
||||
uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b # v0.1.0
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -532,6 +531,9 @@ jobs:
|
||||
with:
|
||||
api-key: ${{ secrets.DATADOG_API_KEY }}
|
||||
|
||||
# NOTE: this could instead be defined as a matrix strategy, but we want to
|
||||
# only block merging if tests on postgres 13 fail. Using a matrix strategy
|
||||
# here makes the check in the above `required` job rather complicated.
|
||||
test-go-pg-17:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
needs:
|
||||
@@ -544,12 +546,12 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -593,12 +595,12 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -653,12 +655,12 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -680,12 +682,12 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -713,12 +715,12 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -785,12 +787,12 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
# 👇 Ensures Chromatic can read your full git history
|
||||
fetch-depth: 0
|
||||
@@ -806,7 +808,7 @@ jobs:
|
||||
# the check to pass. This is desired in PRs, but not in mainline.
|
||||
- name: Publish to Chromatic (non-mainline)
|
||||
if: github.ref != 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
|
||||
uses: chromaui/action@58d9ffb36c90c97a02d061544ecc849cc4a242a9 # v13.1.3
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -838,7 +840,7 @@ jobs:
|
||||
# infinitely "in progress" in mainline unless we re-review each build.
|
||||
- name: Publish to Chromatic (mainline)
|
||||
if: github.ref == 'refs/heads/main' && github.repository_owner == 'coder'
|
||||
uses: chromaui/action@20c7e42e1b2f6becd5d188df9acb02f3e2f51519 # v13.2.0
|
||||
uses: chromaui/action@58d9ffb36c90c97a02d061544ecc849cc4a242a9 # v13.1.3
|
||||
env:
|
||||
NODE_OPTIONS: "--max_old_space_size=4096"
|
||||
STORYBOOK: true
|
||||
@@ -866,12 +868,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
# 0 is required here for version.sh to work.
|
||||
fetch-depth: 0
|
||||
@@ -920,12 +922,10 @@ jobs:
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- changes
|
||||
- fmt
|
||||
- lint
|
||||
- gen
|
||||
- test-go-pg
|
||||
- test-go-pg-17
|
||||
- test-go-race-pg
|
||||
- test-js
|
||||
- test-e2e
|
||||
@@ -937,19 +937,17 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Ensure required checks
|
||||
run: | # zizmor: ignore[template-injection] We're just reading needs.x.result here, no risk of injection
|
||||
echo "Checking required checks"
|
||||
echo "- changes: ${{ needs.changes.result }}"
|
||||
echo "- fmt: ${{ needs.fmt.result }}"
|
||||
echo "- lint: ${{ needs.lint.result }}"
|
||||
echo "- gen: ${{ needs.gen.result }}"
|
||||
echo "- test-go-pg: ${{ needs.test-go-pg.result }}"
|
||||
echo "- test-go-pg-17: ${{ needs.test-go-pg-17.result }}"
|
||||
echo "- test-go-race-pg: ${{ needs.test-go-race-pg.result }}"
|
||||
echo "- test-js: ${{ needs.test-js.result }}"
|
||||
echo "- test-e2e: ${{ needs.test-e2e.result }}"
|
||||
@@ -970,12 +968,12 @@ jobs:
|
||||
needs: changes
|
||||
# We always build the dylibs on Go changes to verify we're not merging unbuildable code,
|
||||
# but they need only be signed and uploaded on coder/coder main.
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-macos-latest' || 'macos-latest' }}
|
||||
steps:
|
||||
# Harden Runner doesn't work on macOS
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -998,7 +996,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Install rcodesign
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
wget -O /tmp/rcodesign.tar.gz https://github.com/indygreg/apple-platform-rs/releases/download/apple-codesign%2F0.22.0/apple-codesign-0.22.0-macos-universal.tar.gz
|
||||
@@ -1009,7 +1007,7 @@ jobs:
|
||||
rm /tmp/rcodesign.tar.gz
|
||||
|
||||
- name: Setup Apple Developer certificate and API key
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
touch /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
|
||||
@@ -1030,12 +1028,12 @@ jobs:
|
||||
make gen/mark-fresh
|
||||
make build/coder-dylib
|
||||
env:
|
||||
CODER_SIGN_DARWIN: ${{ (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && '1' || '0' }}
|
||||
CODER_SIGN_DARWIN: ${{ github.ref == 'refs/heads/main' && '1' || '0' }}
|
||||
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
|
||||
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: dylibs
|
||||
@@ -1045,7 +1043,7 @@ jobs:
|
||||
retention-days: 7
|
||||
|
||||
- name: Delete Apple Developer certificate and API key
|
||||
if: ${{ github.repository_owner == 'coder' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
if: ${{ github.repository_owner == 'coder' && github.ref == 'refs/heads/main' }}
|
||||
run: rm -f /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
|
||||
|
||||
check-build:
|
||||
@@ -1057,12 +1055,12 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -1095,7 +1093,7 @@ jobs:
|
||||
needs:
|
||||
- changes
|
||||
- build-dylib
|
||||
if: (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
|
||||
if: github.ref == 'refs/heads/main' && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-22.04' }}
|
||||
permissions:
|
||||
# Necessary to push docker images to ghcr.io.
|
||||
@@ -1112,12 +1110,12 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -1158,7 +1156,7 @@ jobs:
|
||||
|
||||
# Necessary for signing Windows binaries.
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@dded0888837ed1f317902acf8a20df0ad188d165 # v5.0.0
|
||||
uses: actions/setup-java@c5195efecf7bdfc987ee8bae7a71cb8b11521c00 # v4.7.1
|
||||
with:
|
||||
distribution: "zulu"
|
||||
java-version: "11.0"
|
||||
@@ -1191,14 +1189,14 @@ jobs:
|
||||
# Setup GCloud for signing Windows binaries.
|
||||
- name: Authenticate to Google Cloud
|
||||
id: gcloud_auth
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
|
||||
token_format: "access_token"
|
||||
|
||||
- name: Setup GCloud SDK
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
@@ -1248,45 +1246,40 @@ jobs:
|
||||
id: build-docker
|
||||
env:
|
||||
CODER_IMAGE_BASE: ghcr.io/coder/coder-preview
|
||||
CODER_IMAGE_TAG_PREFIX: main
|
||||
DOCKER_CLI_EXPERIMENTAL: "enabled"
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
|
||||
# build Docker images for each architecture
|
||||
version="$(./scripts/version.sh)"
|
||||
tag="${version//+/-}"
|
||||
tag="main-${version//+/-}"
|
||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# build images for each architecture
|
||||
# note: omitting the -j argument to avoid race conditions when pushing
|
||||
make build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
|
||||
|
||||
# only push if we are on main branch or release branch
|
||||
if [[ "${GITHUB_REF}" == "refs/heads/main" || "${GITHUB_REF}" == refs/heads/release/* ]]; then
|
||||
# only push if we are on main branch
|
||||
if [ "${GITHUB_REF}" == "refs/heads/main" ]; then
|
||||
# build and push multi-arch manifest, this depends on the other images
|
||||
# being pushed so will automatically push them
|
||||
# note: omitting the -j argument to avoid race conditions when pushing
|
||||
make push/build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
|
||||
|
||||
# Define specific tags
|
||||
tags=("$tag")
|
||||
if [ "${GITHUB_REF}" == "refs/heads/main" ]; then
|
||||
tags+=("main" "latest")
|
||||
elif [[ "${GITHUB_REF}" == refs/heads/release/* ]]; then
|
||||
tags+=("release-${GITHUB_REF#refs/heads/release/}")
|
||||
fi
|
||||
tags=("$tag" "main" "latest")
|
||||
|
||||
# Create and push a multi-arch manifest for each tag
|
||||
# we are adding `latest` tag and keeping `main` for backward
|
||||
# compatibality
|
||||
for t in "${tags[@]}"; do
|
||||
echo "Pushing multi-arch manifest for tag: $t"
|
||||
# shellcheck disable=SC2046
|
||||
./scripts/build_docker_multiarch.sh \
|
||||
--push \
|
||||
--target "ghcr.io/coder/coder-preview:$t" \
|
||||
--version "$version" \
|
||||
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
|
||||
# shellcheck disable=SC2046
|
||||
./scripts/build_docker_multiarch.sh \
|
||||
--push \
|
||||
--target "ghcr.io/coder/coder-preview:$t" \
|
||||
--version "$version" \
|
||||
$(cat build/coder_"$version"_linux_{amd64,arm64,armv7}.tag)
|
||||
done
|
||||
fi
|
||||
|
||||
@@ -1330,7 +1323,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:main"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1367,7 +1360,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:latest"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1404,7 +1397,7 @@ jobs:
|
||||
id: attest_version
|
||||
if: github.ref == 'refs/heads/main'
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: "ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}"
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -1477,28 +1470,112 @@ jobs:
|
||||
./build/*.deb
|
||||
retention-days: 7
|
||||
|
||||
# Deploy is handled in deploy.yaml so we can apply concurrency limits.
|
||||
deploy:
|
||||
name: "deploy"
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
needs:
|
||||
- changes
|
||||
- build
|
||||
if: |
|
||||
(github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/'))
|
||||
github.ref == 'refs/heads/main' && !github.event.pull_request.head.repo.fork
|
||||
&& needs.changes.outputs.docs-only == 'false'
|
||||
&& !github.event.pull_request.head.repo.fork
|
||||
uses: ./.github/workflows/deploy.yaml
|
||||
with:
|
||||
image: ${{ needs.build.outputs.IMAGE }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
packages: write # to retag image as dogfood
|
||||
secrets:
|
||||
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
|
||||
FLY_PARIS_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
|
||||
FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
|
||||
FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
|
||||
FLY_JNB_CODER_PROXY_SESSION_TOKEN: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
|
||||
|
||||
- name: Set up Google Cloud SDK
|
||||
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
|
||||
|
||||
- name: Set up Flux CLI
|
||||
uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4
|
||||
with:
|
||||
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
|
||||
version: "2.5.1"
|
||||
|
||||
- name: Get Cluster Credentials
|
||||
uses: google-github-actions/get-gke-credentials@8e574c49425fa7efed1e74650a449bfa6a23308a # v2.3.4
|
||||
with:
|
||||
cluster_name: dogfood-v2
|
||||
location: us-central1-a
|
||||
project_id: coder-dogfood-v2
|
||||
|
||||
- name: Reconcile Flux
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
flux --namespace flux-system reconcile source git flux-system
|
||||
flux --namespace flux-system reconcile source git coder-main
|
||||
flux --namespace flux-system reconcile kustomization flux-system
|
||||
flux --namespace flux-system reconcile kustomization coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder-provisioner
|
||||
flux --namespace coder reconcile helmrelease coder
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner
|
||||
|
||||
# Just updating Flux is usually not enough. The Helm release may get
|
||||
# redeployed, but unless something causes the Deployment to update the
|
||||
# pods won't be recreated. It's important that the pods get recreated,
|
||||
# since we use `imagePullPolicy: Always` to ensure we're running the
|
||||
# latest image.
|
||||
- name: Rollout Deployment
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
kubectl --namespace coder rollout restart deployment/coder
|
||||
kubectl --namespace coder rollout status deployment/coder
|
||||
kubectl --namespace coder rollout restart deployment/coder-provisioner
|
||||
kubectl --namespace coder rollout status deployment/coder-provisioner
|
||||
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged
|
||||
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged
|
||||
|
||||
deploy-wsproxies:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: github.ref == 'refs/heads/main' && !github.event.pull_request.head.repo.fork
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup flyctl
|
||||
uses: superfly/flyctl-actions/setup-flyctl@fc53c09e1bc3be6f54706524e3b82c4f462f77be # v1.5
|
||||
|
||||
- name: Deploy workspace proxies
|
||||
run: |
|
||||
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
|
||||
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
|
||||
env:
|
||||
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
|
||||
IMAGE: ${{ needs.build.outputs.IMAGE }}
|
||||
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
|
||||
|
||||
# sqlc-vet runs a postgres docker container, runs Coder migrations, and then
|
||||
# runs sqlc-vet to ensure all queries are valid. This catches any mistakes
|
||||
@@ -1509,12 +1586,12 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -1537,7 +1614,6 @@ jobs:
|
||||
steps:
|
||||
- name: Send Slack notification
|
||||
run: |
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data '{
|
||||
"blocks": [
|
||||
@@ -1549,6 +1625,23 @@ jobs:
|
||||
"emoji": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"fields": [
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Workflow:*\n'"${GITHUB_WORKFLOW}"'"
|
||||
},
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Committer:*\n'"${GITHUB_ACTOR}"'"
|
||||
},
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Commit:*\n'"${GITHUB_SHA}"'"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
@@ -1560,7 +1653,7 @@ jobs:
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": '"$ESCAPED_PROMPT"'
|
||||
"text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame."
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -1568,4 +1661,3 @@ jobs:
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }}
|
||||
RUN_URL: "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
BLINK_CI_FAILURE_PROMPT: ${{ vars.BLINK_CI_FAILURE_PROMPT }}
|
||||
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
if: ${{ github.event_name == 'pull_request_target' && !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: release-labels
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
# This script ensures PR title and labels are in sync:
|
||||
#
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
name: deploy
|
||||
|
||||
on:
|
||||
# Via workflow_call, called from ci.yaml
|
||||
workflow_call:
|
||||
inputs:
|
||||
image:
|
||||
description: "Image and tag to potentially deploy. Current branch will be validated against should-deploy check."
|
||||
required: true
|
||||
type: string
|
||||
secrets:
|
||||
FLY_API_TOKEN:
|
||||
required: true
|
||||
FLY_PARIS_CODER_PROXY_SESSION_TOKEN:
|
||||
required: true
|
||||
FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN:
|
||||
required: true
|
||||
FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN:
|
||||
required: true
|
||||
FLY_JNB_CODER_PROXY_SESSION_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }} # no per-branch concurrency
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# Determines if the given branch should be deployed to dogfood.
|
||||
should-deploy:
|
||||
name: should-deploy
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check if deploy is enabled
|
||||
id: check
|
||||
run: |
|
||||
set -euo pipefail
|
||||
verdict="$(./scripts/should_deploy.sh)"
|
||||
echo "verdict=$verdict" >> "$GITHUB_OUTPUT"
|
||||
|
||||
deploy:
|
||||
name: "deploy"
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
needs: should-deploy
|
||||
if: needs.should-deploy.outputs.verdict == 'DEPLOY'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: GHCR Login
|
||||
uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3.5.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
|
||||
|
||||
- name: Set up Google Cloud SDK
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
|
||||
- name: Set up Flux CLI
|
||||
uses: fluxcd/flux2/action@6bf37f6a560fd84982d67f853162e4b3c2235edb # v2.6.4
|
||||
with:
|
||||
# Keep this and the github action up to date with the version of flux installed in dogfood cluster
|
||||
version: "2.7.0"
|
||||
|
||||
- name: Get Cluster Credentials
|
||||
uses: google-github-actions/get-gke-credentials@3da1e46a907576cefaa90c484278bb5b259dd395 # v3.0.0
|
||||
with:
|
||||
cluster_name: dogfood-v2
|
||||
location: us-central1-a
|
||||
project_id: coder-dogfood-v2
|
||||
|
||||
# Retag image as dogfood while maintaining the multi-arch manifest
|
||||
- name: Tag image as dogfood
|
||||
run: docker buildx imagetools create --tag "ghcr.io/coder/coder-preview:dogfood" "$IMAGE"
|
||||
env:
|
||||
IMAGE: ${{ inputs.image }}
|
||||
|
||||
- name: Reconcile Flux
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
flux --namespace flux-system reconcile source git flux-system
|
||||
flux --namespace flux-system reconcile source git coder-main
|
||||
flux --namespace flux-system reconcile kustomization flux-system
|
||||
flux --namespace flux-system reconcile kustomization coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder
|
||||
flux --namespace flux-system reconcile source chart coder-coder-provisioner
|
||||
flux --namespace coder reconcile helmrelease coder
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner-tagged
|
||||
flux --namespace coder reconcile helmrelease coder-provisioner-tagged-prebuilds
|
||||
|
||||
# Just updating Flux is usually not enough. The Helm release may get
|
||||
# redeployed, but unless something causes the Deployment to update the
|
||||
# pods won't be recreated. It's important that the pods get recreated,
|
||||
# since we use `imagePullPolicy: Always` to ensure we're running the
|
||||
# latest image.
|
||||
- name: Rollout Deployment
|
||||
run: |
|
||||
set -euxo pipefail
|
||||
kubectl --namespace coder rollout restart deployment/coder
|
||||
kubectl --namespace coder rollout status deployment/coder
|
||||
kubectl --namespace coder rollout restart deployment/coder-provisioner
|
||||
kubectl --namespace coder rollout status deployment/coder-provisioner
|
||||
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged
|
||||
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged
|
||||
kubectl --namespace coder rollout restart deployment/coder-provisioner-tagged-prebuilds
|
||||
kubectl --namespace coder rollout status deployment/coder-provisioner-tagged-prebuilds
|
||||
|
||||
deploy-wsproxies:
|
||||
runs-on: ubuntu-latest
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup flyctl
|
||||
uses: superfly/flyctl-actions/setup-flyctl@fc53c09e1bc3be6f54706524e3b82c4f462f77be # v1.5
|
||||
|
||||
- name: Deploy workspace proxies
|
||||
run: |
|
||||
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
|
||||
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
|
||||
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
|
||||
env:
|
||||
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
|
||||
IMAGE: ${{ inputs.image }}
|
||||
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
|
||||
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
|
||||
@@ -38,12 +38,12 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
# This uses OIDC authentication, so no auth variables are required.
|
||||
- name: Build base Docker image via depot.dev
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
|
||||
with:
|
||||
project: wl5hnrrkns
|
||||
context: base-build-context
|
||||
|
||||
@@ -23,14 +23,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Node
|
||||
uses: ./.github/actions/setup-node
|
||||
|
||||
- uses: tj-actions/changed-files@4563c729c555b4141fac99c80f699f571219b836 # v45.0.7
|
||||
- uses: tj-actions/changed-files@f963b3f3562b00b6d2dd25efc390eb04e51ef6c6 # v45.0.7
|
||||
id: changed-files
|
||||
with:
|
||||
files: |
|
||||
|
||||
@@ -26,17 +26,17 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Nix
|
||||
uses: nixbuild/nix-quick-install-action@1f095fee853b33114486cfdeae62fa099cda35a9 # v33
|
||||
uses: nixbuild/nix-quick-install-action@63ca48f939ee3b8d835f4126562537df0fee5b91 # v32
|
||||
with:
|
||||
# Pinning to 2.28 here, as Nix gets a "error: [json.exception.type_error.302] type must be array, but is string"
|
||||
# on version 2.29 and above.
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and push Non-Nix image
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
|
||||
with:
|
||||
project: b4q6ltmpzh
|
||||
token: ${{ secrets.DEPOT_TOKEN }}
|
||||
@@ -125,12 +125,12 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-tf
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
uses: coder/setup-ramdisk-action@e1100847ab2d7bcd9d14bcda8f2d1b0f07b36f1b # v0.1.0
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
@@ -143,7 +143,7 @@ jobs:
|
||||
|
||||
DB=ci gotestsum \
|
||||
--format standard-quiet --packages "./..." \
|
||||
-- -timeout=20m -v -p "$NUM_PARALLEL_PACKAGES" -parallel="$NUM_PARALLEL_TESTS" "$TESTCOUNT"
|
||||
-- -timeout=20m -v -p $NUM_PARALLEL_PACKAGES -parallel=$NUM_PARALLEL_TESTS $TESTCOUNT
|
||||
|
||||
- name: Upload Embedded Postgres Cache
|
||||
uses: ./.github/actions/embedded-pg-cache/upload
|
||||
@@ -170,7 +170,6 @@ jobs:
|
||||
steps:
|
||||
- name: Send Slack notification
|
||||
run: |
|
||||
ESCAPED_PROMPT=$(printf "%s" "<@U09LQ75AHKR> $BLINK_CI_FAILURE_PROMPT" | jq -Rsa .)
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data '{
|
||||
"blocks": [
|
||||
@@ -182,6 +181,23 @@ jobs:
|
||||
"emoji": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"fields": [
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Workflow:*\n'"${GITHUB_WORKFLOW}"'"
|
||||
},
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Committer:*\n'"${GITHUB_ACTOR}"'"
|
||||
},
|
||||
{
|
||||
"type": "mrkdwn",
|
||||
"text": "*Commit:*\n'"${GITHUB_SHA}"'"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
@@ -193,7 +209,7 @@ jobs:
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": '"$ESCAPED_PROMPT"'
|
||||
"text": "<@U08TJ4YNCA3> investigate this CI failure. Check logs, search for existing issues, use git blame to find who last modified failing tests, create issue in coder/internal (not public repo), use title format \"flake: TestName\" for flaky tests, and assign to the person from git blame."
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -201,4 +217,3 @@ jobs:
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.CI_FAILURE_SLACK_WEBHOOK }}
|
||||
RUN_URL: "${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
|
||||
BLINK_CI_FAILURE_PROMPT: ${{ vars.BLINK_CI_FAILURE_PROMPT }}
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,12 +39,12 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -76,12 +76,12 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,12 +228,12 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -337,7 +337,7 @@ jobs:
|
||||
kubectl create namespace "pr${PR_NUMBER}"
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -370,7 +370,6 @@ jobs:
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm install coder-db bitnami/postgresql \
|
||||
--namespace "pr${PR_NUMBER}" \
|
||||
--set image.repository=bitnamilegacy/postgresql \
|
||||
--set auth.username=coder \
|
||||
--set auth.password=coder \
|
||||
--set auth.database=coder \
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Allow only maintainers/admins
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@v7.0.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
steps:
|
||||
# Harden Runner doesn't work on macOS.
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -164,12 +164,12 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
|
||||
# Necessary for signing Windows binaries.
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@dded0888837ed1f317902acf8a20df0ad188d165 # v5.0.0
|
||||
uses: actions/setup-java@c5195efecf7bdfc987ee8bae7a71cb8b11521c00 # v4.7.1
|
||||
with:
|
||||
distribution: "zulu"
|
||||
java-version: "11.0"
|
||||
@@ -317,14 +317,14 @@ jobs:
|
||||
# Setup GCloud for signing Windows binaries.
|
||||
- name: Authenticate to Google Cloud
|
||||
id: gcloud_auth
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
|
||||
token_format: "access_token"
|
||||
|
||||
- name: Setup GCloud SDK
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # v3.0.1
|
||||
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # v2.2.0
|
||||
|
||||
- name: Download dylibs
|
||||
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
|
||||
@@ -397,7 +397,7 @@ jobs:
|
||||
# This uses OIDC authentication, so no auth variables are required.
|
||||
- name: Build base Docker image via depot.dev
|
||||
if: steps.image-base-tag.outputs.tag != ''
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
uses: depot/build-push-action@2583627a84956d07561420dcc1d0eb1f2af3fac0 # v1.15.0
|
||||
with:
|
||||
project: wl5hnrrkns
|
||||
context: base-build-context
|
||||
@@ -454,7 +454,7 @@ jobs:
|
||||
id: attest_base
|
||||
if: ${{ !inputs.dry_run && steps.image-base-tag.outputs.tag != '' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: ${{ steps.image-base-tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
id: attest_main
|
||||
if: ${{ !inputs.dry_run }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: ${{ steps.build_docker.outputs.multiarch_image }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -614,7 +614,7 @@ jobs:
|
||||
id: attest_latest
|
||||
if: ${{ !inputs.dry_run && steps.build_docker.outputs.created_latest_tag == 'true' }}
|
||||
continue-on-error: true
|
||||
uses: actions/attest@daf44fb950173508f38bd2406030372c1d1162b1 # v3.0.0
|
||||
uses: actions/attest@ce27ba3b4a9a139d9a20a4a07d69fabb52f1e5bc # v2.4.0
|
||||
with:
|
||||
subject-name: ${{ steps.latest_tag.outputs.tag }}
|
||||
predicate-type: "https://slsa.dev/provenance/v1"
|
||||
@@ -734,13 +734,13 @@ jobs:
|
||||
CREATED_LATEST_TAG: ${{ steps.build_docker.outputs.created_latest_tag }}
|
||||
|
||||
- name: Authenticate to Google Cloud
|
||||
uses: google-github-actions/auth@7c6bc770dae815cd3e89ee6cdf493a5fab2cc093 # v3.0.0
|
||||
uses: google-github-actions/auth@b7593ed2efd1c1617e1b0254da33b86225adb2a5 # v2.1.12
|
||||
with:
|
||||
workload_identity_provider: ${{ vars.GCP_WORKLOAD_ID_PROVIDER }}
|
||||
service_account: ${{ vars.GCP_SERVICE_ACCOUNT }}
|
||||
|
||||
- name: Setup GCloud SDK
|
||||
uses: google-github-actions/setup-gcloud@aa5489c8933f4cc7a4f7d45035b3b1440c9c10db # 3.0.1
|
||||
uses: google-github-actions/setup-gcloud@cb1e50a9932213ecece00a606661ae9ca44f3397 # 2.2.0
|
||||
|
||||
- name: Publish Helm Chart
|
||||
if: ${{ !inputs.dry_run }}
|
||||
@@ -802,7 +802,7 @@ jobs:
|
||||
# TODO: skip this if it's not a new release (i.e. a backport). This is
|
||||
# fine right now because it just makes a PR that we can close.
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -878,7 +878,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -888,7 +888,7 @@ jobs:
|
||||
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -971,12 +971,12 @@ jobs:
|
||||
if: ${{ !inputs.dry_run }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
|
||||
@@ -20,12 +20,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: "Checkout code"
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -47,6 +47,6 @@ jobs:
|
||||
|
||||
# Upload the results to GitHub's code scanning dashboard.
|
||||
- name: "Upload to code-scanning"
|
||||
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -27,12 +27,12 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/init@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
|
||||
with:
|
||||
languages: go, javascript
|
||||
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
rm Makefile
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/analyze@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
|
||||
|
||||
- name: Send Slack notification on failure
|
||||
if: ${{ failure() }}
|
||||
@@ -69,12 +69,12 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
|
||||
uses: aquasecurity/trivy-action@dc5a429b52fcf669ce959baa2c2dd26090d2a6c4
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
severity: "CRITICAL,HIGH"
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.29.5
|
||||
uses: github/codeql-action/upload-sarif@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5
|
||||
with:
|
||||
sarif_file: trivy-results.sarif
|
||||
category: "Trivy"
|
||||
|
||||
@@ -18,12 +18,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
# Start with the oldest issues, always.
|
||||
ascending: true
|
||||
- name: "Close old issues labeled likely-no"
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@@ -96,12 +96,12 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Run delete-old-branches-action
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
name: AI Triage Automation
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- labeled
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
issue_url:
|
||||
description: "GitHub Issue URL to process"
|
||||
required: true
|
||||
type: string
|
||||
template_name:
|
||||
description: "Coder template to use for workspace"
|
||||
required: true
|
||||
default: "traiage"
|
||||
type: string
|
||||
template_preset:
|
||||
description: "Template preset to use"
|
||||
required: true
|
||||
default: "Default"
|
||||
type: string
|
||||
prefix:
|
||||
description: "Prefix for workspace name"
|
||||
required: false
|
||||
default: "traiage"
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
traiage:
|
||||
name: Triage GitHub Issue with Claude Code
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.label.name == 'traiage' || github.event_name == 'workflow_dispatch'
|
||||
timeout-minutes: 30
|
||||
env:
|
||||
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
|
||||
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
# This is only required for testing locally using nektos/act, so leaving commented out.
|
||||
# An alternative is to use a larger or custom image.
|
||||
# - name: Install Github CLI
|
||||
# id: install-gh
|
||||
# run: |
|
||||
# (type -p wget >/dev/null || (sudo apt update && sudo apt install wget -y)) \
|
||||
# && sudo mkdir -p -m 755 /etc/apt/keyrings \
|
||||
# && out=$(mktemp) && wget -nv -O$out https://cli.github.com/packages/githubcli-archive-keyring.gpg \
|
||||
# && cat $out | sudo tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \
|
||||
# && sudo chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
# && sudo mkdir -p -m 755 /etc/apt/sources.list.d \
|
||||
# && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | sudo tee /etc/apt/sources.list.d/github-cli.list > /dev/null \
|
||||
# && sudo apt update \
|
||||
# && sudo apt install gh -y
|
||||
|
||||
- name: Determine Inputs
|
||||
id: determine-inputs
|
||||
if: always()
|
||||
env:
|
||||
GITHUB_ACTOR: ${{ github.actor }}
|
||||
GITHUB_EVENT_ISSUE_HTML_URL: ${{ github.event.issue.html_url }}
|
||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||
GITHUB_EVENT_USER_ID: ${{ github.event.sender.id }}
|
||||
GITHUB_EVENT_USER_LOGIN: ${{ github.event.sender.login }}
|
||||
INPUTS_ISSUE_URL: ${{ inputs.issue_url }}
|
||||
INPUTS_TEMPLATE_NAME: ${{ inputs.template_name || 'traiage' }}
|
||||
INPUTS_TEMPLATE_PRESET: ${{ inputs.template_preset || 'Default'}}
|
||||
INPUTS_PREFIX: ${{ inputs.prefix || 'traiage' }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
echo "Using template name: ${INPUTS_TEMPLATE_NAME}"
|
||||
echo "template_name=${INPUTS_TEMPLATE_NAME}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using template preset: ${INPUTS_TEMPLATE_PRESET}"
|
||||
echo "template_preset=${INPUTS_TEMPLATE_PRESET}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using prefix: ${INPUTS_PREFIX}"
|
||||
echo "prefix=${INPUTS_PREFIX}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
# For workflow_dispatch, use the actor who triggered it
|
||||
# For issues events, use the issue author.
|
||||
if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ]]; then
|
||||
if ! GITHUB_USER_ID=$(gh api "users/${GITHUB_ACTOR}" --jq '.id'); then
|
||||
echo "::error::Failed to get GitHub user ID for actor ${GITHUB_ACTOR}"
|
||||
exit 1
|
||||
fi
|
||||
echo "Using workflow_dispatch actor: ${GITHUB_ACTOR} (ID: ${GITHUB_USER_ID})"
|
||||
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
|
||||
echo "github_username=${GITHUB_ACTOR}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using issue URL: ${INPUTS_ISSUE_URL}"
|
||||
echo "issue_url=${INPUTS_ISSUE_URL}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
exit 0
|
||||
elif [[ "${GITHUB_EVENT_NAME}" == "issues" ]]; then
|
||||
GITHUB_USER_ID=${GITHUB_EVENT_USER_ID}
|
||||
echo "Using issue author: ${GITHUB_EVENT_USER_LOGIN} (ID: ${GITHUB_USER_ID})"
|
||||
echo "github_user_id=${GITHUB_USER_ID}" >> "${GITHUB_OUTPUT}"
|
||||
echo "github_username=${GITHUB_EVENT_USER_LOGIN}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
echo "Using issue URL: ${GITHUB_EVENT_ISSUE_HTML_URL}"
|
||||
echo "issue_url=${GITHUB_EVENT_ISSUE_HTML_URL}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
exit 0
|
||||
else
|
||||
echo "::error::Unsupported event type: ${GITHUB_EVENT_NAME}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Verify push access
|
||||
env:
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_USERNAME: ${{ steps.determine-inputs.outputs.github_username }}
|
||||
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
|
||||
run: |
|
||||
# Query the actor’s permission on this repo
|
||||
can_push="$(gh api "/repos/${GITHUB_REPOSITORY}/collaborators/${GITHUB_USERNAME}/permission" --jq '.user.permissions.push')"
|
||||
if [[ "${can_push}" != "true" ]]; then
|
||||
echo "::error title=Access Denied::${GITHUB_USERNAME} does not have push access to ${GITHUB_REPOSITORY}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Extract context key from issue
|
||||
id: extract-context
|
||||
env:
|
||||
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
issue_number="$(gh issue view "${ISSUE_URL}" --json number --jq '.number')"
|
||||
context_key="gh-${issue_number}"
|
||||
echo "context_key=${context_key}" >> "${GITHUB_OUTPUT}"
|
||||
echo "CONTEXT_KEY=${context_key}" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Download and install Coder binary
|
||||
shell: bash
|
||||
env:
|
||||
CODER_URL: ${{ secrets.TRAIAGE_CODER_URL }}
|
||||
run: |
|
||||
if [ "${{ runner.arch }}" == "ARM64" ]; then
|
||||
ARCH="arm64"
|
||||
else
|
||||
ARCH="amd64"
|
||||
fi
|
||||
mkdir -p "${HOME}/.local/bin"
|
||||
curl -fsSL --compressed "$CODER_URL/bin/coder-linux-${ARCH}" -o "${HOME}/.local/bin/coder"
|
||||
chmod +x "${HOME}/.local/bin/coder"
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
coder version
|
||||
coder whoami
|
||||
echo "$HOME/.local/bin" >> "${GITHUB_PATH}"
|
||||
|
||||
- name: Get Coder username from GitHub actor
|
||||
id: get-coder-username
|
||||
env:
|
||||
CODER_SESSION_TOKEN: ${{ secrets.TRAIAGE_CODER_SESSION_TOKEN }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_USER_ID: ${{ steps.determine-inputs.outputs.github_user_id }}
|
||||
run: |
|
||||
user_json=$(
|
||||
coder users list --github-user-id="${GITHUB_USER_ID}" --output=json
|
||||
)
|
||||
coder_username=$(jq -r 'first | .username' <<< "$user_json")
|
||||
[[ -z "${coder_username}" || "${coder_username}" == "null" ]] && echo "No Coder user with GitHub user ID ${GITHUB_USER_ID} found" && exit 1
|
||||
echo "coder_username=${coder_username}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
# TODO(Cian): this is a good use-case for 'recipes'
|
||||
- name: Create Coder task
|
||||
id: create-task
|
||||
env:
|
||||
CODER_USERNAME: ${{ steps.get-coder-username.outputs.coder_username }}
|
||||
CONTEXT_KEY: ${{ steps.extract-context.outputs.context_key }}
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||
ISSUE_URL: ${{ steps.determine-inputs.outputs.issue_url }}
|
||||
PREFIX: ${{ steps.determine-inputs.outputs.prefix }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
TEMPLATE_NAME: ${{ steps.determine-inputs.outputs.template_name }}
|
||||
TEMPLATE_PARAMETERS: ${{ secrets.TRAIAGE_TEMPLATE_PARAMETERS }}
|
||||
TEMPLATE_PRESET: ${{ steps.determine-inputs.outputs.template_preset }}
|
||||
run: |
|
||||
# Fetch issue description using `gh` CLI
|
||||
#shellcheck disable=SC2016 # The template string should not be subject to shell expansion
|
||||
issue_description=$(gh issue view "${ISSUE_URL}" \
|
||||
--json 'title,body,comments' \
|
||||
--template '{{printf "%s\n\n%s\n\nComments:\n" .title .body}}{{range $k, $v := .comments}} - {{index $v.author "login"}}: {{printf "%s\n" $v.body}}{{end}}')
|
||||
|
||||
# Write a prompt to PROMPT_FILE
|
||||
PROMPT=$(cat <<EOF
|
||||
Fix ${ISSUE_URL}
|
||||
|
||||
Analyze the below GitHub issue description, understand the root cause, and make appropriate changes to resolve the issue.
|
||||
---
|
||||
${issue_description}
|
||||
EOF
|
||||
)
|
||||
export PROMPT
|
||||
|
||||
export TASK_NAME="${PREFIX}-${CONTEXT_KEY}-${RUN_ID}"
|
||||
echo "Creating task: $TASK_NAME"
|
||||
./scripts/traiage.sh create
|
||||
if [[ "${ISSUE_URL}" == "https://github.com/${GITHUB_REPOSITORY}"* ]]; then
|
||||
gh issue comment "${ISSUE_URL}" --body "Task created: https://dev.coder.com/tasks/${CODER_USERNAME}/${TASK_NAME}" --create-if-none --edit-last
|
||||
else
|
||||
echo "Skipping comment on other repo."
|
||||
fi
|
||||
echo "TASK_NAME=${CODER_USERNAME}/${TASK_NAME}" >> "${GITHUB_OUTPUT}"
|
||||
echo "TASK_NAME=${CODER_USERNAME}/${TASK_NAME}" >> "${GITHUB_ENV}"
|
||||
@@ -1,6 +1,5 @@
|
||||
[default]
|
||||
extend-ignore-identifiers-re = ["gho_.*"]
|
||||
extend-ignore-re = ["(#|//)\\s*spellchecker:ignore-next-line\\n.*"]
|
||||
|
||||
[default.extend-identifiers]
|
||||
alog = "alog"
|
||||
|
||||
@@ -21,12 +21,12 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@f4a75cfd619ee5ce8d5b864b0d183aff3c69b55a # v2.13.1
|
||||
uses: step-security/harden-runner@ec9f2d5744a09debf3a187a3f4f675c53b671911 # v2.13.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
rules:
|
||||
cache-poisoning:
|
||||
ignore:
|
||||
- "ci.yaml:184"
|
||||
+1
-11
@@ -169,16 +169,6 @@ linters-settings:
|
||||
- name: var-declaration
|
||||
- name: var-naming
|
||||
- name: waitgroup-by-value
|
||||
usetesting:
|
||||
# Only os-setenv is enabled because we migrated to usetesting from another linter that
|
||||
# only covered os-setenv.
|
||||
os-setenv: true
|
||||
os-create-temp: false
|
||||
os-mkdir-temp: false
|
||||
os-temp-dir: false
|
||||
os-chdir: false
|
||||
context-background: false
|
||||
context-todo: false
|
||||
|
||||
# irrelevant as of Go v1.22: https://go.dev/blog/loopvar-preview
|
||||
govet:
|
||||
@@ -262,6 +252,7 @@ linters:
|
||||
# - wastedassign
|
||||
|
||||
- staticcheck
|
||||
- tenv
|
||||
# In Go, it's possible for a package to test it's internal functionality
|
||||
# without testing any exported functions. This is enabled to promote
|
||||
# decomposing a package before testing it's internals. A function caller
|
||||
@@ -274,5 +265,4 @@ linters:
|
||||
- typecheck
|
||||
- unconvert
|
||||
- unused
|
||||
- usetesting
|
||||
- dupl
|
||||
|
||||
Vendored
+1
-3
@@ -54,13 +54,11 @@
|
||||
}
|
||||
},
|
||||
|
||||
"tailwindCSS.classFunctions": ["cva", "cn"],
|
||||
"[css][html][markdown][yaml]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode"
|
||||
},
|
||||
"typos.config": ".github/workflows/typos.toml",
|
||||
"[markdown]": {
|
||||
"editor.defaultFormatter": "DavidAnson.vscode-markdownlint"
|
||||
},
|
||||
"biome.lsp.bin": "site/node_modules/.bin/biome"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -561,7 +561,7 @@ endif
|
||||
|
||||
# Note: we don't run zizmor in the lint target because it takes a while. CI
|
||||
# runs it explicitly.
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -614,11 +614,6 @@ lint/actions/zizmor:
|
||||
.
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
@@ -635,23 +630,16 @@ TAILNETTEST_MOCKS := \
|
||||
tailnet/tailnettest/workspaceupdatesprovidermock.go \
|
||||
tailnet/tailnettest/subscriptionmock.go
|
||||
|
||||
AIBRIDGED_MOCKS := \
|
||||
enterprise/x/aibridged/aibridgedmock/clientmock.go \
|
||||
enterprise/x/aibridged/aibridgedmock/poolmock.go
|
||||
|
||||
GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
$(DB_GEN_FILES) \
|
||||
$(SITE_GEN_FILES) \
|
||||
coderd/rbac/object_gen.go \
|
||||
codersdk/rbacresources_gen.go \
|
||||
coderd/rbac/scopes_constants_gen.go \
|
||||
codersdk/apikey_scopes_gen.go \
|
||||
docs/admin/integrations/prometheus.md \
|
||||
docs/reference/cli/index.md \
|
||||
docs/admin/security/audit-logs.md \
|
||||
@@ -665,8 +653,7 @@ GEN_FILES := \
|
||||
agent/agentcontainers/acmock/acmock.go \
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go \
|
||||
coderd/httpmw/loggermw/loggermock/loggermock.go \
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
|
||||
$(AIBRIDGED_MOCKS)
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go
|
||||
|
||||
# all gen targets should be added here and to gen/mark-fresh
|
||||
gen: gen/db gen/golden-files $(GEN_FILES)
|
||||
@@ -696,13 +683,11 @@ gen/mark-fresh:
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go \
|
||||
coderd/database/dump.sql \
|
||||
$(DB_GEN_FILES) \
|
||||
site/src/api/typesGenerated.ts \
|
||||
coderd/rbac/object_gen.go \
|
||||
codersdk/rbacresources_gen.go \
|
||||
coderd/rbac/scopes_constants_gen.go \
|
||||
site/src/api/rbacresourcesGenerated.ts \
|
||||
site/src/api/countriesGenerated.ts \
|
||||
docs/admin/integrations/prometheus.md \
|
||||
@@ -719,7 +704,6 @@ gen/mark-fresh:
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go \
|
||||
coderd/httpmw/loggermw/loggermock/loggermock.go \
|
||||
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
|
||||
$(AIBRIDGED_MOCKS) \
|
||||
"
|
||||
|
||||
for file in $$files; do
|
||||
@@ -767,10 +751,6 @@ codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agen
|
||||
go generate ./codersdk/workspacesdk/agentconnmock/
|
||||
touch "$@"
|
||||
|
||||
$(AIBRIDGED_MOCKS): enterprise/x/aibridged/client.go enterprise/x/aibridged/pool.go
|
||||
go generate ./enterprise/x/aibridged/aibridgedmock/
|
||||
touch "$@"
|
||||
|
||||
agent/agentcontainers/dcspec/dcspec_gen.go: \
|
||||
node_modules/.installed \
|
||||
agent/agentcontainers/dcspec/devContainer.base.schema.json \
|
||||
@@ -821,14 +801,6 @@ vpn/vpn.pb.go: vpn/vpn.proto
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
enterprise/x/aibridged/proto/aibridged.pb.go: enterprise/x/aibridged/proto/aibridged.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/x/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
go run -C ./scripts/apitypings main.go > $@
|
||||
@@ -855,15 +827,6 @@ coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/mai
|
||||
rmdir -v "$$tempdir"
|
||||
touch "$@"
|
||||
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go
|
||||
# Generate typed low-level ScopeName constants from RBACPermissions
|
||||
# Write to a temp file first to avoid truncating the package during build
|
||||
# since the generator imports the rbac package.
|
||||
tempfile=$(shell mktemp /tmp/scopes_constants_gen.XXXXXX)
|
||||
go run ./scripts/typegen/main.go rbac scopenames > "$$tempfile"
|
||||
mv -v "$$tempfile" coderd/rbac/scopes_constants_gen.go
|
||||
touch "$@"
|
||||
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
# Do no overwrite codersdk/rbacresources_gen.go directly, as it would make the file empty, breaking
|
||||
# the `codersdk` package and any parallel build targets.
|
||||
@@ -871,12 +834,6 @@ codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/m
|
||||
mv /tmp/rbacresources_gen.go codersdk/rbacresources_gen.go
|
||||
touch "$@"
|
||||
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go
|
||||
# Generate SDK constants for external API key scopes.
|
||||
go run ./scripts/apikeyscopesgen > /tmp/apikey_scopes_gen.go
|
||||
mv /tmp/apikey_scopes_gen.go codersdk/apikey_scopes_gen.go
|
||||
touch "$@"
|
||||
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
go run scripts/typegen/main.go rbac typescript > "$@"
|
||||
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
|
||||
|
||||
+15
-5
@@ -74,6 +74,7 @@ type Options struct {
|
||||
LogDir string
|
||||
TempDir string
|
||||
ScriptDataDir string
|
||||
ExchangeToken func(ctx context.Context) (string, error)
|
||||
Client Client
|
||||
ReconnectingPTYTimeout time.Duration
|
||||
EnvironmentVariables map[string]string
|
||||
@@ -98,7 +99,6 @@ type Client interface {
|
||||
proto.DRPCAgentClient26, tailnetproto.DRPCTailnetClient26, error,
|
||||
)
|
||||
tailnet.DERPMapRewriter
|
||||
agentsdk.RefreshableSessionTokenProvider
|
||||
}
|
||||
|
||||
type Agent interface {
|
||||
@@ -131,6 +131,11 @@ func New(options Options) Agent {
|
||||
}
|
||||
options.ScriptDataDir = options.TempDir
|
||||
}
|
||||
if options.ExchangeToken == nil {
|
||||
options.ExchangeToken = func(_ context.Context) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
if options.ReportMetadataInterval == 0 {
|
||||
options.ReportMetadataInterval = time.Second
|
||||
}
|
||||
@@ -167,6 +172,7 @@ func New(options Options) Agent {
|
||||
coordDisconnected: make(chan struct{}),
|
||||
environmentVariables: options.EnvironmentVariables,
|
||||
client: options.Client,
|
||||
exchangeToken: options.ExchangeToken,
|
||||
filesystem: options.Filesystem,
|
||||
logDir: options.LogDir,
|
||||
tempDir: options.TempDir,
|
||||
@@ -197,6 +203,7 @@ func New(options Options) Agent {
|
||||
// coordinator during shut down.
|
||||
close(a.coordDisconnected)
|
||||
a.announcementBanners.Store(new([]codersdk.BannerConfig))
|
||||
a.sessionToken.Store(new(string))
|
||||
a.init()
|
||||
return a
|
||||
}
|
||||
@@ -205,6 +212,7 @@ type agent struct {
|
||||
clock quartz.Clock
|
||||
logger slog.Logger
|
||||
client Client
|
||||
exchangeToken func(ctx context.Context) (string, error)
|
||||
tailnetListenPort uint16
|
||||
filesystem afero.Fs
|
||||
logDir string
|
||||
@@ -246,6 +254,7 @@ type agent struct {
|
||||
scriptRunner *agentscripts.Runner
|
||||
announcementBanners atomic.Pointer[[]codersdk.BannerConfig] // announcementBanners is atomic because it is periodically updated.
|
||||
announcementBannersRefreshInterval time.Duration
|
||||
sessionToken atomic.Pointer[string]
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
@@ -785,7 +794,7 @@ func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentC
|
||||
// log a warning.
|
||||
// Related to https://github.com/coder/coder/issues/20194
|
||||
logger.Warn(ctx, "failed to report connection to server", slog.Error(err))
|
||||
// keep going, we still need to remove it from the slice
|
||||
// no continue here, we still need to remove it from the slice
|
||||
} else {
|
||||
logger.Debug(ctx, "successfully reported connection")
|
||||
}
|
||||
@@ -918,10 +927,11 @@ func (a *agent) run() (retErr error) {
|
||||
// This allows the agent to refresh its token if necessary.
|
||||
// For instance identity this is required, since the instance
|
||||
// may not have re-provisioned, but a new agent ID was created.
|
||||
err := a.client.RefreshToken(a.hardCtx)
|
||||
sessionToken, err := a.exchangeToken(a.hardCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("refresh token: %w", err)
|
||||
return xerrors.Errorf("exchange token: %w", err)
|
||||
}
|
||||
a.sessionToken.Store(&sessionToken)
|
||||
|
||||
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
|
||||
aAPI, tAPI, err := a.client.ConnectRPC26(a.hardCtx)
|
||||
@@ -1360,7 +1370,7 @@ func (a *agent) updateCommandEnv(current []string) (updated []string, err error)
|
||||
"CODER_WORKSPACE_OWNER_NAME": manifest.OwnerName,
|
||||
|
||||
// Specific Coder subcommands require the agent token exposed!
|
||||
"CODER_AGENT_TOKEN": a.client.GetSessionToken(),
|
||||
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
|
||||
|
||||
// Git on Windows resolves with UNIX-style paths.
|
||||
// If using backslashes, it's unable to find the executable.
|
||||
|
||||
+24
-16
@@ -22,6 +22,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1807,12 +1808,11 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
|
||||
|
||||
//nolint:dogsled
|
||||
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
|
||||
idConnectionReport := uuid.New()
|
||||
id := uuid.New()
|
||||
|
||||
// Test that the connection is reported. This must be tested in the
|
||||
// first connection because we care about verifying all of these.
|
||||
netConn0, err := conn.ReconnectingPTY(ctx, idConnectionReport, 80, 80, "bash --norc")
|
||||
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
|
||||
require.NoError(t, err)
|
||||
_ = netConn0.Close()
|
||||
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")
|
||||
@@ -2028,8 +2028,7 @@ func runSubAgentMain() int {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "agent connection failed: %v\n", err)
|
||||
return 11
|
||||
@@ -2927,11 +2926,11 @@ func TestAgent_Speedtest(t *testing.T) {
|
||||
|
||||
func TestAgent_Reconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t)
|
||||
// After the agent is disconnected from a coordinator, it's supposed
|
||||
// to reconnect!
|
||||
fCoordinator := tailnettest.NewFakeCoordinator()
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
defer coordinator.Close()
|
||||
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *proto.Stats, 50)
|
||||
@@ -2943,24 +2942,27 @@ func TestAgent_Reconnect(t *testing.T) {
|
||||
DERPMap: derpMap,
|
||||
},
|
||||
statsCh,
|
||||
fCoordinator,
|
||||
coordinator,
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
initialized := atomic.Int32{}
|
||||
closer := agent.New(agent.Options{
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
initialized.Add(1)
|
||||
return "", nil
|
||||
},
|
||||
Client: client,
|
||||
Logger: logger.Named("agent"),
|
||||
})
|
||||
defer closer.Close()
|
||||
|
||||
call1 := testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 1)
|
||||
close(call1.Resps) // hang up
|
||||
// expect reconnect
|
||||
testutil.RequireReceive(ctx, t, fCoordinator.CoordinateCalls)
|
||||
// Check that the agent refreshes the token when it reconnects.
|
||||
require.Equal(t, client.GetNumRefreshTokenCalls(), 2)
|
||||
closer.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
return coordinator.Node(agentID) != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
client.LastWorkspaceAgent()
|
||||
require.Eventually(t, func() bool {
|
||||
return initialized.Load() == 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
||||
@@ -2982,6 +2984,9 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
|
||||
defer client.Close()
|
||||
filesystem := afero.NewMemMapFs()
|
||||
closer := agent.New(agent.Options{
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
Client: client,
|
||||
Logger: logger.Named("agent"),
|
||||
Filesystem: filesystem,
|
||||
@@ -3010,6 +3015,9 @@ func TestAgent_DebugServer(t *testing.T) {
|
||||
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
|
||||
DERPMap: derpMap,
|
||||
}, 0, func(c *agenttest.Client, o *agent.Options) {
|
||||
o.ExchangeToken = func(context.Context) (string, error) {
|
||||
return "token", nil
|
||||
}
|
||||
o.LogDir = logDir
|
||||
})
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package agenttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -30,11 +31,18 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
|
||||
}
|
||||
|
||||
if o.Client == nil {
|
||||
agentClient := agentsdk.New(coderURL, agentsdk.WithFixedToken(agentToken))
|
||||
agentClient := agentsdk.New(coderURL)
|
||||
agentClient.SetSessionToken(agentToken)
|
||||
agentClient.SDK.SetLogger(log)
|
||||
o.Client = agentClient
|
||||
}
|
||||
|
||||
if o.ExchangeToken == nil {
|
||||
o.ExchangeToken = func(_ context.Context) (string, error) {
|
||||
return agentToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
if o.LogDir == "" {
|
||||
o.LogDir = t.TempDir()
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package agenttest
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -29,7 +28,6 @@ import (
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
const statsInterval = 500 * time.Millisecond
|
||||
@@ -88,34 +86,10 @@ type Client struct {
|
||||
fakeAgentAPI *FakeAgentAPI
|
||||
LastWorkspaceAgent func()
|
||||
|
||||
mu sync.Mutex // Protects following.
|
||||
logs []agentsdk.Log
|
||||
derpMapUpdates chan *tailcfg.DERPMap
|
||||
derpMapOnce sync.Once
|
||||
refreshTokenCalls int
|
||||
}
|
||||
|
||||
func (*Client) AsRequestOption() codersdk.RequestOption {
|
||||
return func(_ *http.Request) {}
|
||||
}
|
||||
|
||||
func (*Client) SetDialOption(*websocket.DialOptions) {}
|
||||
|
||||
func (*Client) GetSessionToken() string {
|
||||
return "agenttest-token"
|
||||
}
|
||||
|
||||
func (c *Client) RefreshToken(context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.refreshTokenCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) GetNumRefreshTokenCalls() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.refreshTokenCalls
|
||||
mu sync.Mutex // Protects following.
|
||||
logs []agentsdk.Log
|
||||
derpMapUpdates chan *tailcfg.DERPMap
|
||||
derpMapOnce sync.Once
|
||||
}
|
||||
|
||||
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
|
||||
|
||||
@@ -60,9 +60,6 @@ func (a *agent) apiHandler() http.Handler {
|
||||
r.Get("/api/v0/listening-ports", lp.handler)
|
||||
r.Get("/api/v0/netcheck", a.HandleNetcheck)
|
||||
r.Post("/api/v0/list-directory", a.HandleLS)
|
||||
r.Get("/api/v0/read-file", a.HandleReadFile)
|
||||
r.Post("/api/v0/write-file", a.HandleWriteFile)
|
||||
r.Post("/api/v0/edit-files", a.HandleEditFiles)
|
||||
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
|
||||
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
|
||||
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)
|
||||
|
||||
+1
-2
@@ -63,7 +63,6 @@ func NewAppHealthReporterWithClock(
|
||||
// run a ticker for each app health check.
|
||||
var mu sync.RWMutex
|
||||
failures := make(map[uuid.UUID]int, 0)
|
||||
client := &http.Client{}
|
||||
for _, nextApp := range apps {
|
||||
if !shouldStartTicker(nextApp) {
|
||||
continue
|
||||
@@ -92,7 +91,7 @@ func NewAppHealthReporterWithClock(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
-273
@@ -1,273 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/icholy/replace"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/text/transform"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type HTTPResponseCode = int
|
||||
|
||||
func (a *agent) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
query := r.URL.Query()
|
||||
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
|
||||
path := parser.String(query, "", "path")
|
||||
offset := parser.PositiveInt64(query, 0, "offset")
|
||||
limit := parser.PositiveInt64(query, 0, "limit")
|
||||
parser.ErrorExcessParams(query)
|
||||
if len(parser.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Validations: parser.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, err := a.streamFile(ctx, rw, path, offset, limit)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, status, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) streamFile(ctx context.Context, rw http.ResponseWriter, path string, offset, limit int64) (HTTPResponseCode, error) {
|
||||
if !filepath.IsAbs(path) {
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
f, err := a.filesystem.Open(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
status = http.StatusNotFound
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
|
||||
}
|
||||
|
||||
size := stat.Size()
|
||||
if limit == 0 {
|
||||
limit = size
|
||||
}
|
||||
bytesRemaining := max(size-offset, 0)
|
||||
bytesToRead := min(bytesRemaining, limit)
|
||||
|
||||
// Relying on just the file name for the mime type for now.
|
||||
mimeType := mime.TypeByExtension(filepath.Ext(path))
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
rw.Header().Set("Content-Type", mimeType)
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(bytesToRead, 10))
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
reader := io.NewSectionReader(f, offset, bytesToRead)
|
||||
_, err = io.Copy(rw, reader)
|
||||
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
|
||||
a.logger.Error(ctx, "workspace agent read file", slog.Error(err))
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (a *agent) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
query := r.URL.Query()
|
||||
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
|
||||
path := parser.String(query, "", "path")
|
||||
parser.ErrorExcessParams(query)
|
||||
if len(parser.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Validations: parser.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
status, err := a.writeFile(ctx, r, path)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, status, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: fmt.Sprintf("Successfully wrote to %q", path),
|
||||
})
|
||||
}
|
||||
|
||||
func (a *agent) writeFile(ctx context.Context, r *http.Request, path string) (HTTPResponseCode, error) {
|
||||
if !filepath.IsAbs(path) {
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
err := a.filesystem.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
case errors.Is(err, syscall.ENOTDIR):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
|
||||
f, err := a.filesystem.Create(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
case errors.Is(err, syscall.EISDIR):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, r.Body)
|
||||
if err != nil && !errors.Is(err, io.EOF) && ctx.Err() == nil {
|
||||
a.logger.Error(ctx, "workspace agent write file", slog.Error(err))
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (a *agent) HandleEditFiles(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req workspacesdk.FileEditRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Files) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "must specify at least one file",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var combinedErr error
|
||||
status := http.StatusOK
|
||||
for _, edit := range req.Files {
|
||||
s, err := a.editFile(r.Context(), edit.Path, edit.Edits)
|
||||
// Keep the highest response status, so 500 will be preferred over 400, etc.
|
||||
if s > status {
|
||||
status = s
|
||||
}
|
||||
if err != nil {
|
||||
combinedErr = errors.Join(combinedErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if combinedErr != nil {
|
||||
httpapi.Write(ctx, rw, status, codersdk.Response{
|
||||
Message: combinedErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Successfully edited file(s)",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *agent) editFile(ctx context.Context, path string, edits []workspacesdk.FileEdit) (int, error) {
|
||||
if path == "" {
|
||||
return http.StatusBadRequest, xerrors.New("\"path\" is required")
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(path) {
|
||||
return http.StatusBadRequest, xerrors.Errorf("file path must be absolute: %q", path)
|
||||
}
|
||||
|
||||
if len(edits) == 0 {
|
||||
return http.StatusBadRequest, xerrors.New("must specify at least one edit")
|
||||
}
|
||||
|
||||
f, err := a.filesystem.Open(path)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
status = http.StatusNotFound
|
||||
case errors.Is(err, os.ErrPermission):
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
|
||||
}
|
||||
|
||||
transforms := make([]transform.Transformer, len(edits))
|
||||
for i, edit := range edits {
|
||||
transforms[i] = replace.String(edit.Search, edit.Replace)
|
||||
}
|
||||
|
||||
tmpfile, err := afero.TempFile(a.filesystem, "", filepath.Base(path))
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
defer tmpfile.Close()
|
||||
|
||||
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
|
||||
if err != nil {
|
||||
if rerr := a.filesystem.Remove(tmpfile.Name()); rerr != nil {
|
||||
a.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
return http.StatusInternalServerError, xerrors.Errorf("edit %s: %w", path, err)
|
||||
}
|
||||
|
||||
err = a.filesystem.Rename(tmpfile.Name(), path)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
@@ -1,722 +0,0 @@
|
||||
package agent_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
type testFs struct {
|
||||
afero.Fs
|
||||
// intercept can return an error for testing when a call fails.
|
||||
intercept func(call, file string) error
|
||||
}
|
||||
|
||||
func newTestFs(base afero.Fs, intercept func(call, file string) error) *testFs {
|
||||
return &testFs{
|
||||
Fs: base,
|
||||
intercept: intercept,
|
||||
}
|
||||
}
|
||||
|
||||
func (fs *testFs) Open(name string) (afero.File, error) {
|
||||
if err := fs.intercept("open", name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fs.Fs.Open(name)
|
||||
}
|
||||
|
||||
func (fs *testFs) Create(name string) (afero.File, error) {
|
||||
if err := fs.intercept("create", name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unlike os, afero lets you create files where directories already exist and
|
||||
// lets you nest them underneath files, somehow.
|
||||
stat, err := fs.Fs.Stat(name)
|
||||
if err == nil && stat.IsDir() {
|
||||
return nil, &os.PathError{
|
||||
Op: "open",
|
||||
Path: name,
|
||||
Err: syscall.EISDIR,
|
||||
}
|
||||
}
|
||||
stat, err = fs.Fs.Stat(filepath.Dir(name))
|
||||
if err == nil && !stat.IsDir() {
|
||||
return nil, &os.PathError{
|
||||
Op: "open",
|
||||
Path: name,
|
||||
Err: syscall.ENOTDIR,
|
||||
}
|
||||
}
|
||||
return fs.Fs.Create(name)
|
||||
}
|
||||
|
||||
func (fs *testFs) MkdirAll(name string, mode os.FileMode) error {
|
||||
if err := fs.intercept("mkdirall", name); err != nil {
|
||||
return err
|
||||
}
|
||||
// Unlike os, afero lets you create directories where files already exist and
|
||||
// lets you nest them underneath files somehow.
|
||||
stat, err := fs.Fs.Stat(filepath.Dir(name))
|
||||
if err == nil && !stat.IsDir() {
|
||||
return &os.PathError{
|
||||
Op: "mkdir",
|
||||
Path: name,
|
||||
Err: syscall.ENOTDIR,
|
||||
}
|
||||
}
|
||||
stat, err = fs.Fs.Stat(name)
|
||||
if err == nil && !stat.IsDir() {
|
||||
return &os.PathError{
|
||||
Op: "mkdir",
|
||||
Path: name,
|
||||
Err: syscall.ENOTDIR,
|
||||
}
|
||||
}
|
||||
return fs.Fs.MkdirAll(name, mode)
|
||||
}
|
||||
|
||||
func (fs *testFs) Rename(oldName, newName string) error {
|
||||
if err := fs.intercept("rename", newName); err != nil {
|
||||
return err
|
||||
}
|
||||
return fs.Fs.Rename(oldName, newName)
|
||||
}
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpdir := os.TempDir()
|
||||
noPermsFilePath := filepath.Join(tmpdir, "no-perms")
|
||||
//nolint:dogsled
|
||||
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||||
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
|
||||
if file == noPermsFilePath {
|
||||
return os.ErrPermission
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filePath := filepath.Join(tmpdir, "file")
|
||||
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
imagePath := filepath.Join(tmpdir, "file.png")
|
||||
err = afero.WriteFile(fs, imagePath, []byte("not really an image"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
limit int64
|
||||
offset int64
|
||||
bytes []byte
|
||||
mimeType string
|
||||
errCode int
|
||||
error string
|
||||
}{
|
||||
{
|
||||
name: "NoPath",
|
||||
path: "",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "\"path\" is required",
|
||||
},
|
||||
{
|
||||
name: "RelativePathDotSlash",
|
||||
path: "./relative",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "file path must be absolute",
|
||||
},
|
||||
{
|
||||
name: "RelativePath",
|
||||
path: "also-relative",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "file path must be absolute",
|
||||
},
|
||||
{
|
||||
name: "NegativeLimit",
|
||||
path: filePath,
|
||||
limit: -10,
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "value is negative",
|
||||
},
|
||||
{
|
||||
name: "NegativeOffset",
|
||||
path: filePath,
|
||||
offset: -10,
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "value is negative",
|
||||
},
|
||||
{
|
||||
name: "NonExistent",
|
||||
path: filepath.Join(tmpdir, "does-not-exist"),
|
||||
errCode: http.StatusNotFound,
|
||||
error: "file does not exist",
|
||||
},
|
||||
{
|
||||
name: "IsDir",
|
||||
path: dirPath,
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "not a file",
|
||||
},
|
||||
{
|
||||
name: "NoPermissions",
|
||||
path: noPermsFilePath,
|
||||
errCode: http.StatusForbidden,
|
||||
error: "permission denied",
|
||||
},
|
||||
{
|
||||
name: "Defaults",
|
||||
path: filePath,
|
||||
bytes: []byte("content"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Limit1",
|
||||
path: filePath,
|
||||
limit: 1,
|
||||
bytes: []byte("c"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Offset1",
|
||||
path: filePath,
|
||||
offset: 1,
|
||||
bytes: []byte("ontent"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Limit1Offset2",
|
||||
path: filePath,
|
||||
limit: 1,
|
||||
offset: 2,
|
||||
bytes: []byte("n"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Limit7Offset0",
|
||||
path: filePath,
|
||||
limit: 7,
|
||||
offset: 0,
|
||||
bytes: []byte("content"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Limit100",
|
||||
path: filePath,
|
||||
limit: 100,
|
||||
bytes: []byte("content"),
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Offset7",
|
||||
path: filePath,
|
||||
offset: 7,
|
||||
bytes: []byte{},
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "Offset100",
|
||||
path: filePath,
|
||||
offset: 100,
|
||||
bytes: []byte{},
|
||||
mimeType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "MimeTypePng",
|
||||
path: imagePath,
|
||||
bytes: []byte("not really an image"),
|
||||
mimeType: "image/png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
reader, mimeType, err := conn.ReadFile(ctx, tt.path, tt.offset, tt.limit)
|
||||
if tt.errCode != 0 {
|
||||
require.Error(t, err)
|
||||
cerr := coderdtest.SDKError(t, err)
|
||||
require.Contains(t, cerr.Error(), tt.error)
|
||||
require.Equal(t, tt.errCode, cerr.StatusCode())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
defer reader.Close()
|
||||
bytes, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.bytes, bytes)
|
||||
require.Equal(t, tt.mimeType, mimeType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpdir := os.TempDir()
|
||||
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
|
||||
noPermsDirPath := filepath.Join(tmpdir, "no-perms-dir")
|
||||
//nolint:dogsled
|
||||
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||||
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
|
||||
if file == noPermsFilePath || file == noPermsDirPath {
|
||||
return os.ErrPermission
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filePath := filepath.Join(tmpdir, "file")
|
||||
err = afero.WriteFile(fs, filePath, []byte("content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
notDirErr := "not a directory"
|
||||
if runtime.GOOS == "windows" {
|
||||
notDirErr = "cannot find the path"
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
bytes []byte
|
||||
errCode int
|
||||
error string
|
||||
}{
|
||||
{
|
||||
name: "NoPath",
|
||||
path: "",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "\"path\" is required",
|
||||
},
|
||||
{
|
||||
name: "RelativePathDotSlash",
|
||||
path: "./relative",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "file path must be absolute",
|
||||
},
|
||||
{
|
||||
name: "RelativePath",
|
||||
path: "also-relative",
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "file path must be absolute",
|
||||
},
|
||||
{
|
||||
name: "NonExistent",
|
||||
path: filepath.Join(tmpdir, "/nested/does-not-exist"),
|
||||
bytes: []byte("now it does exist"),
|
||||
},
|
||||
{
|
||||
name: "IsDir",
|
||||
path: dirPath,
|
||||
errCode: http.StatusBadRequest,
|
||||
error: "is a directory",
|
||||
},
|
||||
{
|
||||
name: "IsNotDir",
|
||||
path: filepath.Join(filePath, "file2"),
|
||||
errCode: http.StatusBadRequest,
|
||||
error: notDirErr,
|
||||
},
|
||||
{
|
||||
name: "NoPermissionsFile",
|
||||
path: noPermsFilePath,
|
||||
errCode: http.StatusForbidden,
|
||||
error: "permission denied",
|
||||
},
|
||||
{
|
||||
name: "NoPermissionsDir",
|
||||
path: filepath.Join(noPermsDirPath, "within-no-perm-dir"),
|
||||
errCode: http.StatusForbidden,
|
||||
error: "permission denied",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
reader := bytes.NewReader(tt.bytes)
|
||||
err := conn.WriteFile(ctx, tt.path, reader)
|
||||
if tt.errCode != 0 {
|
||||
require.Error(t, err)
|
||||
cerr := coderdtest.SDKError(t, err)
|
||||
require.Contains(t, cerr.Error(), tt.error)
|
||||
require.Equal(t, tt.errCode, cerr.StatusCode())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
b, err := afero.ReadFile(fs, tt.path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.bytes, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEditFiles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpdir := os.TempDir()
|
||||
noPermsFilePath := filepath.Join(tmpdir, "no-perms-file")
|
||||
failRenameFilePath := filepath.Join(tmpdir, "fail-rename")
|
||||
//nolint:dogsled
|
||||
conn, _, _, fs, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, opts *agent.Options) {
|
||||
opts.Filesystem = newTestFs(opts.Filesystem, func(call, file string) error {
|
||||
if file == noPermsFilePath {
|
||||
return &os.PathError{
|
||||
Op: call,
|
||||
Path: file,
|
||||
Err: os.ErrPermission,
|
||||
}
|
||||
} else if file == failRenameFilePath && call == "rename" {
|
||||
return xerrors.New("rename failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "directory")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contents map[string]string
|
||||
edits []workspacesdk.FileEdits
|
||||
expected map[string]string
|
||||
errCode int
|
||||
errors []string
|
||||
}{
|
||||
{
|
||||
name: "NoFiles",
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"must specify at least one file"},
|
||||
},
|
||||
{
|
||||
name: "NoPath",
|
||||
errCode: http.StatusBadRequest,
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errors: []string{"\"path\" is required"},
|
||||
},
|
||||
{
|
||||
name: "RelativePathDotSlash",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: "./relative",
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"file path must be absolute"},
|
||||
},
|
||||
{
|
||||
name: "RelativePath",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: "also-relative",
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"file path must be absolute"},
|
||||
},
|
||||
{
|
||||
name: "NoEdits",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "no-edits"),
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"must specify at least one edit"},
|
||||
},
|
||||
{
|
||||
name: "NonExistent",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "does-not-exist"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusNotFound,
|
||||
errors: []string{"file does not exist"},
|
||||
},
|
||||
{
|
||||
name: "IsDir",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: dirPath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusBadRequest,
|
||||
errors: []string{"not a file"},
|
||||
},
|
||||
{
|
||||
name: "NoPermissions",
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: noPermsFilePath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusForbidden,
|
||||
errors: []string{"permission denied"},
|
||||
},
|
||||
{
|
||||
name: "FailRename",
|
||||
contents: map[string]string{failRenameFilePath: "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: failRenameFilePath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
errCode: http.StatusInternalServerError,
|
||||
errors: []string{"rename failed"},
|
||||
},
|
||||
{
|
||||
name: "Edit1",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit1"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "edit1"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit1"): "bar bar"},
|
||||
},
|
||||
{
|
||||
name: "EditEdit", // Edits affect previous edits.
|
||||
contents: map[string]string{filepath.Join(tmpdir, "edit-edit"): "foo bar"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "edit-edit"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo",
|
||||
Replace: "bar",
|
||||
},
|
||||
{
|
||||
Search: "bar",
|
||||
Replace: "qux",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "edit-edit"): "qux qux"},
|
||||
},
|
||||
{
|
||||
name: "Multiline",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "multiline"): "foo\nbar\nbaz\nqux"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "multiline"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "bar\nbaz",
|
||||
Replace: "frob",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "multiline"): "foo\nfrob\nqux"},
|
||||
},
|
||||
{
|
||||
name: "Multifile",
|
||||
contents: map[string]string{
|
||||
filepath.Join(tmpdir, "file1"): "file 1",
|
||||
filepath.Join(tmpdir, "file2"): "file 2",
|
||||
filepath.Join(tmpdir, "file3"): "file 3",
|
||||
},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "file1"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "file2"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "file3"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited3",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{
|
||||
filepath.Join(tmpdir, "file1"): "edited1 1",
|
||||
filepath.Join(tmpdir, "file2"): "edited2 2",
|
||||
filepath.Join(tmpdir, "file3"): "edited3 3",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultiError",
|
||||
contents: map[string]string{
|
||||
filepath.Join(tmpdir, "file8"): "file 8",
|
||||
},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: noPermsFilePath,
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited7",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "file8"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited8",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "file9"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "file",
|
||||
Replace: "edited9",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{
|
||||
filepath.Join(tmpdir, "file8"): "edited8 8",
|
||||
},
|
||||
// Higher status codes will override lower ones, so in this case the 404
|
||||
// takes priority over the 403.
|
||||
errCode: http.StatusNotFound,
|
||||
errors: []string{
|
||||
fmt.Sprintf("%s: permission denied", noPermsFilePath),
|
||||
"file9: file does not exist",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
for path, content := range tt.contents {
|
||||
err := afero.WriteFile(fs, path, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := conn.EditFiles(ctx, workspacesdk.FileEditRequest{Files: tt.edits})
|
||||
if tt.errCode != 0 {
|
||||
require.Error(t, err)
|
||||
cerr := coderdtest.SDKError(t, err)
|
||||
for _, error := range tt.errors {
|
||||
require.Contains(t, cerr.Error(), error)
|
||||
}
|
||||
require.Equal(t, tt.errCode, cerr.StatusCode())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for path, expect := range tt.expected {
|
||||
b, err := afero.ReadFile(fs, path)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expect, string(b))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package backedpipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPipeClosed = xerrors.New("pipe is closed")
|
||||
ErrPipeAlreadyConnected = xerrors.New("pipe is already connected")
|
||||
ErrReconnectionInProgress = xerrors.New("reconnection already in progress")
|
||||
ErrReconnectFailed = xerrors.New("reconnect failed")
|
||||
ErrInvalidSequenceNumber = xerrors.New("remote sequence number exceeds local sequence")
|
||||
ErrReconnectWriterFailed = xerrors.New("reconnect writer failed")
|
||||
)
|
||||
|
||||
// connectionState represents the current state of the BackedPipe connection.
|
||||
type connectionState int
|
||||
|
||||
const (
|
||||
// connected indicates the pipe is connected and operational.
|
||||
connected connectionState = iota
|
||||
// disconnected indicates the pipe is not connected but not closed.
|
||||
disconnected
|
||||
// reconnecting indicates a reconnection attempt is in progress.
|
||||
reconnecting
|
||||
// closed indicates the pipe is permanently closed.
|
||||
closed
|
||||
)
|
||||
|
||||
// ErrorEvent represents an error from a reader or writer with connection generation info.
|
||||
type ErrorEvent struct {
|
||||
Err error
|
||||
Component string // "reader" or "writer"
|
||||
Generation uint64 // connection generation when error occurred
|
||||
}
|
||||
|
||||
const (
|
||||
// Default buffer capacity used by the writer - 64MB
|
||||
DefaultBufferSize = 64 * 1024 * 1024
|
||||
)
|
||||
|
||||
// Reconnector is an interface for establishing connections when the BackedPipe needs to reconnect.
|
||||
// Implementations should:
|
||||
// 1. Establish a new connection to the remote side
|
||||
// 2. Exchange sequence numbers with the remote side
|
||||
// 3. Return the new connection and the remote's reader sequence number
|
||||
//
|
||||
// The readerSeqNum parameter is the local reader's current sequence number
|
||||
// (total bytes successfully read from the remote). This must be sent to the
|
||||
// remote so it can replay its data to us starting from this number.
|
||||
//
|
||||
// The returned remoteReaderSeqNum should be the remote side's reader sequence
|
||||
// number (how many bytes of our outbound data it has successfully read). This
|
||||
// informs our writer where to resume (i.e., which bytes to replay to the remote).
|
||||
type Reconnector interface {
|
||||
Reconnect(ctx context.Context, readerSeqNum uint64) (conn io.ReadWriteCloser, remoteReaderSeqNum uint64, err error)
|
||||
}
|
||||
|
||||
// BackedPipe provides a reliable bidirectional byte stream over unreliable network connections.
|
||||
// It orchestrates a BackedReader and BackedWriter to provide transparent reconnection
|
||||
// and data replay capabilities.
|
||||
type BackedPipe struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
reader *BackedReader
|
||||
writer *BackedWriter
|
||||
reconnector Reconnector
|
||||
conn io.ReadWriteCloser
|
||||
|
||||
// State machine
|
||||
state connectionState
|
||||
connGen uint64 // Increments on each successful reconnection
|
||||
|
||||
// Unified error handling with generation filtering
|
||||
errChan chan ErrorEvent
|
||||
|
||||
// singleflight group to dedupe concurrent ForceReconnect calls
|
||||
sf singleflight.Group
|
||||
|
||||
// Track first error per generation to avoid duplicate reconnections
|
||||
lastErrorGen uint64
|
||||
}
|
||||
|
||||
// NewBackedPipe creates a new BackedPipe with default options and the specified reconnector.
|
||||
// The pipe starts disconnected and must be connected using Connect().
|
||||
func NewBackedPipe(ctx context.Context, reconnector Reconnector) *BackedPipe {
|
||||
pipeCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errChan := make(chan ErrorEvent, 1)
|
||||
|
||||
bp := &BackedPipe{
|
||||
ctx: pipeCtx,
|
||||
cancel: cancel,
|
||||
reconnector: reconnector,
|
||||
state: disconnected,
|
||||
connGen: 0, // Start with generation 0
|
||||
errChan: errChan,
|
||||
}
|
||||
|
||||
// Create reader and writer with typed error channel for generation-aware error reporting
|
||||
bp.reader = NewBackedReader(errChan)
|
||||
bp.writer = NewBackedWriter(DefaultBufferSize, errChan)
|
||||
|
||||
// Start error handler goroutine
|
||||
go bp.handleErrors()
|
||||
|
||||
return bp
|
||||
}
|
||||
|
||||
// Connect establishes the initial connection using the reconnect function.
|
||||
func (bp *BackedPipe) Connect() error {
|
||||
bp.mu.Lock()
|
||||
defer bp.mu.Unlock()
|
||||
|
||||
if bp.state == closed {
|
||||
return ErrPipeClosed
|
||||
}
|
||||
|
||||
if bp.state == connected {
|
||||
return ErrPipeAlreadyConnected
|
||||
}
|
||||
|
||||
// Use internal context for the actual reconnect operation to ensure
|
||||
// Close() reliably cancels any in-flight attempt.
|
||||
return bp.reconnectLocked()
|
||||
}
|
||||
|
||||
// Read implements io.Reader by delegating to the BackedReader.
|
||||
func (bp *BackedPipe) Read(p []byte) (int, error) {
|
||||
return bp.reader.Read(p)
|
||||
}
|
||||
|
||||
// Write implements io.Writer by delegating to the BackedWriter.
|
||||
func (bp *BackedPipe) Write(p []byte) (int, error) {
|
||||
bp.mu.RLock()
|
||||
writer := bp.writer
|
||||
state := bp.state
|
||||
bp.mu.RUnlock()
|
||||
|
||||
if state == closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return writer.Write(p)
|
||||
}
|
||||
|
||||
// Close closes the pipe and all underlying connections.
|
||||
func (bp *BackedPipe) Close() error {
|
||||
bp.mu.Lock()
|
||||
defer bp.mu.Unlock()
|
||||
|
||||
if bp.state == closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
bp.state = closed
|
||||
bp.cancel() // Cancel main context
|
||||
|
||||
// Close all components in parallel to avoid deadlocks
|
||||
//
|
||||
// IMPORTANT: The connection must be closed first to unblock any
|
||||
// readers or writers that might be holding the mutex on Read/Write
|
||||
var g errgroup.Group
|
||||
|
||||
if bp.conn != nil {
|
||||
conn := bp.conn
|
||||
g.Go(func() error {
|
||||
return conn.Close()
|
||||
})
|
||||
bp.conn = nil
|
||||
}
|
||||
|
||||
if bp.reader != nil {
|
||||
reader := bp.reader
|
||||
g.Go(func() error {
|
||||
return reader.Close()
|
||||
})
|
||||
}
|
||||
|
||||
if bp.writer != nil {
|
||||
writer := bp.writer
|
||||
g.Go(func() error {
|
||||
return writer.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all close operations to complete and return any error
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// Connected returns whether the pipe is currently connected.
|
||||
func (bp *BackedPipe) Connected() bool {
|
||||
bp.mu.RLock()
|
||||
defer bp.mu.RUnlock()
|
||||
return bp.state == connected && bp.reader.Connected() && bp.writer.Connected()
|
||||
}
|
||||
|
||||
// reconnectLocked handles the reconnection logic. Must be called with write lock held.
|
||||
func (bp *BackedPipe) reconnectLocked() error {
|
||||
if bp.state == reconnecting {
|
||||
return ErrReconnectionInProgress
|
||||
}
|
||||
|
||||
bp.state = reconnecting
|
||||
defer func() {
|
||||
// Only reset to disconnected if we're still in reconnecting state
|
||||
// (successful reconnection will set state to connected)
|
||||
if bp.state == reconnecting {
|
||||
bp.state = disconnected
|
||||
}
|
||||
}()
|
||||
|
||||
// Close existing connection if any
|
||||
if bp.conn != nil {
|
||||
_ = bp.conn.Close()
|
||||
bp.conn = nil
|
||||
}
|
||||
|
||||
// Increment the generation and update both reader and writer.
|
||||
// We do it now to track even the connections that fail during
|
||||
// Reconnect.
|
||||
bp.connGen++
|
||||
bp.reader.SetGeneration(bp.connGen)
|
||||
bp.writer.SetGeneration(bp.connGen)
|
||||
|
||||
// Reconnect reader and writer
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go bp.reader.Reconnect(seqNum, newR)
|
||||
|
||||
// Get the precise reader sequence number from the reader while it holds its lock
|
||||
readerSeqNum, ok := <-seqNum
|
||||
if !ok {
|
||||
// Reader was closed during reconnection
|
||||
return ErrReconnectFailed
|
||||
}
|
||||
|
||||
// Perform reconnect using the exact sequence number we just received
|
||||
conn, remoteReaderSeqNum, err := bp.reconnector.Reconnect(bp.ctx, readerSeqNum)
|
||||
if err != nil {
|
||||
// Unblock reader reconnect
|
||||
newR <- nil
|
||||
return ErrReconnectFailed
|
||||
}
|
||||
|
||||
// Provide the new connection to the reader (reader still holds its lock)
|
||||
newR <- conn
|
||||
|
||||
// Replay our outbound data from the remote's reader sequence number
|
||||
writerReconnectErr := bp.writer.Reconnect(remoteReaderSeqNum, conn)
|
||||
if writerReconnectErr != nil {
|
||||
return ErrReconnectWriterFailed
|
||||
}
|
||||
|
||||
// Success - update state
|
||||
bp.conn = conn
|
||||
bp.state = connected
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleErrors listens for connection errors from reader/writer and triggers reconnection.
|
||||
// It filters errors from old connections and ensures only the first error per generation
|
||||
// triggers reconnection.
|
||||
func (bp *BackedPipe) handleErrors() {
|
||||
for {
|
||||
select {
|
||||
case <-bp.ctx.Done():
|
||||
return
|
||||
case errorEvt := <-bp.errChan:
|
||||
bp.handleConnectionError(errorEvt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectionError handles errors from either reader or writer components.
|
||||
// It filters errors from old connections and ensures only one reconnection per generation.
|
||||
func (bp *BackedPipe) handleConnectionError(errorEvt ErrorEvent) {
|
||||
bp.mu.Lock()
|
||||
defer bp.mu.Unlock()
|
||||
|
||||
// Skip if already closed
|
||||
if bp.state == closed {
|
||||
return
|
||||
}
|
||||
|
||||
// Filter errors from old connections (lower generation)
|
||||
if errorEvt.Generation < bp.connGen {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if not connected (already disconnected or reconnecting)
|
||||
if bp.state != connected {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if we've already seen an error for this generation
|
||||
if bp.lastErrorGen >= errorEvt.Generation {
|
||||
return
|
||||
}
|
||||
|
||||
// This is the first error for this generation
|
||||
bp.lastErrorGen = errorEvt.Generation
|
||||
|
||||
// Mark as disconnected
|
||||
bp.state = disconnected
|
||||
|
||||
// Try to reconnect using internal context
|
||||
reconnectErr := bp.reconnectLocked()
|
||||
|
||||
if reconnectErr != nil {
|
||||
// Reconnection failed - log or handle as needed
|
||||
// For now, we'll just continue and wait for manual reconnection
|
||||
_ = errorEvt.Err // Use the original error from the component
|
||||
_ = errorEvt.Component // Component info available for potential logging by higher layers
|
||||
}
|
||||
}
|
||||
|
||||
// ForceReconnect forces a reconnection attempt immediately.
|
||||
// This can be used to force a reconnection if a new connection is established.
|
||||
// It prevents duplicate reconnections when called concurrently.
|
||||
func (bp *BackedPipe) ForceReconnect() error {
|
||||
// Deduplicate concurrent ForceReconnect calls so only one reconnection
|
||||
// attempt runs at a time from this API. Use the pipe's internal context
|
||||
// to ensure Close() cancels any in-flight attempt.
|
||||
_, err, _ := bp.sf.Do("force-reconnect", func() (interface{}, error) {
|
||||
bp.mu.Lock()
|
||||
defer bp.mu.Unlock()
|
||||
|
||||
if bp.state == closed {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// Don't force reconnect if already reconnecting
|
||||
if bp.state == reconnecting {
|
||||
return nil, ErrReconnectionInProgress
|
||||
}
|
||||
|
||||
return nil, bp.reconnectLocked()
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -1,989 +0,0 @@
|
||||
package backedpipe_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mockConnection implements io.ReadWriteCloser for testing
|
||||
type mockConnection struct {
|
||||
mu sync.Mutex
|
||||
readBuffer bytes.Buffer
|
||||
writeBuffer bytes.Buffer
|
||||
closed bool
|
||||
readError error
|
||||
writeError error
|
||||
closeError error
|
||||
readFunc func([]byte) (int, error)
|
||||
writeFunc func([]byte) (int, error)
|
||||
seqNum uint64
|
||||
}
|
||||
|
||||
func newMockConnection() *mockConnection {
|
||||
return &mockConnection{}
|
||||
}
|
||||
|
||||
func (mc *mockConnection) Read(p []byte) (int, error) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
if mc.readFunc != nil {
|
||||
return mc.readFunc(p)
|
||||
}
|
||||
|
||||
if mc.readError != nil {
|
||||
return 0, mc.readError
|
||||
}
|
||||
|
||||
return mc.readBuffer.Read(p)
|
||||
}
|
||||
|
||||
func (mc *mockConnection) Write(p []byte) (int, error) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
|
||||
if mc.writeFunc != nil {
|
||||
return mc.writeFunc(p)
|
||||
}
|
||||
|
||||
if mc.writeError != nil {
|
||||
return 0, mc.writeError
|
||||
}
|
||||
|
||||
return mc.writeBuffer.Write(p)
|
||||
}
|
||||
|
||||
func (mc *mockConnection) Close() error {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.closed = true
|
||||
return mc.closeError
|
||||
}
|
||||
|
||||
func (mc *mockConnection) WriteString(s string) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
_, _ = mc.readBuffer.WriteString(s)
|
||||
}
|
||||
|
||||
func (mc *mockConnection) ReadString() string {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
return mc.writeBuffer.String()
|
||||
}
|
||||
|
||||
func (mc *mockConnection) SetReadError(err error) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.readError = err
|
||||
}
|
||||
|
||||
func (mc *mockConnection) SetWriteError(err error) {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.writeError = err
|
||||
}
|
||||
|
||||
func (mc *mockConnection) Reset() {
|
||||
mc.mu.Lock()
|
||||
defer mc.mu.Unlock()
|
||||
mc.readBuffer.Reset()
|
||||
mc.writeBuffer.Reset()
|
||||
mc.readError = nil
|
||||
mc.writeError = nil
|
||||
mc.closed = false
|
||||
}
|
||||
|
||||
// mockReconnector implements the Reconnector interface for testing
|
||||
type mockReconnector struct {
|
||||
mu sync.Mutex
|
||||
connections []*mockConnection
|
||||
connectionIndex int
|
||||
callCount int
|
||||
signalChan chan struct{}
|
||||
}
|
||||
|
||||
// Reconnect implements the Reconnector interface
|
||||
func (m *mockReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.callCount++
|
||||
|
||||
if m.connectionIndex >= len(m.connections) {
|
||||
return nil, 0, xerrors.New("no more connections available")
|
||||
}
|
||||
|
||||
conn := m.connections[m.connectionIndex]
|
||||
m.connectionIndex++
|
||||
|
||||
// Signal when reconnection happens
|
||||
if m.connectionIndex > 1 {
|
||||
select {
|
||||
case m.signalChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Determine remoteReaderSeqNum (how many bytes of our outbound data the remote has read)
|
||||
var remoteReaderSeqNum uint64
|
||||
switch {
|
||||
case m.callCount == 1:
|
||||
remoteReaderSeqNum = 0
|
||||
case conn.seqNum != 0:
|
||||
remoteReaderSeqNum = conn.seqNum
|
||||
default:
|
||||
// Default to 0 if unspecified
|
||||
remoteReaderSeqNum = 0
|
||||
}
|
||||
|
||||
return conn, remoteReaderSeqNum, nil
|
||||
}
|
||||
|
||||
// GetCallCount returns the current call count in a thread-safe manner
|
||||
func (m *mockReconnector) GetCallCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.callCount
|
||||
}
|
||||
|
||||
// mockReconnectFunc creates a unified reconnector with all behaviors enabled
|
||||
func mockReconnectFunc(connections ...*mockConnection) (*mockReconnector, chan struct{}) {
|
||||
signalChan := make(chan struct{}, 1)
|
||||
|
||||
reconnector := &mockReconnector{
|
||||
connections: connections,
|
||||
signalChan: signalChan,
|
||||
}
|
||||
|
||||
return reconnector, signalChan
|
||||
}
|
||||
|
||||
// blockingReconnector is a reconnector that blocks on a channel for deterministic testing
|
||||
type blockingReconnector struct {
|
||||
conn1 *mockConnection
|
||||
conn2 *mockConnection
|
||||
callCount int
|
||||
blockChan <-chan struct{}
|
||||
blockedChan chan struct{}
|
||||
mu sync.Mutex
|
||||
signalOnce sync.Once // Ensure we only signal once for the first actual reconnect
|
||||
}
|
||||
|
||||
func (b *blockingReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
|
||||
b.mu.Lock()
|
||||
b.callCount++
|
||||
currentCall := b.callCount
|
||||
b.mu.Unlock()
|
||||
|
||||
if currentCall == 1 {
|
||||
// Initial connect
|
||||
return b.conn1, 0, nil
|
||||
}
|
||||
|
||||
// Signal that we're about to block, but only once for the first reconnect attempt
|
||||
// This ensures we properly test singleflight deduplication
|
||||
b.signalOnce.Do(func() {
|
||||
select {
|
||||
case b.blockedChan <- struct{}{}:
|
||||
default:
|
||||
// If channel is full, don't block
|
||||
}
|
||||
})
|
||||
|
||||
// For subsequent calls, block until channel is closed
|
||||
select {
|
||||
case <-b.blockChan:
|
||||
// Channel closed, proceed with reconnection
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
}
|
||||
|
||||
return b.conn2, 0, nil
|
||||
}
|
||||
|
||||
// GetCallCount returns the current call count in a thread-safe manner
|
||||
func (b *blockingReconnector) GetCallCount() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.callCount
|
||||
}
|
||||
|
||||
func mockBlockingReconnectFunc(conn1, conn2 *mockConnection, blockChan <-chan struct{}) (*blockingReconnector, chan struct{}) {
|
||||
blockedChan := make(chan struct{}, 1)
|
||||
reconnector := &blockingReconnector{
|
||||
conn1: conn1,
|
||||
conn2: conn2,
|
||||
blockChan: blockChan,
|
||||
blockedChan: blockedChan,
|
||||
}
|
||||
|
||||
return reconnector, blockedChan
|
||||
}
|
||||
|
||||
// eofTestReconnector is a custom reconnector for the EOF test case
|
||||
type eofTestReconnector struct {
|
||||
mu sync.Mutex
|
||||
conn1 io.ReadWriteCloser
|
||||
conn2 io.ReadWriteCloser
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (e *eofTestReconnector) Reconnect(ctx context.Context, readerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.callCount++
|
||||
|
||||
if e.callCount == 1 {
|
||||
return e.conn1, 0, nil
|
||||
}
|
||||
if e.callCount == 2 {
|
||||
// Second call is the reconnection after EOF
|
||||
// Return 5 to indicate remote has read all 5 bytes of "hello"
|
||||
return e.conn2, 5, nil
|
||||
}
|
||||
|
||||
return nil, 0, xerrors.New("no more connections")
|
||||
}
|
||||
|
||||
// GetCallCount returns the current call count in a thread-safe manner
|
||||
func (e *eofTestReconnector) GetCallCount() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.callCount
|
||||
}
|
||||
|
||||
func TestBackedPipe_NewBackedPipe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
reconnectFn, _ := mockReconnectFunc(newMockConnection())
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
require.NotNil(t, bp)
|
||||
require.False(t, bp.Connected())
|
||||
}
|
||||
|
||||
func TestBackedPipe_Connect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnector, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
}
|
||||
|
||||
func TestBackedPipe_ConnectAlreadyConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second connect should fail
|
||||
err = bp.Connect()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrPipeAlreadyConnected)
|
||||
}
|
||||
|
||||
func TestBackedPipe_ConnectAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
|
||||
err := bp.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = bp.Connect()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrPipeClosed)
|
||||
}
|
||||
|
||||
func TestBackedPipe_BasicReadWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write data
|
||||
n, err := bp.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// Simulate data coming back
|
||||
conn.WriteString("world")
|
||||
|
||||
// Read data
|
||||
buf := make([]byte, 10)
|
||||
n, err = bp.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, "world", string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestBackedPipe_WriteBeforeConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
|
||||
// Write before connecting should block
|
||||
writeComplete := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := bp.Write([]byte("hello"))
|
||||
writeComplete <- err
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when disconnected")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Connect should unblock the write
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
err = testutil.RequireReceive(ctx, t, writeComplete)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that data was replayed to connection
|
||||
require.Equal(t, "hello", conn.ReadString())
|
||||
}
|
||||
|
||||
func TestBackedPipe_ReadBlocksWhenDisconnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
reconnectFn, _ := mockReconnectFunc(newMockConnection())
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
|
||||
// Start a read that should block
|
||||
readDone := make(chan struct{})
|
||||
readStarted := make(chan struct{}, 1)
|
||||
var readErr error
|
||||
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
readStarted <- struct{}{} // Signal that we're about to start the read
|
||||
buf := make([]byte, 10)
|
||||
_, readErr = bp.Read(buf)
|
||||
}()
|
||||
|
||||
// Wait for the goroutine to start
|
||||
testutil.TryReceive(testCtx, t, readStarted)
|
||||
|
||||
// Ensure the read is actually blocked by verifying it hasn't completed
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-readDone:
|
||||
t.Fatal("Read should be blocked when disconnected")
|
||||
return false
|
||||
default:
|
||||
// Good, still blocked
|
||||
return true
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Close should unblock the read
|
||||
bp.Close()
|
||||
|
||||
testutil.TryReceive(testCtx, t, readDone)
|
||||
require.Equal(t, io.EOF, readErr)
|
||||
}
|
||||
|
||||
func TestBackedPipe_Reconnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
conn2.seqNum = 17 // Remote has received 17 bytes, so replay from sequence 17
|
||||
reconnectFn, signalChan := mockReconnectFunc(conn1, conn2)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some data before failure
|
||||
bp.Write([]byte("before disconnect***"))
|
||||
|
||||
// Simulate connection failure
|
||||
conn1.SetReadError(xerrors.New("connection lost"))
|
||||
conn1.SetWriteError(xerrors.New("connection lost"))
|
||||
|
||||
// Trigger a write to cause the pipe to notice the failure
|
||||
_, _ = bp.Write([]byte("trigger failure "))
|
||||
|
||||
testutil.RequireReceive(testCtx, t, signalChan)
|
||||
|
||||
// Wait for reconnection to complete
|
||||
require.Eventually(t, func() bool {
|
||||
return bp.Connected()
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "pipe should reconnect")
|
||||
|
||||
replayedData := conn2.ReadString()
|
||||
require.Equal(t, "***trigger failure ", replayedData, "Should replay exactly the data written after sequence 17")
|
||||
|
||||
// Verify that new writes work with the reconnected pipe
|
||||
_, err = bp.Write([]byte("new data after reconnect"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read all data from the connection (replayed + new data)
|
||||
allData := conn2.ReadString()
|
||||
require.Equal(t, "***trigger failure new data after reconnect", allData, "Should have replayed data plus new data")
|
||||
}
|
||||
|
||||
func TestBackedPipe_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = bp.Close()
|
||||
require.NoError(t, err)
|
||||
require.True(t, conn.closed)
|
||||
|
||||
// Operations after close should fail
|
||||
_, err = bp.Read(make([]byte, 10))
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
_, err = bp.Write([]byte("test"))
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestBackedPipe_CloseIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
|
||||
err := bp.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second close should be no-op
|
||||
err = bp.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackedPipe_ReconnectFunctionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
failingReconnector := &mockReconnector{
|
||||
connections: nil, // No connections available
|
||||
}
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, failingReconnector)
|
||||
defer bp.Close()
|
||||
|
||||
err := bp.Connect()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrReconnectFailed)
|
||||
require.False(t, bp.Connected())
|
||||
}
|
||||
|
||||
func TestBackedPipe_ForceReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
// Set conn2 sequence number to 9 to indicate remote has read all 9 bytes of "test data"
|
||||
conn2.seqNum = 9
|
||||
reconnector, _ := mockReconnectFunc(conn1, conn2)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
|
||||
// Write some data to the first connection
|
||||
_, err = bp.Write([]byte("test data"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test data", conn1.ReadString())
|
||||
|
||||
// Force a reconnection
|
||||
err = bp.ForceReconnect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 2, reconnector.GetCallCount())
|
||||
|
||||
// Since the mock returns the proper sequence number, no data should be replayed
|
||||
// The new connection should be empty
|
||||
require.Equal(t, "", conn2.ReadString())
|
||||
|
||||
// Verify that data can still be written and read after forced reconnection
|
||||
_, err = bp.Write([]byte("new data"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new data", conn2.ReadString())
|
||||
|
||||
// Verify that reads work with the new connection
|
||||
conn2.WriteString("response data")
|
||||
buf := make([]byte, 20)
|
||||
n, err := bp.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 13, n)
|
||||
require.Equal(t, "response data", string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestBackedPipe_ForceReconnectWhenClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
|
||||
// Close the pipe first
|
||||
err := bp.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to force reconnect when closed
|
||||
err = bp.ForceReconnect()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestBackedPipe_StateTransitionsAndGenerationTracking(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
conn3 := newMockConnection()
|
||||
reconnector, signalChan := mockReconnectFunc(conn1, conn2, conn3)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial state should be disconnected
|
||||
require.False(t, bp.Connected())
|
||||
|
||||
// Connect should transition to connected
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
|
||||
// Write some data
|
||||
_, err = bp.Write([]byte("test data gen 1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate connection failure by setting errors on connection
|
||||
conn1.SetReadError(xerrors.New("connection lost"))
|
||||
conn1.SetWriteError(xerrors.New("connection lost"))
|
||||
|
||||
// Trigger a write to cause the pipe to notice the failure
|
||||
_, _ = bp.Write([]byte("trigger failure"))
|
||||
|
||||
// Wait for reconnection signal
|
||||
testutil.RequireReceive(testutil.Context(t, testutil.WaitShort), t, signalChan)
|
||||
|
||||
// Wait for reconnection to complete
|
||||
require.Eventually(t, func() bool {
|
||||
return bp.Connected()
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect")
|
||||
require.Equal(t, 2, reconnector.GetCallCount())
|
||||
|
||||
// Force another reconnection
|
||||
err = bp.ForceReconnect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 3, reconnector.GetCallCount())
|
||||
|
||||
// Close should transition to closed state
|
||||
err = bp.Close()
|
||||
require.NoError(t, err)
|
||||
require.False(t, bp.Connected())
|
||||
|
||||
// Operations on closed pipe should fail
|
||||
err = bp.Connect()
|
||||
require.Equal(t, backedpipe.ErrPipeClosed, err)
|
||||
|
||||
err = bp.ForceReconnect()
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestBackedPipe_GenerationFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
reconnector, _ := mockReconnectFunc(conn1, conn2)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
|
||||
// Simulate multiple rapid errors from the same connection generation
|
||||
// Only the first one should trigger reconnection
|
||||
conn1.SetReadError(xerrors.New("error 1"))
|
||||
conn1.SetWriteError(xerrors.New("error 2"))
|
||||
|
||||
// Trigger multiple errors quickly
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = bp.Write([]byte("trigger error 1"))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = bp.Write([]byte("trigger error 2"))
|
||||
}()
|
||||
|
||||
// Wait for both writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// Wait for reconnection to complete
|
||||
require.Eventually(t, func() bool {
|
||||
return bp.Connected()
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect once")
|
||||
|
||||
// Should have only reconnected once despite multiple errors
|
||||
require.Equal(t, 2, reconnector.GetCallCount()) // Initial connect + 1 reconnect
|
||||
}
|
||||
|
||||
func TestBackedPipe_DuplicateReconnectionPrevention(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Create a blocking reconnector for deterministic testing
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
blockChan := make(chan struct{})
|
||||
reconnector, blockedChan := mockBlockingReconnectFunc(conn1, conn2, blockChan)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, reconnector.GetCallCount(), "should have exactly 1 call after initial connect")
|
||||
|
||||
// We'll use channels to coordinate the test execution:
|
||||
// 1. Start all goroutines but have them wait
|
||||
// 2. Release the first one and wait for it to block
|
||||
// 3. Release the others while the first is still blocked
|
||||
|
||||
const numConcurrent = 3
|
||||
startSignals := make([]chan struct{}, numConcurrent)
|
||||
startedSignals := make([]chan struct{}, numConcurrent)
|
||||
for i := range startSignals {
|
||||
startSignals[i] = make(chan struct{})
|
||||
startedSignals[i] = make(chan struct{})
|
||||
}
|
||||
|
||||
errors := make([]error, numConcurrent)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start all goroutines
|
||||
for i := 0; i < numConcurrent; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
// Wait for the signal to start
|
||||
<-startSignals[idx]
|
||||
// Signal that we're about to call ForceReconnect
|
||||
close(startedSignals[idx])
|
||||
errors[idx] = bp.ForceReconnect()
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Start the first ForceReconnect and wait for it to block
|
||||
close(startSignals[0])
|
||||
<-startedSignals[0]
|
||||
|
||||
// Wait for the first reconnect to actually start and block
|
||||
testutil.RequireReceive(testCtx, t, blockedChan)
|
||||
|
||||
// Now start all the other ForceReconnect calls
|
||||
// They should all join the same singleflight operation
|
||||
for i := 1; i < numConcurrent; i++ {
|
||||
close(startSignals[i])
|
||||
}
|
||||
|
||||
// Wait for all additional goroutines to have started their calls
|
||||
for i := 1; i < numConcurrent; i++ {
|
||||
<-startedSignals[i]
|
||||
}
|
||||
|
||||
// At this point, one reconnect has started and is blocked,
|
||||
// and all other goroutines have called ForceReconnect and should be
|
||||
// waiting on the same singleflight operation.
|
||||
// Due to singleflight, only one reconnect should have been attempted.
|
||||
require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnect due to singleflight")
|
||||
|
||||
// Release the blocking reconnect function
|
||||
close(blockChan)
|
||||
|
||||
// Wait for all ForceReconnect calls to complete
|
||||
wg.Wait()
|
||||
|
||||
// All calls should succeed (they share the same result from singleflight)
|
||||
for i, err := range errors {
|
||||
require.NoError(t, err, "ForceReconnect %d should succeed", i, err)
|
||||
}
|
||||
|
||||
// Final verification: call count should still be exactly 2
|
||||
require.Equal(t, 2, reconnector.GetCallCount(), "final call count should be exactly 2: initial connect + 1 singleflight reconnect")
|
||||
}
|
||||
|
||||
func TestBackedPipe_SingleReconnectionOnMultipleErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
// Create connections for initial connect and reconnection
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
reconnector, signalChan := mockReconnectFunc(conn1, conn2)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
|
||||
// Write some initial data to establish the connection
|
||||
_, err = bp.Write([]byte("initial data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up both read and write errors on the connection
|
||||
conn1.SetReadError(xerrors.New("read connection lost"))
|
||||
conn1.SetWriteError(xerrors.New("write connection lost"))
|
||||
|
||||
// Trigger write error (this will trigger reconnection)
|
||||
go func() {
|
||||
_, _ = bp.Write([]byte("trigger write error"))
|
||||
}()
|
||||
|
||||
// Wait for reconnection to start
|
||||
testutil.RequireReceive(testCtx, t, signalChan)
|
||||
|
||||
// Wait for reconnection to complete
|
||||
require.Eventually(t, func() bool {
|
||||
return bp.Connected()
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "should reconnect after write error")
|
||||
|
||||
// Verify that only one reconnection occurred
|
||||
require.Equal(t, 2, reconnector.GetCallCount(), "should have exactly 2 calls: initial connect + 1 reconnection")
|
||||
require.True(t, bp.Connected(), "should be connected after reconnection")
|
||||
}
|
||||
|
||||
func TestBackedPipe_ForceReconnectWhenDisconnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnector, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Don't connect initially, just force reconnect
|
||||
err := bp.ForceReconnect()
|
||||
require.NoError(t, err)
|
||||
require.True(t, bp.Connected())
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
|
||||
// Verify we can write and read
|
||||
_, err = bp.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", conn.ReadString())
|
||||
|
||||
conn.WriteString("response")
|
||||
buf := make([]byte, 10)
|
||||
n, err := bp.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8, n)
|
||||
require.Equal(t, "response", string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestBackedPipe_EOFTriggersReconnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create connections where we can control when EOF occurs
|
||||
conn1 := newMockConnection()
|
||||
conn2 := newMockConnection()
|
||||
conn2.WriteString("newdata") // Pre-populate conn2 with data
|
||||
|
||||
// Make conn1 return EOF after reading "world"
|
||||
hasReadData := false
|
||||
conn1.readFunc = func(p []byte) (int, error) {
|
||||
// Don't lock here - the Read method already holds the lock
|
||||
|
||||
// First time: return "world"
|
||||
if !hasReadData && conn1.readBuffer.Len() > 0 {
|
||||
n, _ := conn1.readBuffer.Read(p)
|
||||
hasReadData = true
|
||||
return n, nil
|
||||
}
|
||||
// After that: return EOF
|
||||
return 0, io.EOF
|
||||
}
|
||||
conn1.WriteString("world")
|
||||
|
||||
reconnector := &eofTestReconnector{
|
||||
conn1: conn1,
|
||||
conn2: conn2,
|
||||
}
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnector)
|
||||
defer bp.Close()
|
||||
|
||||
// Initial connect
|
||||
err := bp.Connect()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, reconnector.GetCallCount())
|
||||
|
||||
// Write some data
|
||||
_, err = bp.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
|
||||
// First read should succeed
|
||||
n, err := bp.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, "world", string(buf[:n]))
|
||||
|
||||
// Next read will encounter EOF and should trigger reconnection
|
||||
// After reconnection, it should read from conn2
|
||||
n, err = bp.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 7, n)
|
||||
require.Equal(t, "newdata", string(buf[:n]))
|
||||
|
||||
// Verify reconnection happened
|
||||
require.Equal(t, 2, reconnector.GetCallCount())
|
||||
|
||||
// Verify the pipe is still connected and functional
|
||||
require.True(t, bp.Connected())
|
||||
|
||||
// Further writes should go to the new connection
|
||||
_, err = bp.Write([]byte("aftereof"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "aftereof", conn2.ReadString())
|
||||
}
|
||||
|
||||
func BenchmarkBackedPipe_Write(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
bp.Connect()
|
||||
b.Cleanup(func() {
|
||||
_ = bp.Close()
|
||||
})
|
||||
|
||||
data := make([]byte, 1024) // 1KB writes
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bp.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBackedPipe_Read(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
conn := newMockConnection()
|
||||
reconnectFn, _ := mockReconnectFunc(conn)
|
||||
|
||||
bp := backedpipe.NewBackedPipe(ctx, reconnectFn)
|
||||
bp.Connect()
|
||||
b.Cleanup(func() {
|
||||
_ = bp.Close()
|
||||
})
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Fill connection with fresh data for each iteration
|
||||
conn.WriteString(string(buf))
|
||||
bp.Read(buf)
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
package backedpipe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BackedReader wraps an unreliable io.Reader and makes it resilient to disconnections.
|
||||
// It tracks sequence numbers for all bytes read and can handle reconnection,
|
||||
// blocking reads when disconnected instead of erroring.
|
||||
type BackedReader struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
reader io.Reader
|
||||
sequenceNum uint64
|
||||
closed bool
|
||||
|
||||
// Error channel for generation-aware error reporting
|
||||
errorEventChan chan<- ErrorEvent
|
||||
|
||||
// Current connection generation for error reporting
|
||||
currentGen uint64
|
||||
}
|
||||
|
||||
// NewBackedReader creates a new BackedReader with generation-aware error reporting.
|
||||
// The reader is initially disconnected and must be connected using Reconnect before
|
||||
// reads will succeed. The errorEventChan will receive ErrorEvent structs containing
|
||||
// error details, component info, and connection generation.
|
||||
func NewBackedReader(errorEventChan chan<- ErrorEvent) *BackedReader {
|
||||
if errorEventChan == nil {
|
||||
panic("error event channel cannot be nil")
|
||||
}
|
||||
br := &BackedReader{
|
||||
errorEventChan: errorEventChan,
|
||||
}
|
||||
br.cond = sync.NewCond(&br.mu)
|
||||
return br
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It blocks when disconnected until either:
|
||||
// 1. A reconnection is established
|
||||
// 2. The reader is closed
|
||||
//
|
||||
// When connected, it reads from the underlying reader and updates sequence numbers.
|
||||
// Connection failures are automatically detected and reported to the higher layer via callback.
|
||||
func (br *BackedReader) Read(p []byte) (int, error) {
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
|
||||
for {
|
||||
// Step 1: Wait until we have a reader or are closed
|
||||
for br.reader == nil && !br.closed {
|
||||
br.cond.Wait()
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Step 2: Perform the read while holding the mutex
|
||||
// This ensures proper synchronization with Reconnect and Close operations
|
||||
n, err := br.reader.Read(p)
|
||||
br.sequenceNum += uint64(n) // #nosec G115 -- n is always >= 0 per io.Reader contract
|
||||
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Mark reader as disconnected so future reads will wait for reconnection
|
||||
br.reader = nil
|
||||
|
||||
// Notify parent of error with generation information
|
||||
select {
|
||||
case br.errorEventChan <- ErrorEvent{
|
||||
Err: err,
|
||||
Component: "reader",
|
||||
Generation: br.currentGen,
|
||||
}:
|
||||
default:
|
||||
// Channel is full, drop the error.
|
||||
// This is not a problem, because we set the reader to nil
|
||||
// and block until reconnected so no new errors will be sent
|
||||
// until pipe processes the error and reconnects.
|
||||
}
|
||||
|
||||
// If we got some data before the error, return it now
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect coordinates reconnection using channels for better synchronization.
|
||||
// The seqNum channel is used to send the current sequence number to the caller.
|
||||
// The newR channel is used to receive the new reader from the caller.
|
||||
// This allows for better coordination during the reconnection process.
|
||||
func (br *BackedReader) Reconnect(seqNum chan<- uint64, newR <-chan io.Reader) {
|
||||
// Grab the lock
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
|
||||
if br.closed {
|
||||
// Close the channel to indicate closed state
|
||||
close(seqNum)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the sequence number to send to the other side via seqNum channel
|
||||
seqNum <- br.sequenceNum
|
||||
close(seqNum)
|
||||
|
||||
// Wait for the reconnect to complete, via newR channel, and give us a new io.Reader
|
||||
newReader := <-newR
|
||||
|
||||
// If reconnection fails while we are starting it, the caller sends nil on newR
|
||||
if newReader == nil {
|
||||
// Reconnection failed, keep current state
|
||||
return
|
||||
}
|
||||
|
||||
// Reconnection successful
|
||||
br.reader = newReader
|
||||
|
||||
// Notify any waiting reads via the cond
|
||||
br.cond.Broadcast()
|
||||
}
|
||||
|
||||
// Close the reader and wake up any blocked reads.
|
||||
// After closing, all Read calls will return io.EOF.
|
||||
func (br *BackedReader) Close() error {
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
br.closed = true
|
||||
br.reader = nil
|
||||
|
||||
// Wake up any blocked reads
|
||||
br.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SequenceNum returns the current sequence number (total bytes read).
|
||||
func (br *BackedReader) SequenceNum() uint64 {
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
return br.sequenceNum
|
||||
}
|
||||
|
||||
// Connected returns whether the reader is currently connected.
|
||||
func (br *BackedReader) Connected() bool {
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
return br.reader != nil
|
||||
}
|
||||
|
||||
// SetGeneration sets the current connection generation for error reporting.
|
||||
func (br *BackedReader) SetGeneration(generation uint64) {
|
||||
br.mu.Lock()
|
||||
defer br.mu.Unlock()
|
||||
br.currentGen = generation
|
||||
}
|
||||
@@ -1,603 +0,0 @@
|
||||
package backedpipe_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mockReader implements io.Reader with controllable behavior for testing
|
||||
type mockReader struct {
|
||||
mu sync.Mutex
|
||||
data []byte
|
||||
pos int
|
||||
err error
|
||||
readFunc func([]byte) (int, error)
|
||||
}
|
||||
|
||||
func newMockReader(data string) *mockReader {
|
||||
return &mockReader{data: []byte(data)}
|
||||
}
|
||||
|
||||
func (mr *mockReader) Read(p []byte) (int, error) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
|
||||
if mr.readFunc != nil {
|
||||
return mr.readFunc(p)
|
||||
}
|
||||
|
||||
if mr.err != nil {
|
||||
return 0, mr.err
|
||||
}
|
||||
|
||||
if mr.pos >= len(mr.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n := copy(p, mr.data[mr.pos:])
|
||||
mr.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (mr *mockReader) setError(err error) {
|
||||
mr.mu.Lock()
|
||||
defer mr.mu.Unlock()
|
||||
mr.err = err
|
||||
}
|
||||
|
||||
func TestBackedReader_NewBackedReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
require.NotNil(t, br)
|
||||
require.Equal(t, uint64(0), br.SequenceNum())
|
||||
require.False(t, br.Connected())
|
||||
}
|
||||
|
||||
func TestBackedReader_BasicReadOperation(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader := newMockReader("hello world")
|
||||
|
||||
// Connect the reader
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number from reader
|
||||
seq := testutil.RequireReceive(ctx, t, seqNum)
|
||||
require.Equal(t, uint64(0), seq)
|
||||
|
||||
// Send new reader
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
|
||||
|
||||
// Read data
|
||||
buf := make([]byte, 5)
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, "hello", string(buf))
|
||||
require.Equal(t, uint64(5), br.SequenceNum())
|
||||
|
||||
// Read more data
|
||||
n, err = br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, " worl", string(buf))
|
||||
require.Equal(t, uint64(10), br.SequenceNum())
|
||||
}
|
||||
|
||||
func TestBackedReader_ReadBlocksWhenDisconnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
|
||||
// Start a read operation that should block
|
||||
readDone := make(chan struct{})
|
||||
var readErr error
|
||||
var readBuf []byte
|
||||
var readN int
|
||||
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
buf := make([]byte, 10)
|
||||
readN, readErr = br.Read(buf)
|
||||
readBuf = buf[:readN]
|
||||
}()
|
||||
|
||||
// Ensure the read is actually blocked by verifying it hasn't completed
|
||||
// and that the reader is not connected
|
||||
select {
|
||||
case <-readDone:
|
||||
t.Fatal("Read should be blocked when disconnected")
|
||||
default:
|
||||
// Read is still blocked, which is what we want
|
||||
}
|
||||
require.False(t, br.Connected(), "Reader should not be connected")
|
||||
|
||||
// Connect and the read should unblock
|
||||
reader := newMockReader("test")
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send new reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
|
||||
|
||||
// Wait for read to complete
|
||||
testutil.TryReceive(ctx, t, readDone)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, "test", string(readBuf))
|
||||
}
|
||||
|
||||
func TestBackedReader_ReconnectionAfterFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader1 := newMockReader("first")
|
||||
|
||||
// Initial connection
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send new reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(reader1))
|
||||
|
||||
// Read some data
|
||||
buf := make([]byte, 5)
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "first", string(buf[:n]))
|
||||
require.Equal(t, uint64(5), br.SequenceNum())
|
||||
|
||||
// Simulate connection failure
|
||||
reader1.setError(xerrors.New("connection lost"))
|
||||
|
||||
// Start a read that will block due to connection failure
|
||||
readDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := br.Read(buf)
|
||||
readDone <- err
|
||||
}()
|
||||
|
||||
// Wait for the error to be reported via error channel
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Error(t, receivedErrorEvent.Err)
|
||||
require.Equal(t, "reader", receivedErrorEvent.Component)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
|
||||
|
||||
// Verify read is still blocked
|
||||
select {
|
||||
case err := <-readDone:
|
||||
t.Fatalf("Read should still be blocked, but completed with: %v", err)
|
||||
default:
|
||||
// Good, still blocked
|
||||
}
|
||||
|
||||
// Verify disconnection
|
||||
require.False(t, br.Connected())
|
||||
|
||||
// Reconnect with new reader
|
||||
reader2 := newMockReader("second")
|
||||
seqNum2 := make(chan uint64, 1)
|
||||
newR2 := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum2, newR2)
|
||||
|
||||
// Get sequence number and send new reader
|
||||
seq := testutil.RequireReceive(ctx, t, seqNum2)
|
||||
require.Equal(t, uint64(5), seq) // Should return current sequence number
|
||||
testutil.RequireSend(ctx, t, newR2, io.Reader(reader2))
|
||||
|
||||
// Wait for read to unblock and succeed with new data
|
||||
readErr := testutil.RequireReceive(ctx, t, readDone)
|
||||
require.NoError(t, readErr) // Should succeed with new reader
|
||||
require.True(t, br.Connected())
|
||||
}
|
||||
|
||||
func TestBackedReader_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader := newMockReader("test")
|
||||
|
||||
// Connect
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send new reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(reader))
|
||||
|
||||
// First, read all available data
|
||||
buf := make([]byte, 10)
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, n) // "test" is 4 bytes
|
||||
|
||||
// Close the reader before EOF triggers reconnection
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After close, reads should return EOF
|
||||
n, err = br.Read(buf)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
// Subsequent reads should return EOF
|
||||
_, err = br.Read(buf)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestBackedReader_CloseIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
|
||||
err := br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second close should be no-op
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackedReader_ReconnectAfterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
|
||||
err := br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Should get 0 sequence number for closed reader
|
||||
seq := testutil.TryReceive(ctx, t, seqNum)
|
||||
require.Equal(t, uint64(0), seq)
|
||||
}
|
||||
|
||||
// Helper function to reconnect a reader using channels
|
||||
func reconnectReader(ctx context.Context, t testing.TB, br *backedpipe.BackedReader, reader io.Reader) {
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send new reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, reader)
|
||||
}
|
||||
|
||||
func TestBackedReader_SequenceNumberTracking(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader := newMockReader("0123456789")
|
||||
|
||||
reconnectReader(ctx, t, br, reader)
|
||||
|
||||
// Read in chunks and verify sequence number
|
||||
buf := make([]byte, 3)
|
||||
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
require.Equal(t, uint64(3), br.SequenceNum())
|
||||
|
||||
n, err = br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
require.Equal(t, uint64(6), br.SequenceNum())
|
||||
|
||||
n, err = br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
require.Equal(t, uint64(9), br.SequenceNum())
|
||||
}
|
||||
|
||||
func TestBackedReader_EOFHandling(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader := newMockReader("test")
|
||||
|
||||
reconnectReader(ctx, t, br, reader)
|
||||
|
||||
// Read all data
|
||||
buf := make([]byte, 10)
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, n)
|
||||
require.Equal(t, "test", string(buf[:n]))
|
||||
|
||||
// Next read should encounter EOF, which triggers disconnection
|
||||
// The read should block waiting for reconnection
|
||||
readDone := make(chan struct{})
|
||||
var readErr error
|
||||
var readN int
|
||||
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
readN, readErr = br.Read(buf)
|
||||
}()
|
||||
|
||||
// Wait for EOF to be reported via error channel
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Equal(t, io.EOF, receivedErrorEvent.Err)
|
||||
require.Equal(t, "reader", receivedErrorEvent.Component)
|
||||
|
||||
// Reader should be disconnected after EOF
|
||||
require.False(t, br.Connected())
|
||||
|
||||
// Read should still be blocked
|
||||
select {
|
||||
case <-readDone:
|
||||
t.Fatal("Read should be blocked waiting for reconnection after EOF")
|
||||
default:
|
||||
// Good, still blocked
|
||||
}
|
||||
|
||||
// Reconnect with new data
|
||||
reader2 := newMockReader("more")
|
||||
reconnectReader(ctx, t, br, reader2)
|
||||
|
||||
// Wait for the blocked read to complete with new data
|
||||
testutil.TryReceive(ctx, t, readDone)
|
||||
require.NoError(t, readErr)
|
||||
require.Equal(t, 4, readN)
|
||||
require.Equal(t, "more", string(buf[:readN]))
|
||||
}
|
||||
|
||||
func BenchmarkBackedReader_Read(b *testing.B) {
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
// Create a reader that never returns EOF by cycling through data
|
||||
reader := &mockReader{
|
||||
readFunc: func(p []byte) (int, error) {
|
||||
// Fill buffer with 'x' characters - never EOF
|
||||
for i := range p {
|
||||
p[i] = 'x'
|
||||
}
|
||||
return len(p), nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
reconnectReader(ctx, b, br, reader)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
br.Read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackedReader_PartialReads(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
|
||||
// Create a reader that returns partial reads
|
||||
reader := &mockReader{
|
||||
readFunc: func(p []byte) (int, error) {
|
||||
// Always return just 1 byte at a time
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
p[0] = 'A'
|
||||
return 1, nil
|
||||
},
|
||||
}
|
||||
|
||||
reconnectReader(ctx, t, br, reader)
|
||||
|
||||
// Read multiple times
|
||||
buf := make([]byte, 10)
|
||||
for i := 0; i < 5; i++ {
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, n)
|
||||
require.Equal(t, byte('A'), buf[0])
|
||||
}
|
||||
|
||||
require.Equal(t, uint64(5), br.SequenceNum())
|
||||
}
|
||||
|
||||
func TestBackedReader_CloseWhileBlockedOnUnderlyingReader(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
|
||||
// Create a reader that blocks on Read calls but can be unblocked
|
||||
readStarted := make(chan struct{}, 1)
|
||||
readUnblocked := make(chan struct{})
|
||||
blockingReader := &mockReader{
|
||||
readFunc: func(p []byte) (int, error) {
|
||||
select {
|
||||
case readStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-readUnblocked // Block until signaled
|
||||
// After unblocking, return an error to simulate connection failure
|
||||
return 0, xerrors.New("connection interrupted")
|
||||
},
|
||||
}
|
||||
|
||||
// Connect the blocking reader
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send blocking reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(blockingReader))
|
||||
|
||||
// Start a read that will block on the underlying reader
|
||||
readDone := make(chan struct{})
|
||||
var readErr error
|
||||
var readN int
|
||||
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
buf := make([]byte, 10)
|
||||
readN, readErr = br.Read(buf)
|
||||
}()
|
||||
|
||||
// Wait for the read to start and block on the underlying reader
|
||||
testutil.RequireReceive(ctx, t, readStarted)
|
||||
|
||||
// Verify read is blocked by checking that it hasn't completed
|
||||
// and ensuring we have adequate time for it to reach the blocking state
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-readDone:
|
||||
t.Fatal("Read should be blocked on underlying reader")
|
||||
return false
|
||||
default:
|
||||
// Good, still blocked
|
||||
return true
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Start Close() in a goroutine since it will block until the underlying read completes
|
||||
closeDone := make(chan error, 1)
|
||||
go func() {
|
||||
closeDone <- br.Close()
|
||||
}()
|
||||
|
||||
// Verify Close() is also blocked waiting for the underlying read
|
||||
select {
|
||||
case <-closeDone:
|
||||
t.Fatal("Close should be blocked until underlying read completes")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Good, Close is blocked
|
||||
}
|
||||
|
||||
// Unblock the underlying reader, which will cause both the read and close to complete
|
||||
close(readUnblocked)
|
||||
|
||||
// Wait for both the read and close to complete
|
||||
testutil.TryReceive(ctx, t, readDone)
|
||||
closeErr := testutil.RequireReceive(ctx, t, closeDone)
|
||||
require.NoError(t, closeErr)
|
||||
|
||||
// The read should return EOF because Close() was called while it was blocked,
|
||||
// even though the underlying reader returned an error
|
||||
require.Equal(t, 0, readN)
|
||||
require.Equal(t, io.EOF, readErr)
|
||||
|
||||
// Subsequent reads should return EOF since the reader is now closed
|
||||
buf := make([]byte, 10)
|
||||
n, err := br.Read(buf)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, io.EOF, err)
|
||||
}
|
||||
|
||||
func TestBackedReader_CloseWhileBlockedWaitingForReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
br := backedpipe.NewBackedReader(errChan)
|
||||
reader1 := newMockReader("initial")
|
||||
|
||||
// Initial connection
|
||||
seqNum := make(chan uint64, 1)
|
||||
newR := make(chan io.Reader, 1)
|
||||
|
||||
go br.Reconnect(seqNum, newR)
|
||||
|
||||
// Get sequence number and send initial reader
|
||||
testutil.RequireReceive(ctx, t, seqNum)
|
||||
testutil.RequireSend(ctx, t, newR, io.Reader(reader1))
|
||||
|
||||
// Read initial data
|
||||
buf := make([]byte, 10)
|
||||
n, err := br.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "initial", string(buf[:n]))
|
||||
|
||||
// Simulate connection failure
|
||||
reader1.setError(xerrors.New("connection lost"))
|
||||
|
||||
// Start a read that will block waiting for reconnection
|
||||
readDone := make(chan struct{})
|
||||
var readErr error
|
||||
var readN int
|
||||
|
||||
go func() {
|
||||
defer close(readDone)
|
||||
readN, readErr = br.Read(buf)
|
||||
}()
|
||||
|
||||
// Wait for the error to be reported (indicating disconnection)
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Error(t, receivedErrorEvent.Err)
|
||||
require.Equal(t, "reader", receivedErrorEvent.Component)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
|
||||
|
||||
// Verify read is blocked waiting for reconnection
|
||||
select {
|
||||
case <-readDone:
|
||||
t.Fatal("Read should be blocked waiting for reconnection")
|
||||
default:
|
||||
// Good, still blocked
|
||||
}
|
||||
|
||||
// Verify reader is disconnected
|
||||
require.False(t, br.Connected())
|
||||
|
||||
// Close the BackedReader while read is blocked waiting for reconnection
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The read should unblock and return EOF
|
||||
testutil.TryReceive(ctx, t, readDone)
|
||||
require.Equal(t, 0, readN)
|
||||
require.Equal(t, io.EOF, readErr)
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
package backedpipe
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWriterClosed = xerrors.New("cannot reconnect closed writer")
|
||||
ErrNilWriter = xerrors.New("new writer cannot be nil")
|
||||
ErrFutureSequence = xerrors.New("cannot replay from future sequence")
|
||||
ErrReplayDataUnavailable = xerrors.New("failed to read replay data")
|
||||
ErrReplayFailed = xerrors.New("replay failed")
|
||||
ErrPartialReplay = xerrors.New("partial replay")
|
||||
)
|
||||
|
||||
// BackedWriter wraps an unreliable io.Writer and makes it resilient to disconnections.
|
||||
// It maintains a ring buffer of recent writes for replay during reconnection.
|
||||
type BackedWriter struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
writer io.Writer
|
||||
buffer *ringBuffer
|
||||
sequenceNum uint64 // total bytes written
|
||||
closed bool
|
||||
|
||||
// Error channel for generation-aware error reporting
|
||||
errorEventChan chan<- ErrorEvent
|
||||
|
||||
// Current connection generation for error reporting
|
||||
currentGen uint64
|
||||
}
|
||||
|
||||
// NewBackedWriter creates a new BackedWriter with generation-aware error reporting.
|
||||
// The writer is initially disconnected and will block writes until connected.
|
||||
// The errorEventChan will receive ErrorEvent structs containing error details,
|
||||
// component info, and connection generation. Capacity must be > 0.
|
||||
func NewBackedWriter(capacity int, errorEventChan chan<- ErrorEvent) *BackedWriter {
|
||||
if capacity <= 0 {
|
||||
panic("backed writer capacity must be > 0")
|
||||
}
|
||||
if errorEventChan == nil {
|
||||
panic("error event channel cannot be nil")
|
||||
}
|
||||
bw := &BackedWriter{
|
||||
buffer: newRingBuffer(capacity),
|
||||
errorEventChan: errorEventChan,
|
||||
}
|
||||
bw.cond = sync.NewCond(&bw.mu)
|
||||
return bw
|
||||
}
|
||||
|
||||
// blockUntilConnectedOrClosed blocks until either a writer is available or the BackedWriter is closed.
|
||||
// Returns os.ErrClosed if closed while waiting, nil if connected. You must hold the mutex to call this.
|
||||
func (bw *BackedWriter) blockUntilConnectedOrClosed() error {
|
||||
for bw.writer == nil && !bw.closed {
|
||||
bw.cond.Wait()
|
||||
}
|
||||
if bw.closed {
|
||||
return os.ErrClosed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
// When connected, it writes to both the ring buffer (to preserve data in case we need to replay it)
|
||||
// and the underlying writer.
|
||||
// If the underlying write fails, the writer is marked as disconnected and the write blocks
|
||||
// until reconnection occurs.
|
||||
func (bw *BackedWriter) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
|
||||
// Block until connected
|
||||
if err := bw.blockUntilConnectedOrClosed(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Write to buffer
|
||||
bw.buffer.Write(p)
|
||||
bw.sequenceNum += uint64(len(p))
|
||||
|
||||
// Try to write to underlying writer
|
||||
n, err := bw.writer.Write(p)
|
||||
if err == nil && n != len(p) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Connection failed or partial write, mark as disconnected
|
||||
bw.writer = nil
|
||||
|
||||
// Notify parent of error with generation information
|
||||
select {
|
||||
case bw.errorEventChan <- ErrorEvent{
|
||||
Err: err,
|
||||
Component: "writer",
|
||||
Generation: bw.currentGen,
|
||||
}:
|
||||
default:
|
||||
// Channel is full, drop the error.
|
||||
// This is not a problem, because we set the writer to nil
|
||||
// and block until reconnected so no new errors will be sent
|
||||
// until pipe processes the error and reconnects.
|
||||
}
|
||||
|
||||
// Block until reconnected - reconnection will replay this data
|
||||
if err := bw.blockUntilConnectedOrClosed(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Don't retry - reconnection replay handled it
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Write succeeded
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Reconnect replaces the current writer with a new one and replays data from the specified
|
||||
// sequence number. If the requested sequence number is no longer in the buffer,
|
||||
// returns an error indicating data loss.
|
||||
//
|
||||
// IMPORTANT: You must close the current writer, if any, before calling this method.
|
||||
// Otherwise, if a Write operation is currently blocked in the underlying writer's
|
||||
// Write method, this method will deadlock waiting for the mutex that Write holds.
|
||||
func (bw *BackedWriter) Reconnect(replayFromSeq uint64, newWriter io.Writer) error {
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
|
||||
if bw.closed {
|
||||
return ErrWriterClosed
|
||||
}
|
||||
|
||||
if newWriter == nil {
|
||||
return ErrNilWriter
|
||||
}
|
||||
|
||||
// Check if we can replay from the requested sequence number
|
||||
if replayFromSeq > bw.sequenceNum {
|
||||
return ErrFutureSequence
|
||||
}
|
||||
|
||||
// Calculate how many bytes we need to replay
|
||||
replayBytes := bw.sequenceNum - replayFromSeq
|
||||
|
||||
var replayData []byte
|
||||
if replayBytes > 0 {
|
||||
// Get the last replayBytes from buffer
|
||||
// If the buffer doesn't have enough data (some was evicted),
|
||||
// ReadLast will return an error
|
||||
var err error
|
||||
// Safe conversion: The check above (replayFromSeq > bw.sequenceNum) ensures
|
||||
// replayBytes = bw.sequenceNum - replayFromSeq is always <= bw.sequenceNum.
|
||||
// Since sequence numbers are much smaller than maxInt, the uint64->int conversion is safe.
|
||||
//nolint:gosec // Safe conversion: replayBytes <= sequenceNum, which is much less than maxInt
|
||||
replayData, err = bw.buffer.ReadLast(int(replayBytes))
|
||||
if err != nil {
|
||||
return ErrReplayDataUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the current writer first in case replay fails
|
||||
bw.writer = nil
|
||||
|
||||
// Replay data if needed. We keep the mutex held during replay to ensure
|
||||
// no concurrent operations can interfere with the reconnection process.
|
||||
if len(replayData) > 0 {
|
||||
n, err := newWriter.Write(replayData)
|
||||
if err != nil {
|
||||
// Reconnect failed, writer remains nil
|
||||
return ErrReplayFailed
|
||||
}
|
||||
|
||||
if n != len(replayData) {
|
||||
// Reconnect failed, writer remains nil
|
||||
return ErrPartialReplay
|
||||
}
|
||||
}
|
||||
|
||||
// Set new writer only after successful replay. This ensures no concurrent
|
||||
// writes can interfere with the replay operation.
|
||||
bw.writer = newWriter
|
||||
|
||||
// Wake up any operations waiting for connection
|
||||
bw.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the writer and prevents further writes.
|
||||
// After closing, all Write calls will return os.ErrClosed.
|
||||
// This code keeps the Close() signature consistent with io.Closer,
|
||||
// but it never actually returns an error.
|
||||
//
|
||||
// IMPORTANT: You must close the current underlying writer, if any, before calling
|
||||
// this method. Otherwise, if a Write operation is currently blocked in the
|
||||
// underlying writer's Write method, this method will deadlock waiting for the
|
||||
// mutex that Write holds.
|
||||
func (bw *BackedWriter) Close() error {
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
|
||||
if bw.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
bw.closed = true
|
||||
bw.writer = nil
|
||||
|
||||
// Wake up any blocked operations
|
||||
bw.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SequenceNum returns the current sequence number (total bytes written).
|
||||
func (bw *BackedWriter) SequenceNum() uint64 {
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
return bw.sequenceNum
|
||||
}
|
||||
|
||||
// Connected returns whether the writer is currently connected.
|
||||
func (bw *BackedWriter) Connected() bool {
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
return bw.writer != nil
|
||||
}
|
||||
|
||||
// SetGeneration sets the current connection generation for error reporting.
|
||||
func (bw *BackedWriter) SetGeneration(generation uint64) {
|
||||
bw.mu.Lock()
|
||||
defer bw.mu.Unlock()
|
||||
bw.currentGen = generation
|
||||
}
|
||||
@@ -1,992 +0,0 @@
|
||||
package backedpipe_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/immortalstreams/backedpipe"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// mockWriter implements io.Writer with controllable behavior for testing
|
||||
type mockWriter struct {
|
||||
mu sync.Mutex
|
||||
buffer bytes.Buffer
|
||||
err error
|
||||
writeFunc func([]byte) (int, error)
|
||||
writeCalls int
|
||||
}
|
||||
|
||||
func newMockWriter() *mockWriter {
|
||||
return &mockWriter{}
|
||||
}
|
||||
|
||||
// newBackedWriterForTest creates a BackedWriter with a small buffer for testing eviction behavior
|
||||
func newBackedWriterForTest(bufferSize int) *backedpipe.BackedWriter {
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
return backedpipe.NewBackedWriter(bufferSize, errChan)
|
||||
}
|
||||
|
||||
func (mw *mockWriter) Write(p []byte) (int, error) {
|
||||
mw.mu.Lock()
|
||||
defer mw.mu.Unlock()
|
||||
|
||||
mw.writeCalls++
|
||||
|
||||
if mw.writeFunc != nil {
|
||||
return mw.writeFunc(p)
|
||||
}
|
||||
|
||||
if mw.err != nil {
|
||||
return 0, mw.err
|
||||
}
|
||||
|
||||
return mw.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (mw *mockWriter) Len() int {
|
||||
mw.mu.Lock()
|
||||
defer mw.mu.Unlock()
|
||||
return mw.buffer.Len()
|
||||
}
|
||||
|
||||
func (mw *mockWriter) Reset() {
|
||||
mw.mu.Lock()
|
||||
defer mw.mu.Unlock()
|
||||
mw.buffer.Reset()
|
||||
mw.writeCalls = 0
|
||||
mw.err = nil
|
||||
mw.writeFunc = nil
|
||||
}
|
||||
|
||||
func (mw *mockWriter) setError(err error) {
|
||||
mw.mu.Lock()
|
||||
defer mw.mu.Unlock()
|
||||
mw.err = err
|
||||
}
|
||||
|
||||
func TestBackedWriter_NewBackedWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
require.NotNil(t, bw)
|
||||
require.Equal(t, uint64(0), bw.SequenceNum())
|
||||
require.False(t, bw.Connected())
|
||||
}
|
||||
|
||||
func TestBackedWriter_WriteBlocksWhenDisconnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Write should block when disconnected
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
n, writeErr = bw.Write([]byte("hello"))
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when disconnected")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Connect and verify write completes
|
||||
writer := newMockWriter()
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
testutil.TryReceive(ctx, t, writeComplete)
|
||||
|
||||
require.NoError(t, writeErr)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, uint64(5), bw.SequenceNum())
|
||||
require.Equal(t, []byte("hello"), writer.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_WriteToUnderlyingWhenConnected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
writer := newMockWriter()
|
||||
|
||||
// Connect
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
require.True(t, bw.Connected())
|
||||
|
||||
// Write should go to both buffer and underlying writer
|
||||
n, err := bw.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// Data should be buffered
|
||||
require.Equal(t, uint64(5), bw.SequenceNum())
|
||||
|
||||
// Check underlying writer
|
||||
require.Equal(t, []byte("hello"), writer.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_BlockOnWriteFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
writer := newMockWriter()
|
||||
|
||||
// Connect
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cause write to fail
|
||||
writer.setError(xerrors.New("write failed"))
|
||||
|
||||
// Write should block when underlying writer fails, not succeed immediately
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
n, writeErr = bw.Write([]byte("hello"))
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when underlying writer fails")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Wait for error event which implies writer was marked disconnected
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "write failed")
|
||||
require.Equal(t, "writer", receivedErrorEvent.Component)
|
||||
require.False(t, bw.Connected())
|
||||
|
||||
// Reconnect with working writer and verify write completes
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(0, writer2) // Replay from beginning
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
testutil.TryReceive(ctx, t, writeComplete)
|
||||
|
||||
require.NoError(t, writeErr)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, uint64(5), bw.SequenceNum())
|
||||
require.Equal(t, []byte("hello"), writer2.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_ReplayOnReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Connect initially to write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some data while connected
|
||||
_, err = bw.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
_, err = bw.Write([]byte(" world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, uint64(11), bw.SequenceNum())
|
||||
|
||||
// Disconnect by causing a write failure
|
||||
writer1.setError(xerrors.New("connection lost"))
|
||||
|
||||
// Write should block when underlying writer fails
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
n, writeErr = bw.Write([]byte("test"))
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when underlying writer fails")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Wait for error event which implies writer was marked disconnected
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
|
||||
require.Equal(t, "writer", receivedErrorEvent.Component)
|
||||
require.False(t, bw.Connected())
|
||||
|
||||
// Reconnect with new writer and request replay from beginning
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(0, writer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
select {
|
||||
case <-writeComplete:
|
||||
// Expected - write completed
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Write should have completed after reconnection")
|
||||
}
|
||||
|
||||
require.NoError(t, writeErr)
|
||||
require.Equal(t, 4, n)
|
||||
|
||||
// Should have replayed all data including the failed write that was buffered
|
||||
require.Equal(t, []byte("hello worldtest"), writer2.buffer.Bytes())
|
||||
|
||||
// Write new data should go to both
|
||||
_, err = bw.Write([]byte("!"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("hello worldtest!"), writer2.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_PartialReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Connect initially to write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some data
|
||||
_, err = bw.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
_, err = bw.Write([]byte(" world"))
|
||||
require.NoError(t, err)
|
||||
_, err = bw.Write([]byte("!"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reconnect with new writer and request replay from middle
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(5, writer2) // From " world!"
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have replayed only the requested portion
|
||||
require.Equal(t, []byte(" world!"), writer2.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_ReplayFromFutureSequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Connect initially to write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = bw.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(10, writer2) // Future sequence
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrFutureSequence)
|
||||
}
|
||||
|
||||
func TestBackedWriter_ReplayDataLoss(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bw := newBackedWriterForTest(10) // Small buffer for testing
|
||||
|
||||
// Connect initially to write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill buffer beyond capacity to cause eviction
|
||||
_, err = bw.Write([]byte("0123456789")) // Fills buffer exactly
|
||||
require.NoError(t, err)
|
||||
_, err = bw.Write([]byte("abcdef")) // Should evict "012345"
|
||||
require.NoError(t, err)
|
||||
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(0, writer2) // Try to replay from evicted data
|
||||
// With the new error handling, this should fail because we can't read all the data
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable)
|
||||
}
|
||||
|
||||
func TestBackedWriter_BufferEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bw := newBackedWriterForTest(5) // Very small buffer for testing
|
||||
|
||||
// Connect initially
|
||||
writer := newMockWriter()
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write data that will cause eviction
|
||||
n, err := bw.Write([]byte("abcde"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// Write more to cause eviction
|
||||
n, err = bw.Write([]byte("fg"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, n)
|
||||
|
||||
// Verify that the buffer contains only the latest data after eviction
|
||||
// Total sequence number should be 7 (5 + 2)
|
||||
require.Equal(t, uint64(7), bw.SequenceNum())
|
||||
|
||||
// Try to reconnect from the beginning - this should fail because
|
||||
// the early data was evicted from the buffer
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(0, writer2)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrReplayDataUnavailable)
|
||||
|
||||
// However, reconnecting from a sequence that's still in the buffer should work
|
||||
// The buffer should contain the last 5 bytes: "cdefg"
|
||||
writer3 := newMockWriter()
|
||||
err = bw.Reconnect(2, writer3) // From sequence 2, should replay "cdefg"
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("cdefg"), writer3.buffer.Bytes())
|
||||
require.True(t, bw.Connected())
|
||||
}
|
||||
|
||||
func TestBackedWriter_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
writer := newMockWriter()
|
||||
|
||||
bw.Reconnect(0, writer)
|
||||
|
||||
err := bw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Writes after close should fail
|
||||
_, err = bw.Write([]byte("test"))
|
||||
require.Equal(t, os.ErrClosed, err)
|
||||
|
||||
// Reconnect after close should fail
|
||||
err = bw.Reconnect(0, newMockWriter())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrWriterClosed)
|
||||
}
|
||||
|
||||
func TestBackedWriter_CloseIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
err := bw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second close should be no-op
|
||||
err = bw.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackedWriter_ReconnectDuringReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Connect initially to write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = bw.Write([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a writer that fails during replay
|
||||
writer2 := &mockWriter{
|
||||
err: backedpipe.ErrReplayFailed,
|
||||
}
|
||||
|
||||
err = bw.Reconnect(0, writer2)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, backedpipe.ErrReplayFailed)
|
||||
require.False(t, bw.Connected())
|
||||
}
|
||||
|
||||
func TestBackedWriter_BlockOnPartialWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Create writer that does partial writes
|
||||
writer := &mockWriter{
|
||||
writeFunc: func(p []byte) (int, error) {
|
||||
if len(p) > 3 {
|
||||
return 3, nil // Only write first 3 bytes
|
||||
}
|
||||
return len(p), nil
|
||||
},
|
||||
}
|
||||
|
||||
bw.Reconnect(0, writer)
|
||||
|
||||
// Write should block due to partial write
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
n, writeErr = bw.Write([]byte("hello"))
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when underlying writer does partial write")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Wait for error event which implies writer was marked disconnected
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "short write")
|
||||
require.Equal(t, "writer", receivedErrorEvent.Component)
|
||||
require.False(t, bw.Connected())
|
||||
|
||||
// Reconnect with working writer and verify write completes
|
||||
writer2 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer2) // Replay from beginning
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
testutil.TryReceive(ctx, t, writeComplete)
|
||||
|
||||
require.NoError(t, writeErr)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, uint64(5), bw.SequenceNum())
|
||||
require.Equal(t, []byte("hello"), writer2.buffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBackedWriter_WriteUnblocksOnReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Start a single write that should block
|
||||
writeResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := bw.Write([]byte("test"))
|
||||
writeResult <- err
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeResult:
|
||||
t.Fatal("Write should have blocked when disconnected")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Connect and verify write completes
|
||||
writer := newMockWriter()
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete
|
||||
err = testutil.RequireReceive(ctx, t, writeResult)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should have been written to the underlying writer
|
||||
require.Equal(t, "test", writer.buffer.String())
|
||||
}
|
||||
|
||||
func TestBackedWriter_CloseUnblocksWaitingWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Start a write that should block
|
||||
writeComplete := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := bw.Write([]byte("test"))
|
||||
writeComplete <- err
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked when disconnected")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Close the writer
|
||||
err := bw.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should now complete with error
|
||||
err = testutil.RequireReceive(ctx, t, writeComplete)
|
||||
require.Equal(t, os.ErrClosed, err)
|
||||
}
|
||||
|
||||
func TestBackedWriter_WriteBlocksAfterDisconnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
writer := newMockWriter()
|
||||
|
||||
// Connect initially
|
||||
err := bw.Reconnect(0, writer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write should succeed when connected
|
||||
_, err = bw.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cause disconnection - the write should now block instead of returning an error
|
||||
writer.setError(xerrors.New("connection lost"))
|
||||
|
||||
// This write should block
|
||||
writeComplete := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := bw.Write([]byte("world"))
|
||||
writeComplete <- err
|
||||
}()
|
||||
|
||||
// Verify write is blocked
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should have blocked after disconnection")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Expected - write is blocked
|
||||
}
|
||||
|
||||
// Wait for error event which implies writer was marked disconnected
|
||||
receivedErrorEvent := testutil.RequireReceive(ctx, t, errChan)
|
||||
require.Contains(t, receivedErrorEvent.Err.Error(), "connection lost")
|
||||
require.Equal(t, "writer", receivedErrorEvent.Component)
|
||||
require.False(t, bw.Connected())
|
||||
|
||||
// Reconnect and verify write completes
|
||||
writer2 := newMockWriter()
|
||||
err = bw.Reconnect(5, writer2) // Replay from after "hello"
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testutil.RequireReceive(ctx, t, writeComplete)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that only "world" was written during replay (not duplicated)
|
||||
require.Equal(t, []byte("world"), writer2.buffer.Bytes()) // Only "world" since we replayed from sequence 5
|
||||
}
|
||||
|
||||
func TestBackedWriter_ConcurrentWriteAndClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Don't connect initially - this will cause writes to block in blockUntilConnectedOrClosed()
|
||||
|
||||
writeStarted := make(chan struct{}, 1)
|
||||
|
||||
// Start a write operation that will block waiting for connection
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
// Signal that we're about to start the write
|
||||
writeStarted <- struct{}{}
|
||||
// This write will block in blockUntilConnectedOrClosed() since no writer is connected
|
||||
n, writeErr = bw.Write([]byte("hello"))
|
||||
}()
|
||||
|
||||
// Wait for write goroutine to start
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.RequireReceive(ctx, t, writeStarted)
|
||||
|
||||
// Ensure the write is actually blocked by repeatedly checking that:
|
||||
// 1. The write hasn't completed yet
|
||||
// 2. The writer is still not connected
|
||||
// We use require.Eventually to give it a fair chance to reach the blocking state
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should be blocked when no writer is connected")
|
||||
return false
|
||||
default:
|
||||
// Write is still blocked, which is what we want
|
||||
return !bw.Connected()
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalMedium)
|
||||
|
||||
// Close the writer while the write is blocked waiting for connection
|
||||
closeErr := bw.Close()
|
||||
require.NoError(t, closeErr)
|
||||
|
||||
// Wait for write to complete
|
||||
select {
|
||||
case <-writeComplete:
|
||||
// Good, write completed
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Write did not complete in time")
|
||||
}
|
||||
|
||||
// The write should have failed with os.ErrClosed because Close() was called
|
||||
// while it was waiting for connection
|
||||
require.ErrorIs(t, writeErr, os.ErrClosed)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
// Subsequent writes should also fail
|
||||
n, err := bw.Write([]byte("world"))
|
||||
require.Equal(t, 0, n)
|
||||
require.ErrorIs(t, err, os.ErrClosed)
|
||||
}
|
||||
|
||||
func TestBackedWriter_ConcurrentWriteAndReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Initial connection
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some initial data
|
||||
_, err = bw.Write([]byte("initial"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start reconnection which will block new writes
|
||||
replayStarted := make(chan struct{}, 1) // Buffered to prevent race condition
|
||||
replayCanComplete := make(chan struct{})
|
||||
writer2 := &mockWriter{
|
||||
writeFunc: func(p []byte) (int, error) {
|
||||
// Signal that replay has started
|
||||
select {
|
||||
case replayStarted <- struct{}{}:
|
||||
default:
|
||||
// Signal already sent, which is fine
|
||||
}
|
||||
// Wait for test to allow replay to complete
|
||||
<-replayCanComplete
|
||||
return len(p), nil
|
||||
},
|
||||
}
|
||||
|
||||
// Start the reconnection in a goroutine so we can control timing
|
||||
reconnectComplete := make(chan error, 1)
|
||||
go func() {
|
||||
reconnectComplete <- bw.Reconnect(0, writer2)
|
||||
}()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
// Wait for replay to start
|
||||
testutil.RequireReceive(ctx, t, replayStarted)
|
||||
|
||||
// Now start a write operation that will be blocked by the ongoing reconnect
|
||||
writeStarted := make(chan struct{}, 1)
|
||||
writeComplete := make(chan struct{})
|
||||
var writeErr error
|
||||
var n int
|
||||
|
||||
go func() {
|
||||
defer close(writeComplete)
|
||||
// Signal that we're about to start the write
|
||||
writeStarted <- struct{}{}
|
||||
// This write should be blocked during reconnect
|
||||
n, writeErr = bw.Write([]byte("blocked"))
|
||||
}()
|
||||
|
||||
// Wait for write to start
|
||||
testutil.RequireReceive(ctx, t, writeStarted)
|
||||
|
||||
// Use a small timeout to ensure the write goroutine has a chance to get blocked
|
||||
// on the mutex before we check if it's still blocked
|
||||
writeCheckTimer := time.NewTimer(testutil.IntervalFast)
|
||||
defer writeCheckTimer.Stop()
|
||||
|
||||
select {
|
||||
case <-writeComplete:
|
||||
t.Fatal("Write should be blocked during reconnect")
|
||||
case <-writeCheckTimer.C:
|
||||
// Write is still blocked after a reasonable wait
|
||||
}
|
||||
|
||||
// Allow replay to complete, which will allow reconnect to finish
|
||||
close(replayCanComplete)
|
||||
|
||||
// Wait for reconnection to complete
|
||||
select {
|
||||
case reconnectErr := <-reconnectComplete:
|
||||
require.NoError(t, reconnectErr)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Reconnect did not complete in time")
|
||||
}
|
||||
|
||||
// Wait for write to complete
|
||||
<-writeComplete
|
||||
|
||||
// Write should succeed after reconnection completes
|
||||
require.NoError(t, writeErr)
|
||||
require.Equal(t, 7, n) // "blocked" is 7 bytes
|
||||
|
||||
// Verify the writer is connected
|
||||
require.True(t, bw.Connected())
|
||||
}
|
||||
|
||||
func TestBackedWriter_ConcurrentReconnectAndClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Initial connection and write some data
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
_, err = bw.Write([]byte("test data"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start reconnection with slow replay
|
||||
reconnectStarted := make(chan struct{}, 1)
|
||||
replayCanComplete := make(chan struct{})
|
||||
reconnectComplete := make(chan struct{})
|
||||
var reconnectErr error
|
||||
|
||||
go func() {
|
||||
defer close(reconnectComplete)
|
||||
writer2 := &mockWriter{
|
||||
writeFunc: func(p []byte) (int, error) {
|
||||
// Signal that replay has started
|
||||
select {
|
||||
case reconnectStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// Wait for test to allow replay to complete
|
||||
<-replayCanComplete
|
||||
return len(p), nil
|
||||
},
|
||||
}
|
||||
reconnectErr = bw.Reconnect(0, writer2)
|
||||
}()
|
||||
|
||||
// Wait for reconnection to start
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
testutil.RequireReceive(ctx, t, reconnectStarted)
|
||||
|
||||
// Start Close() in a separate goroutine since it will block until Reconnect() completes
|
||||
closeStarted := make(chan struct{}, 1)
|
||||
closeComplete := make(chan error, 1)
|
||||
go func() {
|
||||
closeStarted <- struct{}{} // Signal that Close() is starting
|
||||
closeComplete <- bw.Close()
|
||||
}()
|
||||
|
||||
// Wait for Close() to start, then give it a moment to attempt to acquire the mutex
|
||||
testutil.RequireReceive(ctx, t, closeStarted)
|
||||
closeCheckTimer := time.NewTimer(testutil.IntervalFast)
|
||||
defer closeCheckTimer.Stop()
|
||||
|
||||
select {
|
||||
case <-closeComplete:
|
||||
t.Fatal("Close should be blocked during reconnect")
|
||||
case <-closeCheckTimer.C:
|
||||
// Good, Close is still blocked after a reasonable wait
|
||||
}
|
||||
|
||||
// Allow replay to complete so reconnection can finish
|
||||
close(replayCanComplete)
|
||||
|
||||
// Wait for reconnect to complete
|
||||
select {
|
||||
case <-reconnectComplete:
|
||||
// Good, reconnect completed
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Reconnect did not complete in time")
|
||||
}
|
||||
|
||||
// Wait for close to complete
|
||||
select {
|
||||
case closeErr := <-closeComplete:
|
||||
require.NoError(t, closeErr)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Close did not complete in time")
|
||||
}
|
||||
|
||||
// With mutex held during replay, Close() waits for Reconnect() to finish.
|
||||
// So Reconnect() should succeed, then Close() runs and closes the writer.
|
||||
require.NoError(t, reconnectErr)
|
||||
|
||||
// Verify writer is closed (Close() ran after Reconnect() completed)
|
||||
require.False(t, bw.Connected())
|
||||
}
|
||||
|
||||
func TestBackedWriter_MultipleWritesDuringReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Initial connection
|
||||
writer1 := newMockWriter()
|
||||
err := bw.Reconnect(0, writer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some initial data
|
||||
_, err = bw.Write([]byte("initial"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start multiple write operations
|
||||
numWriters := 5
|
||||
var wg sync.WaitGroup
|
||||
writeResults := make([]error, numWriters)
|
||||
writesStarted := make(chan struct{}, numWriters)
|
||||
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
// Signal that this write is starting
|
||||
writesStarted <- struct{}{}
|
||||
data := []byte{byte('A' + id)}
|
||||
_, writeResults[id] = bw.Write(data)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all writes to start
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
for i := 0; i < numWriters; i++ {
|
||||
testutil.RequireReceive(ctx, t, writesStarted)
|
||||
}
|
||||
|
||||
// Use a timer to ensure all write goroutines have had a chance to start executing
|
||||
// and potentially get blocked on the mutex before we start the reconnection
|
||||
writesReadyTimer := time.NewTimer(testutil.IntervalFast)
|
||||
defer writesReadyTimer.Stop()
|
||||
<-writesReadyTimer.C
|
||||
|
||||
// Start reconnection with controlled replay
|
||||
replayStarted := make(chan struct{}, 1)
|
||||
replayCanComplete := make(chan struct{})
|
||||
writer2 := &mockWriter{
|
||||
writeFunc: func(p []byte) (int, error) {
|
||||
// Signal that replay has started
|
||||
select {
|
||||
case replayStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// Wait for test to allow replay to complete
|
||||
<-replayCanComplete
|
||||
return len(p), nil
|
||||
},
|
||||
}
|
||||
|
||||
// Start reconnection in a goroutine so we can control timing
|
||||
reconnectComplete := make(chan error, 1)
|
||||
go func() {
|
||||
reconnectComplete <- bw.Reconnect(0, writer2)
|
||||
}()
|
||||
|
||||
// Wait for replay to start
|
||||
testutil.RequireReceive(ctx, t, replayStarted)
|
||||
|
||||
// Allow replay to complete
|
||||
close(replayCanComplete)
|
||||
|
||||
// Wait for reconnection to complete
|
||||
select {
|
||||
case reconnectErr := <-reconnectComplete:
|
||||
require.NoError(t, reconnectErr)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Reconnect did not complete in time")
|
||||
}
|
||||
|
||||
// Wait for all writes to complete
|
||||
wg.Wait()
|
||||
|
||||
// All writes should succeed
|
||||
for i, err := range writeResults {
|
||||
require.NoError(t, err, "Write %d should succeed", i)
|
||||
}
|
||||
|
||||
// Verify the writer is connected
|
||||
require.True(t, bw.Connected())
|
||||
}
|
||||
|
||||
func BenchmarkBackedWriter_Write(b *testing.B) {
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan) // 64KB buffer
|
||||
writer := newMockWriter()
|
||||
bw.Reconnect(0, writer)
|
||||
|
||||
data := bytes.Repeat([]byte("x"), 1024) // 1KB writes
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
bw.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBackedWriter_Reconnect(b *testing.B) {
|
||||
errChan := make(chan backedpipe.ErrorEvent, 1)
|
||||
bw := backedpipe.NewBackedWriter(backedpipe.DefaultBufferSize, errChan)
|
||||
|
||||
// Connect initially to fill buffer with data
|
||||
initialWriter := newMockWriter()
|
||||
err := bw.Reconnect(0, initialWriter)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Fill buffer with data
|
||||
data := bytes.Repeat([]byte("x"), 1024)
|
||||
for i := 0; i < 32; i++ {
|
||||
bw.Write(data)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
writer := newMockWriter()
|
||||
bw.Reconnect(0, writer)
|
||||
}
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package backedpipe
|
||||
|
||||
import "golang.org/x/xerrors"
|
||||
|
||||
// ringBuffer implements an efficient circular buffer with a fixed-size allocation.
|
||||
// This implementation is not thread-safe and relies on external synchronization.
|
||||
type ringBuffer struct {
|
||||
buffer []byte
|
||||
start int // index of first valid byte
|
||||
end int // index of last valid byte (-1 when empty)
|
||||
}
|
||||
|
||||
// newRingBuffer creates a new ring buffer with the specified capacity.
|
||||
// Capacity must be > 0.
|
||||
func newRingBuffer(capacity int) *ringBuffer {
|
||||
if capacity <= 0 {
|
||||
panic("ring buffer capacity must be > 0")
|
||||
}
|
||||
return &ringBuffer{
|
||||
buffer: make([]byte, capacity),
|
||||
end: -1, // -1 indicates empty buffer
|
||||
}
|
||||
}
|
||||
|
||||
// Size returns the current number of bytes in the buffer.
|
||||
func (rb *ringBuffer) Size() int {
|
||||
if rb.end == -1 {
|
||||
return 0 // Buffer is empty
|
||||
}
|
||||
if rb.start <= rb.end {
|
||||
return rb.end - rb.start + 1
|
||||
}
|
||||
// Buffer wraps around
|
||||
return len(rb.buffer) - rb.start + rb.end + 1
|
||||
}
|
||||
|
||||
// Write writes data to the ring buffer. If the buffer would overflow,
|
||||
// it evicts the oldest data to make room for new data.
|
||||
func (rb *ringBuffer) Write(data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
capacity := len(rb.buffer)
|
||||
|
||||
// If data is larger than capacity, only keep the last capacity bytes
|
||||
if len(data) > capacity {
|
||||
data = data[len(data)-capacity:]
|
||||
// Clear buffer and write new data
|
||||
rb.start = 0
|
||||
rb.end = -1 // Will be set properly below
|
||||
}
|
||||
|
||||
// Calculate how much we need to evict to fit new data
|
||||
spaceNeeded := len(data)
|
||||
availableSpace := capacity - rb.Size()
|
||||
|
||||
if spaceNeeded > availableSpace {
|
||||
bytesToEvict := spaceNeeded - availableSpace
|
||||
rb.evict(bytesToEvict)
|
||||
}
|
||||
|
||||
// Buffer has data, write after current end
|
||||
writePos := (rb.end + 1) % capacity
|
||||
if writePos+len(data) <= capacity {
|
||||
// No wrap needed - single copy
|
||||
copy(rb.buffer[writePos:], data)
|
||||
rb.end = (rb.end + len(data)) % capacity
|
||||
} else {
|
||||
// Need to wrap around - two copies
|
||||
firstChunk := capacity - writePos
|
||||
copy(rb.buffer[writePos:], data[:firstChunk])
|
||||
copy(rb.buffer[0:], data[firstChunk:])
|
||||
rb.end = len(data) - firstChunk - 1
|
||||
}
|
||||
}
|
||||
|
||||
// evict removes the specified number of bytes from the beginning of the buffer.
|
||||
func (rb *ringBuffer) evict(count int) {
|
||||
if count >= rb.Size() {
|
||||
// Evict everything
|
||||
rb.start = 0
|
||||
rb.end = -1
|
||||
return
|
||||
}
|
||||
|
||||
rb.start = (rb.start + count) % len(rb.buffer)
|
||||
// Buffer remains non-empty after partial eviction
|
||||
}
|
||||
|
||||
// ReadLast returns the last n bytes from the buffer.
|
||||
// If n is greater than the available data, returns an error.
|
||||
// If n is negative, returns an error.
|
||||
func (rb *ringBuffer) ReadLast(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, xerrors.New("cannot read negative number of bytes")
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
size := rb.Size()
|
||||
|
||||
// If requested more than available, return error
|
||||
if n > size {
|
||||
return nil, xerrors.Errorf("requested %d bytes but only %d available", n, size)
|
||||
}
|
||||
|
||||
result := make([]byte, n)
|
||||
capacity := len(rb.buffer)
|
||||
|
||||
// Calculate where to start reading from (n bytes before the end)
|
||||
startOffset := size - n
|
||||
actualStart := (rb.start + startOffset) % capacity
|
||||
|
||||
// Copy the last n bytes
|
||||
if actualStart+n <= capacity {
|
||||
// No wrap needed
|
||||
copy(result, rb.buffer[actualStart:actualStart+n])
|
||||
} else {
|
||||
// Need to wrap around
|
||||
firstChunk := capacity - actualStart
|
||||
copy(result[0:firstChunk], rb.buffer[actualStart:capacity])
|
||||
copy(result[firstChunk:], rb.buffer[0:n-firstChunk])
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,261 +0,0 @@
|
||||
package backedpipe
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Don't run goleak on windows tests, they're super flaky right now.
|
||||
// See: https://github.com/coder/coder/issues/8954
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
|
||||
}
|
||||
|
||||
func TestRingBuffer_NewRingBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(100)
|
||||
// Test that we can write and read from the buffer
|
||||
rb.Write([]byte("test"))
|
||||
|
||||
data, err := rb.ReadLast(4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("test"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_WriteAndRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(10)
|
||||
|
||||
// Write some data
|
||||
rb.Write([]byte("hello"))
|
||||
|
||||
// Read last 4 bytes
|
||||
data, err := rb.ReadLast(4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ello", string(data))
|
||||
|
||||
// Write more data
|
||||
rb.Write([]byte("world"))
|
||||
|
||||
// Read last 5 bytes
|
||||
data, err = rb.ReadLast(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "world", string(data))
|
||||
|
||||
// Read last 3 bytes
|
||||
data, err = rb.ReadLast(3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "rld", string(data))
|
||||
|
||||
// Read more than available (should be 10 bytes total)
|
||||
_, err = rb.ReadLast(15)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requested 15 bytes but only")
|
||||
}
|
||||
|
||||
func TestRingBuffer_OverflowEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(5)
|
||||
|
||||
// Fill buffer
|
||||
rb.Write([]byte("abcde"))
|
||||
|
||||
// Overflow should evict oldest data
|
||||
rb.Write([]byte("fg"))
|
||||
|
||||
// Should now contain "cdefg"
|
||||
data, err := rb.ReadLast(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("cdefg"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_LargeWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(5)
|
||||
|
||||
// Write data larger than capacity
|
||||
rb.Write([]byte("abcdefghij"))
|
||||
|
||||
// Should contain last 5 bytes
|
||||
data, err := rb.ReadLast(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("fghij"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_WrapAround(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(5)
|
||||
|
||||
// Fill buffer
|
||||
rb.Write([]byte("abcde"))
|
||||
|
||||
// Write more to cause wrap-around
|
||||
rb.Write([]byte("fgh"))
|
||||
|
||||
// Should contain "defgh"
|
||||
data, err := rb.ReadLast(5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("defgh"), data)
|
||||
|
||||
// Test reading last 3 bytes after wrap
|
||||
data, err = rb.ReadLast(3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("fgh"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_ReadLastEdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(3)
|
||||
|
||||
// Write some data (5 bytes to a 3-byte buffer, so only last 3 bytes remain)
|
||||
rb.Write([]byte("hello"))
|
||||
|
||||
// Test reading negative count
|
||||
data, err := rb.ReadLast(-1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot read negative number of bytes")
|
||||
require.Nil(t, data)
|
||||
|
||||
// Test reading zero bytes
|
||||
data, err = rb.ReadLast(0)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, data)
|
||||
|
||||
// Test reading more than available (buffer has 3 bytes, try to read 10)
|
||||
_, err = rb.ReadLast(10)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requested 10 bytes but only 3 available")
|
||||
|
||||
// Test reading exact amount available
|
||||
data, err = rb.ReadLast(3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("llo"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_EmptyWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(10)
|
||||
|
||||
// Write empty data
|
||||
rb.Write([]byte{})
|
||||
|
||||
// Buffer should still be empty
|
||||
_, err := rb.ReadLast(5)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "requested 5 bytes but only 0 available")
|
||||
}
|
||||
|
||||
func TestRingBuffer_MultipleWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(10)
|
||||
|
||||
// Write data in chunks
|
||||
rb.Write([]byte("ab"))
|
||||
rb.Write([]byte("cd"))
|
||||
rb.Write([]byte("ef"))
|
||||
|
||||
data, err := rb.ReadLast(6)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("abcdef"), data)
|
||||
|
||||
// Test partial reads
|
||||
data, err = rb.ReadLast(4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("cdef"), data)
|
||||
|
||||
data, err = rb.ReadLast(2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ef"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_EdgeCaseEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(3)
|
||||
|
||||
// Write data that will cause eviction
|
||||
rb.Write([]byte("abc"))
|
||||
|
||||
// Write more to cause eviction
|
||||
rb.Write([]byte("d"))
|
||||
|
||||
// Should now contain "bcd"
|
||||
data, err := rb.ReadLast(3)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("bcd"), data)
|
||||
}
|
||||
|
||||
func TestRingBuffer_ComplexWrapAroundScenario(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rb := newRingBuffer(8)
|
||||
|
||||
// Fill buffer
|
||||
rb.Write([]byte("12345678"))
|
||||
|
||||
// Evict some and add more to create complex wrap scenario
|
||||
rb.Write([]byte("abcd"))
|
||||
data, err := rb.ReadLast(8)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("5678abcd"), data)
|
||||
|
||||
// Add more
|
||||
rb.Write([]byte("xyz"))
|
||||
data, err = rb.ReadLast(8)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("8abcdxyz"), data)
|
||||
|
||||
// Test reading various amounts from the end
|
||||
data, err = rb.ReadLast(7)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("abcdxyz"), data)
|
||||
|
||||
data, err = rb.ReadLast(4)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("dxyz"), data)
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkRingBuffer_Write(b *testing.B) {
|
||||
rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks
|
||||
data := bytes.Repeat([]byte("x"), 1024) // 1KB writes
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rb.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRingBuffer_ReadLast(b *testing.B) {
|
||||
rb := newRingBuffer(64 * 1024 * 1024) // 64MB for benchmarks
|
||||
// Fill buffer with test data
|
||||
for i := 0; i < 64; i++ {
|
||||
rb.Write(bytes.Repeat([]byte("x"), 1024))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := rb.ReadLast((i % 100) + 1)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
+78
-69
@@ -11,39 +11,23 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/shirou/gopsutil/v4/disk"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
var WindowsDriveRegex = regexp.MustCompile(`^[a-zA-Z]:\\$`)
|
||||
|
||||
func (a *agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
|
||||
func (*agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// An absolute path may be optionally provided, otherwise a path split into an
|
||||
// array must be provided in the body (which can be relative).
|
||||
query := r.URL.Query()
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
path := parser.String(query, "", "path")
|
||||
parser.ErrorExcessParams(query)
|
||||
if len(parser.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Validations: parser.Errors,
|
||||
})
|
||||
var query LSRequest
|
||||
if !httpapi.Read(ctx, rw, r, &query) {
|
||||
return
|
||||
}
|
||||
|
||||
var req workspacesdk.LSRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := listFiles(a.filesystem, path, req)
|
||||
resp, err := listFiles(query)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
switch {
|
||||
@@ -62,66 +46,58 @@ func (a *agent) HandleLS(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspacesdk.LSResponse, error) {
|
||||
absolutePathString := path
|
||||
if absolutePathString != "" {
|
||||
if !filepath.IsAbs(path) {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("path must be absolute: %q", path)
|
||||
}
|
||||
} else {
|
||||
var fullPath []string
|
||||
switch query.Relativity {
|
||||
case workspacesdk.LSRelativityHome:
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
fullPath = []string{home}
|
||||
case workspacesdk.LSRelativityRoot:
|
||||
if runtime.GOOS == "windows" {
|
||||
if len(query.Path) == 0 {
|
||||
return listDrives()
|
||||
}
|
||||
if !WindowsDriveRegex.MatchString(query.Path[0]) {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("invalid drive letter %q", query.Path[0])
|
||||
}
|
||||
} else {
|
||||
fullPath = []string{"/"}
|
||||
}
|
||||
default:
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("unsupported relativity type %q", query.Relativity)
|
||||
}
|
||||
|
||||
fullPath = append(fullPath, query.Path...)
|
||||
fullPathRelative := filepath.Join(fullPath...)
|
||||
var err error
|
||||
absolutePathString, err = filepath.Abs(fullPathRelative)
|
||||
func listFiles(query LSRequest) (LSResponse, error) {
|
||||
var fullPath []string
|
||||
switch query.Relativity {
|
||||
case LSRelativityHome:
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get absolute path of %q: %w", fullPathRelative, err)
|
||||
return LSResponse{}, xerrors.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
fullPath = []string{home}
|
||||
case LSRelativityRoot:
|
||||
if runtime.GOOS == "windows" {
|
||||
if len(query.Path) == 0 {
|
||||
return listDrives()
|
||||
}
|
||||
if !WindowsDriveRegex.MatchString(query.Path[0]) {
|
||||
return LSResponse{}, xerrors.Errorf("invalid drive letter %q", query.Path[0])
|
||||
}
|
||||
} else {
|
||||
fullPath = []string{"/"}
|
||||
}
|
||||
default:
|
||||
return LSResponse{}, xerrors.Errorf("unsupported relativity type %q", query.Relativity)
|
||||
}
|
||||
|
||||
fullPath = append(fullPath, query.Path...)
|
||||
fullPathRelative := filepath.Join(fullPath...)
|
||||
absolutePathString, err := filepath.Abs(fullPathRelative)
|
||||
if err != nil {
|
||||
return LSResponse{}, xerrors.Errorf("failed to get absolute path of %q: %w", fullPathRelative, err)
|
||||
}
|
||||
|
||||
// codeql[go/path-injection] - The intent is to allow the user to navigate to any directory in their workspace.
|
||||
f, err := fs.Open(absolutePathString)
|
||||
f, err := os.Open(absolutePathString)
|
||||
if err != nil {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to open directory %q: %w", absolutePathString, err)
|
||||
return LSResponse{}, xerrors.Errorf("failed to open directory %q: %w", absolutePathString, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to stat directory %q: %w", absolutePathString, err)
|
||||
return LSResponse{}, xerrors.Errorf("failed to stat directory %q: %w", absolutePathString, err)
|
||||
}
|
||||
|
||||
if !stat.IsDir() {
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("path %q is not a directory", absolutePathString)
|
||||
return LSResponse{}, xerrors.Errorf("path %q is not a directory", absolutePathString)
|
||||
}
|
||||
|
||||
// `contents` may be partially populated even if the operation fails midway.
|
||||
contents, _ := f.Readdir(-1)
|
||||
respContents := make([]workspacesdk.LSFile, 0, len(contents))
|
||||
contents, _ := f.ReadDir(-1)
|
||||
respContents := make([]LSFile, 0, len(contents))
|
||||
for _, file := range contents {
|
||||
respContents = append(respContents, workspacesdk.LSFile{
|
||||
respContents = append(respContents, LSFile{
|
||||
Name: file.Name(),
|
||||
AbsolutePathString: filepath.Join(absolutePathString, file.Name()),
|
||||
IsDir: file.IsDir(),
|
||||
@@ -129,7 +105,7 @@ func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspac
|
||||
}
|
||||
|
||||
// Sort alphabetically: directories then files
|
||||
slices.SortFunc(respContents, func(a, b workspacesdk.LSFile) int {
|
||||
slices.SortFunc(respContents, func(a, b LSFile) int {
|
||||
if a.IsDir && !b.IsDir {
|
||||
return -1
|
||||
}
|
||||
@@ -141,35 +117,35 @@ func listFiles(fs afero.Fs, path string, query workspacesdk.LSRequest) (workspac
|
||||
|
||||
absolutePath := pathToArray(absolutePathString)
|
||||
|
||||
return workspacesdk.LSResponse{
|
||||
return LSResponse{
|
||||
AbsolutePath: absolutePath,
|
||||
AbsolutePathString: absolutePathString,
|
||||
Contents: respContents,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func listDrives() (workspacesdk.LSResponse, error) {
|
||||
func listDrives() (LSResponse, error) {
|
||||
// disk.Partitions() will return partitions even if there was a failure to
|
||||
// get one. Any errored partitions will not be returned.
|
||||
partitionStats, err := disk.Partitions(true)
|
||||
if err != nil && len(partitionStats) == 0 {
|
||||
// Only return the error if there were no partitions returned.
|
||||
return workspacesdk.LSResponse{}, xerrors.Errorf("failed to get partitions: %w", err)
|
||||
return LSResponse{}, xerrors.Errorf("failed to get partitions: %w", err)
|
||||
}
|
||||
|
||||
contents := make([]workspacesdk.LSFile, 0, len(partitionStats))
|
||||
contents := make([]LSFile, 0, len(partitionStats))
|
||||
for _, a := range partitionStats {
|
||||
// Drive letters on Windows have a trailing separator as part of their name.
|
||||
// i.e. `os.Open("C:")` does not work, but `os.Open("C:\\")` does.
|
||||
name := a.Mountpoint + string(os.PathSeparator)
|
||||
contents = append(contents, workspacesdk.LSFile{
|
||||
contents = append(contents, LSFile{
|
||||
Name: name,
|
||||
AbsolutePathString: name,
|
||||
IsDir: true,
|
||||
})
|
||||
}
|
||||
|
||||
return workspacesdk.LSResponse{
|
||||
return LSResponse{
|
||||
AbsolutePath: []string{},
|
||||
AbsolutePathString: "",
|
||||
Contents: contents,
|
||||
@@ -187,3 +163,36 @@ func pathToArray(path string) []string {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type LSRequest struct {
|
||||
// e.g. [], ["repos", "coder"],
|
||||
Path []string `json:"path"`
|
||||
// Whether the supplied path is relative to the user's home directory,
|
||||
// or the root directory.
|
||||
Relativity LSRelativity `json:"relativity"`
|
||||
}
|
||||
|
||||
type LSResponse struct {
|
||||
AbsolutePath []string `json:"absolute_path"`
|
||||
// Returned so clients can display the full path to the user, and
|
||||
// copy it to configure file sync
|
||||
// e.g. Windows: "C:\\Users\\coder"
|
||||
// Linux: "/home/coder"
|
||||
AbsolutePathString string `json:"absolute_path_string"`
|
||||
Contents []LSFile `json:"contents"`
|
||||
}
|
||||
|
||||
type LSFile struct {
|
||||
Name string `json:"name"`
|
||||
// e.g. "C:\\Users\\coder\\hello.txt"
|
||||
// "/home/coder/hello.txt"
|
||||
AbsolutePathString string `json:"absolute_path_string"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
}
|
||||
|
||||
type LSRelativity string
|
||||
|
||||
const (
|
||||
LSRelativityRoot LSRelativity = "root"
|
||||
LSRelativityHome LSRelativity = "home"
|
||||
)
|
||||
|
||||
+38
-76
@@ -6,103 +6,67 @@ import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type testFs struct {
|
||||
afero.Fs
|
||||
}
|
||||
|
||||
func newTestFs(base afero.Fs) *testFs {
|
||||
return &testFs{
|
||||
Fs: base,
|
||||
}
|
||||
}
|
||||
|
||||
func (*testFs) Open(name string) (afero.File, error) {
|
||||
return nil, os.ErrPermission
|
||||
}
|
||||
|
||||
func TestListFilesWithQueryParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
query := workspacesdk.LSRequest{}
|
||||
_, err := listFiles(fs, "not-relative", query)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "must be absolute")
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
err = fs.MkdirAll(tmpDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := listFiles(fs, tmpDir, query)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Contents, 0)
|
||||
}
|
||||
|
||||
func TestListFilesNonExistentDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
query := workspacesdk.LSRequest{
|
||||
query := LSRequest{
|
||||
Path: []string{"idontexist"},
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
Relativity: LSRelativityHome,
|
||||
}
|
||||
_, err := listFiles(fs, "", query)
|
||||
_, err := listFiles(query)
|
||||
require.ErrorIs(t, err, os.ErrNotExist)
|
||||
}
|
||||
|
||||
func TestListFilesPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fs := newTestFs(afero.NewMemMapFs())
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("creating an unreadable-by-user directory is non-trivial on Windows")
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
reposDir := filepath.Join(tmpDir, "repos")
|
||||
err = fs.MkdirAll(reposDir, 0o000)
|
||||
err = os.Mkdir(reposDir, 0o000)
|
||||
require.NoError(t, err)
|
||||
|
||||
rel, err := filepath.Rel(home, reposDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
query := workspacesdk.LSRequest{
|
||||
query := LSRequest{
|
||||
Path: pathToArray(rel),
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
Relativity: LSRelativityHome,
|
||||
}
|
||||
_, err = listFiles(fs, "", query)
|
||||
_, err = listFiles(query)
|
||||
require.ErrorIs(t, err, os.ErrPermission)
|
||||
}
|
||||
|
||||
func TestListFilesNotADirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
err = fs.MkdirAll(tmpDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filePath := filepath.Join(tmpDir, "file.txt")
|
||||
err = afero.WriteFile(fs, filePath, []byte("content"), 0o600)
|
||||
err = os.WriteFile(filePath, []byte("content"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
rel, err := filepath.Rel(home, filePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
query := workspacesdk.LSRequest{
|
||||
query := LSRequest{
|
||||
Path: pathToArray(rel),
|
||||
Relativity: workspacesdk.LSRelativityHome,
|
||||
Relativity: LSRelativityHome,
|
||||
}
|
||||
_, err = listFiles(fs, "", query)
|
||||
_, err = listFiles(query)
|
||||
require.ErrorContains(t, err, "is not a directory")
|
||||
}
|
||||
|
||||
@@ -112,7 +76,7 @@ func TestListFilesSuccess(t *testing.T) {
|
||||
tc := []struct {
|
||||
name string
|
||||
baseFunc func(t *testing.T) string
|
||||
relativity workspacesdk.LSRelativity
|
||||
relativity LSRelativity
|
||||
}{
|
||||
{
|
||||
name: "home",
|
||||
@@ -121,7 +85,7 @@ func TestListFilesSuccess(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
return home
|
||||
},
|
||||
relativity: workspacesdk.LSRelativityHome,
|
||||
relativity: LSRelativityHome,
|
||||
},
|
||||
{
|
||||
name: "root",
|
||||
@@ -131,7 +95,7 @@ func TestListFilesSuccess(t *testing.T) {
|
||||
}
|
||||
return "/"
|
||||
},
|
||||
relativity: workspacesdk.LSRelativityRoot,
|
||||
relativity: LSRelativityRoot,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -140,20 +104,19 @@ func TestListFilesSuccess(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
base := tc.baseFunc(t)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
reposDir := filepath.Join(tmpDir, "repos")
|
||||
err := fs.MkdirAll(reposDir, 0o755)
|
||||
err := os.Mkdir(reposDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
downloadsDir := filepath.Join(tmpDir, "Downloads")
|
||||
err = fs.MkdirAll(downloadsDir, 0o755)
|
||||
err = os.Mkdir(downloadsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
textFile := filepath.Join(tmpDir, "file.txt")
|
||||
err = afero.WriteFile(fs, textFile, []byte("content"), 0o600)
|
||||
err = os.WriteFile(textFile, []byte("content"), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
var queryComponents []string
|
||||
@@ -166,16 +129,16 @@ func TestListFilesSuccess(t *testing.T) {
|
||||
queryComponents = pathToArray(rel)
|
||||
}
|
||||
|
||||
query := workspacesdk.LSRequest{
|
||||
query := LSRequest{
|
||||
Path: queryComponents,
|
||||
Relativity: tc.relativity,
|
||||
}
|
||||
resp, err := listFiles(fs, "", query)
|
||||
resp, err := listFiles(query)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tmpDir, resp.AbsolutePathString)
|
||||
// Output is sorted
|
||||
require.Equal(t, []workspacesdk.LSFile{
|
||||
require.Equal(t, []LSFile{
|
||||
{
|
||||
Name: "Downloads",
|
||||
AbsolutePathString: downloadsDir,
|
||||
@@ -203,44 +166,43 @@ func TestListFilesListDrives(t *testing.T) {
|
||||
t.Skip("skipping test on non-Windows OS")
|
||||
}
|
||||
|
||||
fs := afero.NewOsFs()
|
||||
query := workspacesdk.LSRequest{
|
||||
query := LSRequest{
|
||||
Path: []string{},
|
||||
Relativity: workspacesdk.LSRelativityRoot,
|
||||
Relativity: LSRelativityRoot,
|
||||
}
|
||||
resp, err := listFiles(fs, "", query)
|
||||
resp, err := listFiles(query)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Contents, workspacesdk.LSFile{
|
||||
require.Contains(t, resp.Contents, LSFile{
|
||||
Name: "C:\\",
|
||||
AbsolutePathString: "C:\\",
|
||||
IsDir: true,
|
||||
})
|
||||
|
||||
query = workspacesdk.LSRequest{
|
||||
query = LSRequest{
|
||||
Path: []string{"C:\\"},
|
||||
Relativity: workspacesdk.LSRelativityRoot,
|
||||
Relativity: LSRelativityRoot,
|
||||
}
|
||||
resp, err = listFiles(fs, "", query)
|
||||
resp, err = listFiles(query)
|
||||
require.NoError(t, err)
|
||||
|
||||
query = workspacesdk.LSRequest{
|
||||
query = LSRequest{
|
||||
Path: resp.AbsolutePath,
|
||||
Relativity: workspacesdk.LSRelativityRoot,
|
||||
Relativity: LSRelativityRoot,
|
||||
}
|
||||
resp, err = listFiles(fs, "", query)
|
||||
resp, err = listFiles(query)
|
||||
require.NoError(t, err)
|
||||
// System directory should always exist
|
||||
require.Contains(t, resp.Contents, workspacesdk.LSFile{
|
||||
require.Contains(t, resp.Contents, LSFile{
|
||||
Name: "Windows",
|
||||
AbsolutePathString: "C:\\Windows",
|
||||
IsDir: true,
|
||||
})
|
||||
|
||||
query = workspacesdk.LSRequest{
|
||||
query = LSRequest{
|
||||
// Network drives are not supported.
|
||||
Path: []string{"\\sshfs\\work"},
|
||||
Relativity: workspacesdk.LSRelativityRoot,
|
||||
Relativity: LSRelativityRoot,
|
||||
}
|
||||
resp, err = listFiles(fs, "", query)
|
||||
resp, err = listFiles(query)
|
||||
require.ErrorContains(t, err, "drive")
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
|
||||
// screenReconnectingPTY provides a reconnectable PTY via `screen`.
|
||||
type screenReconnectingPTY struct {
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
command *pty.Cmd
|
||||
|
||||
@@ -63,7 +62,6 @@ type screenReconnectingPTY struct {
|
||||
// own which causes it to spawn with the specified size.
|
||||
func newScreen(ctx context.Context, logger slog.Logger, execer agentexec.Execer, cmd *pty.Cmd, options *Options) *screenReconnectingPTY {
|
||||
rpty := &screenReconnectingPTY{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
command: cmd,
|
||||
metrics: options.Metrics,
|
||||
@@ -175,7 +173,6 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
|
||||
|
||||
ptty, process, err := rpty.doAttach(ctx, conn, height, width, logger)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "unable to attach to screen reconnecting pty", slog.Error(err))
|
||||
if errors.Is(err, context.Canceled) {
|
||||
// Likely the process was too short-lived and canceled the version command.
|
||||
// TODO: Is it worth distinguishing between that and a cancel from the
|
||||
@@ -185,7 +182,6 @@ func (rpty *screenReconnectingPTY) Attach(ctx context.Context, _ string, conn ne
|
||||
}
|
||||
return err
|
||||
}
|
||||
logger.Debug(ctx, "attached to screen reconnecting pty")
|
||||
|
||||
defer func() {
|
||||
// Log only for debugging since the process might have already exited on its
|
||||
@@ -407,7 +403,6 @@ func (rpty *screenReconnectingPTY) Wait() {
|
||||
}
|
||||
|
||||
func (rpty *screenReconnectingPTY) Close(err error) {
|
||||
rpty.logger.Debug(context.Background(), "closing screen reconnecting pty", slog.Error(err))
|
||||
// The closing state change will be handled by the lifecycle.
|
||||
rpty.state.setState(StateClosing, err)
|
||||
}
|
||||
|
||||
+11
-5
@@ -6,7 +6,10 @@
|
||||
"defaultBranch": "main"
|
||||
},
|
||||
"files": {
|
||||
"includes": ["**", "!**/pnpm-lock.yaml"],
|
||||
"includes": [
|
||||
"**",
|
||||
"!**/pnpm-lock.yaml"
|
||||
],
|
||||
"ignoreUnknown": true
|
||||
},
|
||||
"linter": {
|
||||
@@ -45,14 +48,13 @@
|
||||
"options": {
|
||||
"paths": {
|
||||
"@mui/material": "Use @mui/material/<name> instead. See: https://material-ui.com/guides/minimizing-bundle-size/.",
|
||||
"@mui/icons-material": "Use @mui/icons-material/<name> instead. See: https://material-ui.com/guides/minimizing-bundle-size/.",
|
||||
"@mui/material/Avatar": "Use components/Avatar/Avatar instead.",
|
||||
"@mui/material/Alert": "Use components/Alert/Alert instead.",
|
||||
"@mui/material/Popover": "Use components/Popover/Popover instead.",
|
||||
"@mui/material/Typography": "Use native HTML elements instead. Eg: <span>, <p>, <h1>, etc.",
|
||||
"@mui/material/Box": "Use a <div> instead.",
|
||||
"@mui/material/Button": "Use a components/Button/Button instead.",
|
||||
"@mui/material/styles": "Import from @emotion/react instead.",
|
||||
"@mui/material/Table*": "Import from components/Table/Table instead.",
|
||||
"lodash": "Use lodash/<name> instead."
|
||||
}
|
||||
}
|
||||
@@ -67,7 +69,11 @@
|
||||
"noConsole": {
|
||||
"level": "error",
|
||||
"options": {
|
||||
"allow": ["error", "info", "warn"]
|
||||
"allow": [
|
||||
"error",
|
||||
"info",
|
||||
"warn"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -76,5 +82,5 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
|
||||
"$schema": "https://biomejs.dev/schemas/2.2.0/schema.json"
|
||||
}
|
||||
|
||||
+91
-12
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
"golang.org/x/xerrors"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
@@ -37,8 +38,9 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
)
|
||||
|
||||
func workspaceAgent() *serpent.Command {
|
||||
func (r *RootCmd) workspaceAgent() *serpent.Command {
|
||||
var (
|
||||
auth string
|
||||
logDir string
|
||||
scriptDataDir string
|
||||
pprofAddress string
|
||||
@@ -57,7 +59,6 @@ func workspaceAgent() *serpent.Command {
|
||||
devcontainerProjectDiscovery bool
|
||||
devcontainerDiscoveryAutostart bool
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "agent",
|
||||
Short: `Starts the Coder workspace agent.`,
|
||||
@@ -175,14 +176,12 @@ func workspaceAgent() *serpent.Command {
|
||||
|
||||
version := buildinfo.Version()
|
||||
logger.Info(ctx, "agent is starting now",
|
||||
slog.F("url", agentAuth.agentURL),
|
||||
slog.F("auth", agentAuth.agentAuth),
|
||||
slog.F("url", r.agentURL),
|
||||
slog.F("auth", auth),
|
||||
slog.F("version", version),
|
||||
)
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
client := agentsdk.New(r.agentURL)
|
||||
client.SDK.SetLogger(logger)
|
||||
// Set a reasonable timeout so requests can't hang forever!
|
||||
// The timeout needs to be reasonably long, because requests
|
||||
@@ -191,7 +190,7 @@ func workspaceAgent() *serpent.Command {
|
||||
client.SDK.HTTPClient.Timeout = 30 * time.Second
|
||||
// Attach header transport so we process --agent-header and
|
||||
// --agent-header-command flags
|
||||
headerTransport, err := headerTransport(ctx, &agentAuth.agentURL, agentHeader, agentHeaderCommand)
|
||||
headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure header transport: %w", err)
|
||||
}
|
||||
@@ -215,6 +214,68 @@ func workspaceAgent() *serpent.Command {
|
||||
ignorePorts[port] = "debug"
|
||||
}
|
||||
|
||||
// exchangeToken returns a session token.
|
||||
// This is abstracted to allow for the same looping condition
|
||||
// regardless of instance identity auth type.
|
||||
var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error)
|
||||
switch auth {
|
||||
case "token":
|
||||
token, _ := inv.ParsedFlags().GetString(varAgentToken)
|
||||
if token == "" {
|
||||
tokenFile, _ := inv.ParsedFlags().GetString(varAgentTokenFile)
|
||||
if tokenFile != "" {
|
||||
tokenBytes, err := os.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read token file %q: %w", tokenFile, err)
|
||||
}
|
||||
token = strings.TrimSpace(string(tokenBytes))
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
return xerrors.Errorf("CODER_AGENT_TOKEN or CODER_AGENT_TOKEN_FILE must be set for token auth")
|
||||
}
|
||||
client.SetSessionToken(token)
|
||||
case "google-instance-identity":
|
||||
// This is *only* done for testing to mock client authentication.
|
||||
// This will never be set in a production scenario.
|
||||
var gcpClient *metadata.Client
|
||||
gcpClientRaw := ctx.Value("gcp-client")
|
||||
if gcpClientRaw != nil {
|
||||
gcpClient, _ = gcpClientRaw.(*metadata.Client)
|
||||
}
|
||||
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
|
||||
return client.AuthGoogleInstanceIdentity(ctx, "", gcpClient)
|
||||
}
|
||||
case "aws-instance-identity":
|
||||
// This is *only* done for testing to mock client authentication.
|
||||
// This will never be set in a production scenario.
|
||||
var awsClient *http.Client
|
||||
awsClientRaw := ctx.Value("aws-client")
|
||||
if awsClientRaw != nil {
|
||||
awsClient, _ = awsClientRaw.(*http.Client)
|
||||
if awsClient != nil {
|
||||
client.SDK.HTTPClient = awsClient
|
||||
}
|
||||
}
|
||||
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
|
||||
return client.AuthAWSInstanceIdentity(ctx)
|
||||
}
|
||||
case "azure-instance-identity":
|
||||
// This is *only* done for testing to mock client authentication.
|
||||
// This will never be set in a production scenario.
|
||||
var azureClient *http.Client
|
||||
azureClientRaw := ctx.Value("azure-client")
|
||||
if azureClientRaw != nil {
|
||||
azureClient, _ = azureClientRaw.(*http.Client)
|
||||
if azureClient != nil {
|
||||
client.SDK.HTTPClient = azureClient
|
||||
}
|
||||
}
|
||||
exchangeToken = func(ctx context.Context) (agentsdk.AuthenticateResponse, error) {
|
||||
return client.AuthAzureInstanceIdentity(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("getting os executable: %w", err)
|
||||
@@ -282,7 +343,18 @@ func workspaceAgent() *serpent.Command {
|
||||
LogDir: logDir,
|
||||
ScriptDataDir: scriptDataDir,
|
||||
// #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
|
||||
TailnetListenPort: uint16(tailnetListenPort),
|
||||
TailnetListenPort: uint16(tailnetListenPort),
|
||||
ExchangeToken: func(ctx context.Context) (string, error) {
|
||||
if exchangeToken == nil {
|
||||
return client.SDK.SessionToken(), nil
|
||||
}
|
||||
resp, err := exchangeToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
client.SetSessionToken(resp.SessionToken)
|
||||
return resp.SessionToken, nil
|
||||
},
|
||||
EnvironmentVariables: environmentVariables,
|
||||
IgnorePorts: ignorePorts,
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
@@ -293,7 +365,7 @@ func workspaceAgent() *serpent.Command {
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithSubAgentURL(r.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
agentcontainers.WithDiscoveryAutostart(devcontainerDiscoveryAutostart),
|
||||
},
|
||||
@@ -328,6 +400,13 @@ func workspaceAgent() *serpent.Command {
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "auth",
|
||||
Default: "token",
|
||||
Description: "Specify the authentication type to use for the agent.",
|
||||
Env: "CODER_AGENT_AUTH",
|
||||
Value: serpent.StringOf(&auth),
|
||||
},
|
||||
{
|
||||
Flag: "log-dir",
|
||||
Default: os.TempDir(),
|
||||
@@ -450,7 +529,7 @@ func workspaceAgent() *serpent.Command {
|
||||
Value: serpent.BoolOf(&devcontainerDiscoveryAutostart),
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,7 +21,10 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@@ -59,6 +64,158 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
}, testutil.WaitLong, testutil.IntervalMedium)
|
||||
})
|
||||
|
||||
t.Run("Azure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instanceID := "instanceidentifier"
|
||||
certificates, metadataClient := coderdtest.NewAzureInstanceIdentity(t, instanceID)
|
||||
db, ps := dbtestutil.NewDB(t,
|
||||
dbtestutil.WithDumpOnFailure(),
|
||||
)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
AzureCertificates: certificates,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
||||
return agents
|
||||
}).Do()
|
||||
|
||||
inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
|
||||
inv = inv.WithContext(
|
||||
//nolint:revive,staticcheck
|
||||
context.WithValue(inv.Context(), "azure-client", metadataClient),
|
||||
)
|
||||
|
||||
ctx := inv.Context()
|
||||
clitest.Start(t, inv)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
||||
MatchResources(matchAgentWithVersion).Wait()
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := workspacesdk.New(client).
|
||||
DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.True(t, dialer.AwaitReachable(ctx))
|
||||
})
|
||||
|
||||
t.Run("AWS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instanceID := "instanceidentifier"
|
||||
certificates, metadataClient := coderdtest.NewAWSInstanceIdentity(t, instanceID)
|
||||
db, ps := dbtestutil.NewDB(t,
|
||||
dbtestutil.WithDumpOnFailure(),
|
||||
)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
AWSCertificates: certificates,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
||||
return agents
|
||||
}).Do()
|
||||
|
||||
inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
|
||||
inv = inv.WithContext(
|
||||
//nolint:revive,staticcheck
|
||||
context.WithValue(inv.Context(), "aws-client", metadataClient),
|
||||
)
|
||||
|
||||
clitest.Start(t, inv)
|
||||
ctx := inv.Context()
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
||||
MatchResources(matchAgentWithVersion).
|
||||
Wait()
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := workspacesdk.New(client).
|
||||
DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.True(t, dialer.AwaitReachable(ctx))
|
||||
})
|
||||
|
||||
t.Run("GoogleCloud", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
instanceID := "instanceidentifier"
|
||||
validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false)
|
||||
db, ps := dbtestutil.NewDB(t,
|
||||
dbtestutil.WithDumpOnFailure(),
|
||||
)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: ps,
|
||||
GoogleTokenValidator: validator,
|
||||
})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: memberUser.ID,
|
||||
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
|
||||
agents[0].Auth = &proto.Agent_InstanceId{InstanceId: instanceID}
|
||||
return agents
|
||||
}).Do()
|
||||
|
||||
inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
|
||||
clitest.SetupConfig(t, member, cfg)
|
||||
|
||||
clitest.Start(t,
|
||||
inv.WithContext(
|
||||
//nolint:revive,staticcheck
|
||||
context.WithValue(inv.Context(), "gcp-client", metadataClient),
|
||||
),
|
||||
)
|
||||
|
||||
ctx := inv.Context()
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
||||
MatchResources(matchAgentWithVersion).
|
||||
Wait()
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
require.True(t, dialer.AwaitReachable(ctx))
|
||||
sshClient, err := dialer.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
key := "CODER_AGENT_TOKEN"
|
||||
command := "sh -c 'echo $" + key + "'"
|
||||
if runtime.GOOS == "windows" {
|
||||
command = "cmd.exe /c echo %" + key + "%"
|
||||
}
|
||||
token, err := session.CombinedOutput(command)
|
||||
require.NoError(t, err)
|
||||
_, err = uuid.Parse(strings.TrimSpace(string(token)))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("PostStartup", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+3
-6
@@ -12,21 +12,18 @@ import (
|
||||
)
|
||||
|
||||
func (r *RootCmd) autoupdate() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "autoupdate <workspace> <always|never>",
|
||||
Short: "Toggle auto-update policy for a workspace",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(2),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy := strings.ToLower(inv.Args[1])
|
||||
err = validateAutoUpdatePolicy(policy)
|
||||
err := validateAutoUpdatePolicy(policy)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("validate policy: %w", err)
|
||||
}
|
||||
|
||||
+1
-26
@@ -53,9 +53,6 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
|
||||
t := time.NewTimer(0)
|
||||
defer t.Stop()
|
||||
|
||||
startTime := time.Now()
|
||||
baseInterval := opts.FetchInterval
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -71,11 +68,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
|
||||
return
|
||||
}
|
||||
fetchedAgent <- fetchAgent{agent: agent}
|
||||
|
||||
// Adjust the interval based on how long we've been waiting.
|
||||
elapsed := time.Since(startTime)
|
||||
currentInterval := GetProgressiveInterval(baseInterval, elapsed)
|
||||
t.Reset(currentInterval)
|
||||
t.Reset(opts.FetchInterval)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -300,24 +293,6 @@ func safeDuration(sw *stageWriter, a, b *time.Time) time.Duration {
|
||||
return a.Sub(*b)
|
||||
}
|
||||
|
||||
// GetProgressiveInterval returns an interval that increases over time.
|
||||
// The interval starts at baseInterval and increases to
|
||||
// a maximum of baseInterval * 16 over time.
|
||||
func GetProgressiveInterval(baseInterval time.Duration, elapsed time.Duration) time.Duration {
|
||||
switch {
|
||||
case elapsed < 60*time.Second:
|
||||
return baseInterval // 500ms for first 60 seconds
|
||||
case elapsed < 2*time.Minute:
|
||||
return baseInterval * 2 // 1s for next 1 minute
|
||||
case elapsed < 5*time.Minute:
|
||||
return baseInterval * 4 // 2s for next 3 minutes
|
||||
case elapsed < 10*time.Minute:
|
||||
return baseInterval * 8 // 4s for next 5 minutes
|
||||
default:
|
||||
return baseInterval * 16 // 8s after 10 minutes
|
||||
}
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
|
||||
func (c closeFunc) Close() error {
|
||||
|
||||
@@ -866,31 +866,3 @@ func TestConnDiagnostics(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProgressiveInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseInterval := 500 * time.Millisecond
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
elapsed time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"first_minute", 30 * time.Second, baseInterval},
|
||||
{"second_minute", 90 * time.Second, baseInterval * 2},
|
||||
{"third_to_fifth_minute", 3 * time.Minute, baseInterval * 4},
|
||||
{"sixth_to_tenth_minute", 7 * time.Minute, baseInterval * 8},
|
||||
{"after_ten_minutes", 15 * time.Minute, baseInterval * 16},
|
||||
{"boundary_first_minute", 59 * time.Second, baseInterval},
|
||||
{"boundary_second_minute", 61 * time.Second, baseInterval * 2},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := cliui.GetProgressiveInterval(baseInterval, tc.elapsed)
|
||||
require.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+3
-5
@@ -236,6 +236,7 @@ func (r *RootCmd) configSSH() *serpent.Command {
|
||||
dryRun bool
|
||||
coderCliPath string
|
||||
)
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "config-ssh",
|
||||
@@ -252,13 +253,9 @@ func (r *RootCmd) configSSH() *serpent.Command {
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
|
||||
if sshConfigOpts.waitEnum != "auto" && sshConfigOpts.skipProxyCommand {
|
||||
@@ -283,6 +280,7 @@ func (r *RootCmd) configSSH() *serpent.Command {
|
||||
out = inv.Stderr
|
||||
}
|
||||
|
||||
var err error
|
||||
coderBinary := coderCliPath
|
||||
if coderBinary == "" {
|
||||
coderBinary, err = currentBinPath(out)
|
||||
|
||||
@@ -135,13 +135,11 @@ func Test_sshConfigSplitOnCoderSection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// This test tries to mimic the behavior of OpenSSH when executing e.g. a ProxyCommand.
|
||||
// nolint:paralleltest
|
||||
// This test tries to mimic the behavior of OpenSSH
|
||||
// when executing e.g. a ProxyCommand.
|
||||
// nolint:tparallel
|
||||
func Test_sshConfigProxyCommandEscape(t *testing.T) {
|
||||
// Don't run this test, or any of its subtests in parallel. The test works by writing a file and then immediately
|
||||
// executing it. Other tests might also exec a subprocess, and if they do in parallel, there is a small race
|
||||
// condition where our file is open when they fork, and remains open while we attempt to execute it, causing
|
||||
// a "text file busy" error.
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
+3
-5
@@ -50,6 +50,7 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
// shares the same name across multiple organizations.
|
||||
orgContext = NewOrganizationContext()
|
||||
)
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "create [workspace]",
|
||||
@@ -60,12 +61,9 @@ func (r *RootCmd) Create(opts CreateOptions) *serpent.Command {
|
||||
Command: "coder create <username>/<workspace_name>",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(r.InitClient(client)),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
workspaceOwner := codersdk.Me
|
||||
if len(inv.Args) >= 1 {
|
||||
workspaceOwner, workspaceName, err = splitNamedWorkspace(inv.Args[0])
|
||||
|
||||
+2
-5
@@ -16,6 +16,7 @@ func (r *RootCmd) deleteWorkspace() *serpent.Command {
|
||||
orphan bool
|
||||
prov buildFlags
|
||||
)
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "delete <workspace>",
|
||||
@@ -28,13 +29,9 @@ func (r *RootCmd) deleteWorkspace() *serpent.Command {
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
package cli
|
||||
|
||||
import "github.com/coder/serpent"
|
||||
|
||||
func (r *RootCmd) expCmd() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "exp",
|
||||
Short: "Internal commands for testing and experimentation. These are prone to breaking changes with no notice.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Hidden: true,
|
||||
Children: []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
r.tasksCommand(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) boundary() *serpent.Command {
|
||||
cmd := boundarycli.BaseCommand() // Package coder/boundary/cli exports a "base command" designed to be integrated as a subcommand.
|
||||
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
|
||||
return cmd
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// Actually testing the functionality of coder/boundary takes place in the
|
||||
// coder/boundary repo, since it's a dependency of coder.
|
||||
// Here we want to test basically that integrating it as a subcommand doesn't break anything.
|
||||
func TestBoundarySubcommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "boundary", "--help")
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
|
||||
go func() {
|
||||
err := inv.WithContext(ctx).Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Expect the --help output to include the short description.
|
||||
// We're simply confirming that `coder boundary --help` ran without a runtime error as
|
||||
// a good chunk of serpents self validation logic happens at runtime.
|
||||
pty.ExpectMatch(boundarycli.BaseCommand().Short)
|
||||
}
|
||||
+9
-15
@@ -56,7 +56,7 @@ func (r *RootCmd) mcpConfigure() *serpent.Command {
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.mcpConfigureClaudeDesktop(),
|
||||
mcpConfigureClaudeCode(),
|
||||
r.mcpConfigureClaudeCode(),
|
||||
r.mcpConfigureCursor(),
|
||||
},
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mcpConfigureClaudeCode() *serpent.Command {
|
||||
func (r *RootCmd) mcpConfigureClaudeCode() *serpent.Command {
|
||||
var (
|
||||
claudeAPIKey string
|
||||
claudeConfigPath string
|
||||
@@ -131,7 +131,6 @@ func mcpConfigureClaudeCode() *serpent.Command {
|
||||
|
||||
deprecatedCoderMCPClaudeAPIKey string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "claude-code <project-directory>",
|
||||
Short: "Configure the Claude Code server. You will need to run this command for each project you want to use. Specify the project directory as the first argument.",
|
||||
@@ -149,7 +148,7 @@ func mcpConfigureClaudeCode() *serpent.Command {
|
||||
binPath = testBinaryName
|
||||
}
|
||||
configureClaudeEnv := map[string]string{}
|
||||
agentClient, err := agentAuth.CreateClient()
|
||||
agentClient, err := r.createAgentClient()
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "failed to create agent client: %s", err)
|
||||
} else {
|
||||
@@ -293,7 +292,6 @@ func mcpConfigureClaudeCode() *serpent.Command {
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -399,20 +397,15 @@ type mcpServer struct {
|
||||
|
||||
func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
var (
|
||||
client = new(codersdk.Client)
|
||||
instructions string
|
||||
allowedTools []string
|
||||
appStatusSlug string
|
||||
aiAgentAPIURL url.URL
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
return &serpent.Command{
|
||||
Use: "server",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.TryInitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lastReport taskReport
|
||||
// Create a queue that skips duplicates and preserves summaries.
|
||||
queue := cliutil.NewQueue[taskReport](512).WithPredicate(func(report taskReport) (taskReport, bool) {
|
||||
@@ -501,7 +494,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
}
|
||||
|
||||
// Try to create an agent client for status reporting. Not validated.
|
||||
agentClient, err := agentAuth.CreateClient()
|
||||
agentClient, err := r.createAgentClient()
|
||||
if err == nil {
|
||||
cliui.Infof(inv.Stderr, "Agent URL : %s", agentClient.SDK.URL.String())
|
||||
srv.agentClient = agentClient
|
||||
@@ -552,6 +545,9 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
return srv.startServer(ctx, inv, instructions, allowedTools)
|
||||
},
|
||||
Short: "Start the Coder MCP server.",
|
||||
Middleware: serpent.Chain(
|
||||
r.TryInitClient(client),
|
||||
),
|
||||
Options: []serpent.Option{
|
||||
{
|
||||
Name: "instructions",
|
||||
@@ -583,8 +579,6 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) {
|
||||
|
||||
+5
-5
@@ -22,17 +22,16 @@ import (
|
||||
)
|
||||
|
||||
func (r *RootCmd) rptyCommand() *serpent.Command {
|
||||
var args handleRPTYArgs
|
||||
var (
|
||||
client = new(codersdk.Client)
|
||||
args handleRPTYArgs
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
if r.disableDirect {
|
||||
return xerrors.New("direct connections are disabled, but you can try websocat ;-)")
|
||||
}
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
args.NamedWorkspace = inv.Args[0]
|
||||
args.Command = inv.Args[1:]
|
||||
return handleRPTY(inv, client, args)
|
||||
@@ -40,6 +39,7 @@ func (r *RootCmd) rptyCommand() *serpent.Command {
|
||||
Long: "Establish an RPTY session with a workspace/agent. This uses the same mechanism as the Web Terminal.",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(1, -1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Options: []serpent.Option{
|
||||
{
|
||||
|
||||
+118
-987
File diff suppressed because it is too large
Load Diff
@@ -1,170 +0,0 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/serpent"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/dynamicparameters"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
)
|
||||
|
||||
const (
|
||||
dynamicParametersTestName = "dynamic-parameters"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestDynamicParameters() *serpent.Command {
|
||||
var (
|
||||
templateName string
|
||||
numEvals int64
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
// This test requires unlimited concurrency
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
)
|
||||
orgContext := NewOrganizationContext()
|
||||
output := &scaletestOutputFlags{}
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "dynamic-parameters",
|
||||
Short: "Generates load on the Coder server evaluating dynamic parameters",
|
||||
Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("could not parse --output flags")
|
||||
}
|
||||
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if templateName == "" {
|
||||
return xerrors.Errorf("template cannot be empty")
|
||||
}
|
||||
|
||||
org, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := dynamicparameters.NewMetrics(reg, "concurrent_evaluations")
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug)
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
tracerProvider, closeTracing, tracingEnabled, err := tracingFlags.provider(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tracer provider: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
// Allow time for traces to flush even if command context is
|
||||
// canceled. This is a no-op if tracing is not enabled.
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...")
|
||||
if err := closeTracing(ctx); err != nil {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err)
|
||||
}
|
||||
// Wait for prometheus metrics to be scraped
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
tracer := tracerProvider.Tracer(scaletestTracerName)
|
||||
|
||||
partitions, err := dynamicparameters.SetupPartitions(ctx, client, org.ID, templateName, numEvals, logger)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("setup dynamic parameters partitions: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(
|
||||
timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}),
|
||||
// there is no cleanup since it's just a connection that we sever.
|
||||
nil)
|
||||
|
||||
for i, part := range partitions {
|
||||
for j := range part.ConcurrentEvaluations {
|
||||
cfg := dynamicparameters.Config{
|
||||
TemplateVersion: part.TemplateVersion.ID,
|
||||
Metrics: metrics,
|
||||
MetricLabelValues: []string{fmt.Sprintf("%d", part.ConcurrentEvaluations)},
|
||||
}
|
||||
var runner harness.Runnable = dynamicparameters.NewRunner(client, cfg)
|
||||
if tracingEnabled {
|
||||
runner = &runnableTraceWrapper{
|
||||
tracer: tracer,
|
||||
spanName: fmt.Sprintf("%s/%d/%d", dynamicParametersTestName, i, j),
|
||||
runner: runner,
|
||||
}
|
||||
}
|
||||
th.AddRun(dynamicParametersTestName, fmt.Sprintf("%d/%d", j, i), runner)
|
||||
}
|
||||
}
|
||||
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness: %w", err)
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "template",
|
||||
Description: "Name of the template to use. If it does not exist, it will be created.",
|
||||
Default: "scaletest-dynamic-parameters",
|
||||
Value: serpent.StringOf(&templateName),
|
||||
},
|
||||
{
|
||||
Flag: "concurrent-evaluations",
|
||||
Description: "Number of concurrent dynamic parameter evaluations to perform.",
|
||||
Default: "100",
|
||||
Value: serpent.Int64Of(&numEvals),
|
||||
},
|
||||
}
|
||||
orgContext.AttachOptions(cmd)
|
||||
output.attach(&cmd.Options)
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
+1
-4
@@ -13,11 +13,8 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.taskCreate(),
|
||||
r.taskDelete(),
|
||||
r.taskList(),
|
||||
r.taskLogs(),
|
||||
r.taskSend(),
|
||||
r.taskCreate(),
|
||||
r.taskStatus(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskCreate() *serpent.Command {
|
||||
var (
|
||||
orgContext = NewOrganizationContext()
|
||||
|
||||
ownerArg string
|
||||
taskName string
|
||||
templateName string
|
||||
templateVersionName string
|
||||
presetName string
|
||||
stdin bool
|
||||
quiet bool
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "create [input]",
|
||||
Short: "Create an experimental task",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Create a task with direct input",
|
||||
Command: "coder exp task create \"Add authentication to the user service\"",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a task with stdin input",
|
||||
Command: "echo \"Add authentication to the user service\" | coder exp task create",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a task with a specific name",
|
||||
Command: "coder exp task create --name task1 \"Add authentication to the user service\"",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a task from a specific template / preset",
|
||||
Command: "coder exp task create --template backend-dev --preset \"My Preset\" \"Add authentication to the user service\"",
|
||||
},
|
||||
Example{
|
||||
Description: "Create a task for another user (requires appropriate permissions)",
|
||||
Command: "coder exp task create --owner user@example.com \"Add authentication to the user service\"",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(0, 1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "name",
|
||||
Flag: "name",
|
||||
Description: "Specify the name of the task. If you do not specify one, a name will be generated for you.",
|
||||
Value: serpent.StringOf(&taskName),
|
||||
Required: false,
|
||||
Default: "",
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Flag: "owner",
|
||||
Description: "Specify the owner of the task. Defaults to the current user.",
|
||||
Value: serpent.StringOf(&ownerArg),
|
||||
Required: false,
|
||||
Default: codersdk.Me,
|
||||
},
|
||||
{
|
||||
Name: "template",
|
||||
Flag: "template",
|
||||
Env: "CODER_TASK_TEMPLATE_NAME",
|
||||
Value: serpent.StringOf(&templateName),
|
||||
},
|
||||
{
|
||||
Name: "template-version",
|
||||
Flag: "template-version",
|
||||
Env: "CODER_TASK_TEMPLATE_VERSION",
|
||||
Value: serpent.StringOf(&templateVersionName),
|
||||
},
|
||||
{
|
||||
Name: "preset",
|
||||
Flag: "preset",
|
||||
Env: "CODER_TASK_PRESET_NAME",
|
||||
Value: serpent.StringOf(&presetName),
|
||||
Default: PresetNone,
|
||||
},
|
||||
{
|
||||
Name: "stdin",
|
||||
Flag: "stdin",
|
||||
Description: "Reads from stdin for the task input.",
|
||||
Value: serpent.BoolOf(&stdin),
|
||||
},
|
||||
{
|
||||
Name: "quiet",
|
||||
Flag: "quiet",
|
||||
FlagShorthand: "q",
|
||||
Description: "Only display the created task's ID.",
|
||||
Value: serpent.BoolOf(&quiet),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
expClient = codersdk.NewExperimentalClient(client)
|
||||
|
||||
taskInput string
|
||||
templateVersionID uuid.UUID
|
||||
templateVersionPresetID uuid.UUID
|
||||
)
|
||||
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get current organization: %w", err)
|
||||
}
|
||||
|
||||
if stdin {
|
||||
bytes, err := io.ReadAll(inv.Stdin)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reading stdin: %w", err)
|
||||
}
|
||||
|
||||
taskInput = string(bytes)
|
||||
} else {
|
||||
if len(inv.Args) != 1 {
|
||||
return xerrors.Errorf("expected an input for task")
|
||||
}
|
||||
|
||||
taskInput = inv.Args[0]
|
||||
}
|
||||
|
||||
if taskInput == "" {
|
||||
return xerrors.Errorf("a task cannot be started with an empty input")
|
||||
}
|
||||
|
||||
switch {
|
||||
case templateName == "":
|
||||
templates, err := client.Templates(ctx, codersdk.TemplateFilter{SearchQuery: "has-ai-task:true", OrganizationID: organization.ID})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list templates: %w", err)
|
||||
}
|
||||
|
||||
if len(templates) == 0 {
|
||||
return xerrors.Errorf("no task templates configured")
|
||||
}
|
||||
|
||||
// When a deployment has only 1 AI task template, we will
|
||||
// allow omitting the template. Otherwise we will require
|
||||
// the user to be explicit with their choice of template.
|
||||
if len(templates) > 1 {
|
||||
templateNames := make([]string, 0, len(templates))
|
||||
for _, template := range templates {
|
||||
templateNames = append(templateNames, template.Name)
|
||||
}
|
||||
|
||||
return xerrors.Errorf("template name not provided, available templates: %s", strings.Join(templateNames, ", "))
|
||||
}
|
||||
|
||||
if templateVersionName != "" {
|
||||
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templates[0].Name, templateVersionName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template version: %w", err)
|
||||
}
|
||||
|
||||
templateVersionID = templateVersion.ID
|
||||
} else {
|
||||
templateVersionID = templates[0].ActiveVersionID
|
||||
}
|
||||
|
||||
case templateVersionName != "":
|
||||
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templateName, templateVersionName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template version: %w", err)
|
||||
}
|
||||
|
||||
templateVersionID = templateVersion.ID
|
||||
|
||||
default:
|
||||
template, err := client.TemplateByName(ctx, organization.ID, templateName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template: %w", err)
|
||||
}
|
||||
|
||||
templateVersionID = template.ActiveVersionID
|
||||
}
|
||||
|
||||
if presetName != PresetNone {
|
||||
templatePresets, err := client.TemplateVersionPresets(ctx, templateVersionID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template presets: %w", err)
|
||||
}
|
||||
|
||||
preset, err := resolvePreset(templatePresets, presetName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve preset: %w", err)
|
||||
}
|
||||
|
||||
templateVersionPresetID = preset.ID
|
||||
}
|
||||
|
||||
task, err := expClient.CreateTask(ctx, ownerArg, codersdk.CreateTaskRequest{
|
||||
Name: taskName,
|
||||
TemplateVersionID: templateVersionID,
|
||||
TemplateVersionPresetID: templateVersionPresetID,
|
||||
Input: taskInput,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create task: %w", err)
|
||||
}
|
||||
|
||||
if quiet {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, task.ID)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(
|
||||
inv.Stdout,
|
||||
"The task %s has been created at %s!\n",
|
||||
cliui.Keyword(task.Name),
|
||||
cliui.Timestamp(task.CreatedAt),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
orgContext.AttachOptions(cmd)
|
||||
return cmd
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/pretty"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskDelete() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "delete <task> [<task> ...]",
|
||||
Short: "Delete experimental tasks",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Delete a single task.",
|
||||
Command: "$ coder exp task delete task1",
|
||||
},
|
||||
Example{
|
||||
Description: "Delete multiple tasks.",
|
||||
Command: "$ coder exp task delete task1 task2 task3",
|
||||
},
|
||||
Example{
|
||||
Description: "Delete a task without confirmation.",
|
||||
Command: "$ coder exp task delete task4 --yes",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(1, -1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
|
||||
type toDelete struct {
|
||||
ID uuid.UUID
|
||||
Owner string
|
||||
Display string
|
||||
}
|
||||
|
||||
var items []toDelete
|
||||
for _, identifier := range inv.Args {
|
||||
identifier = strings.TrimSpace(identifier)
|
||||
if identifier == "" {
|
||||
return xerrors.New("task identifier cannot be empty or whitespace")
|
||||
}
|
||||
|
||||
// Check task identifier, try UUID first.
|
||||
if id, err := uuid.Parse(identifier); err == nil {
|
||||
task, err := exp.TaskByID(ctx, id)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", identifier, err)
|
||||
}
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
items = append(items, toDelete{ID: id, Display: display, Owner: task.OwnerName})
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-UUID, treat as a workspace identifier (name or owner/name).
|
||||
ws, err := namedWorkspace(ctx, client, identifier)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", identifier, err)
|
||||
}
|
||||
display := ws.FullName()
|
||||
items = append(items, toDelete{ID: ws.ID, Display: display, Owner: ws.OwnerName})
|
||||
}
|
||||
|
||||
// Confirm deletion of the tasks.
|
||||
var displayList []string
|
||||
for _, it := range items {
|
||||
displayList = append(displayList, it.Display)
|
||||
}
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Delete these tasks: %s?", pretty.Sprint(cliui.DefaultStyles.Code, strings.Join(displayList, ", "))),
|
||||
IsConfirm: true,
|
||||
Default: cliui.ConfirmNo,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
if err := exp.DeleteTask(ctx, item.Owner, item.ID); err != nil {
|
||||
return xerrors.Errorf("delete task %q: %w", item.Display, err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(
|
||||
inv.Stdout, "Deleted task "+pretty.Sprint(cliui.DefaultStyles.Keyword, item.Display)+" at "+cliui.Timestamp(time.Now()),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExpTaskDelete(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testCounters struct {
|
||||
deleteCalls atomic.Int64
|
||||
nameResolves atomic.Int64
|
||||
}
|
||||
type handlerBuilder func(c *testCounters) http.HandlerFunc
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
args []string
|
||||
promptYes bool
|
||||
wantErr bool
|
||||
wantDeleteCalls int64
|
||||
wantNameResolves int64
|
||||
wantDeletedMessage int
|
||||
buildHandler handlerBuilder
|
||||
}
|
||||
|
||||
const (
|
||||
id1 = "11111111-1111-1111-1111-111111111111"
|
||||
id2 = "22222222-2222-2222-2222-222222222222"
|
||||
id3 = "33333333-3333-3333-3333-333333333333"
|
||||
id4 = "44444444-4444-4444-4444-444444444444"
|
||||
id5 = "55555555-5555-5555-5555-555555555555"
|
||||
)
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "Prompted_ByName_OK",
|
||||
args: []string{"exists"},
|
||||
promptYes: true,
|
||||
buildHandler: func(c *testCounters) http.HandlerFunc {
|
||||
taskID := uuid.MustParse(id1)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/exists":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: taskID,
|
||||
Name: "exists",
|
||||
OwnerName: "me",
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id1:
|
||||
c.deleteCalls.Add(1)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
}
|
||||
},
|
||||
wantDeleteCalls: 1,
|
||||
wantNameResolves: 1,
|
||||
},
|
||||
{
|
||||
name: "Prompted_ByUUID_OK",
|
||||
args: []string{id2},
|
||||
promptYes: true,
|
||||
buildHandler: func(c *testCounters) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id2:
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse(id2),
|
||||
OwnerName: "me",
|
||||
Name: "uuid-task",
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id2:
|
||||
c.deleteCalls.Add(1)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
}
|
||||
},
|
||||
wantDeleteCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "Multiple_YesFlag",
|
||||
args: []string{"--yes", "first", id4},
|
||||
buildHandler: func(c *testCounters) http.HandlerFunc {
|
||||
firstID := uuid.MustParse(id3)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/first":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: firstID,
|
||||
Name: "first",
|
||||
OwnerName: "me",
|
||||
})
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/experimental/tasks/me/"+id4:
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse(id4),
|
||||
OwnerName: "me",
|
||||
Name: "uuid-task-2",
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id3:
|
||||
c.deleteCalls.Add(1)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id4:
|
||||
c.deleteCalls.Add(1)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
}
|
||||
},
|
||||
wantDeleteCalls: 2,
|
||||
wantNameResolves: 1,
|
||||
wantDeletedMessage: 2,
|
||||
},
|
||||
{
|
||||
name: "ResolveNameError",
|
||||
args: []string{"doesnotexist"},
|
||||
wantErr: true,
|
||||
buildHandler: func(_ *testCounters) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/doesnotexist":
|
||||
httpapi.ResourceNotFound(w)
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DeleteError",
|
||||
args: []string{"bad"},
|
||||
promptYes: true,
|
||||
wantErr: true,
|
||||
buildHandler: func(c *testCounters) http.HandlerFunc {
|
||||
taskID := uuid.MustParse(id5)
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v2/users/me/workspace/bad":
|
||||
c.nameResolves.Add(1)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Workspace{
|
||||
ID: taskID,
|
||||
Name: "bad",
|
||||
OwnerName: "me",
|
||||
})
|
||||
case r.Method == http.MethodDelete && r.URL.Path == "/api/experimental/tasks/me/"+id5:
|
||||
httpapi.InternalServerError(w, xerrors.New("boom"))
|
||||
default:
|
||||
httpapi.InternalServerError(w, xerrors.New("unwanted path: "+r.Method+" "+r.URL.Path))
|
||||
}
|
||||
}
|
||||
},
|
||||
wantNameResolves: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
var counters testCounters
|
||||
srv := httptest.NewServer(tc.buildHandler(&counters))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
client := codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
|
||||
args := append([]string{"exp", "task", "delete"}, tc.args...)
|
||||
inv, root := clitest.New(t, args...)
|
||||
inv = inv.WithContext(ctx)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
var runErr error
|
||||
var outBuf bytes.Buffer
|
||||
if tc.promptYes {
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatch("Delete these tasks:")
|
||||
pty.WriteLine("yes")
|
||||
runErr = w.Wait()
|
||||
outBuf.Write(pty.ReadAll())
|
||||
} else {
|
||||
inv.Stdout = &outBuf
|
||||
inv.Stderr = &outBuf
|
||||
runErr = inv.Run()
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
require.Error(t, runErr)
|
||||
} else {
|
||||
require.NoError(t, runErr)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.wantDeleteCalls, counters.deleteCalls.Load(), "wrong delete call count")
|
||||
require.Equal(t, tc.wantNameResolves, counters.nameResolves.Load(), "wrong name resolve count")
|
||||
|
||||
if tc.wantDeletedMessage > 0 {
|
||||
output := outBuf.String()
|
||||
require.GreaterOrEqual(t, strings.Count(output, "Deleted task"), tc.wantDeletedMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskLogs() *serpent.Command {
|
||||
formatter := cliui.NewOutputFormatter(
|
||||
cliui.TableFormat(
|
||||
[]codersdk.TaskLogEntry{},
|
||||
[]string{
|
||||
"type",
|
||||
"content",
|
||||
},
|
||||
),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "logs <task>",
|
||||
Short: "Show a task's logs",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Show logs for a given task.",
|
||||
Command: "coder exp task logs task1",
|
||||
}),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
task = inv.Args[0]
|
||||
taskID uuid.UUID
|
||||
)
|
||||
|
||||
if id, err := uuid.Parse(task); err == nil {
|
||||
taskID = id
|
||||
} else {
|
||||
ws, err := namedWorkspace(ctx, client, task)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", task, err)
|
||||
}
|
||||
|
||||
taskID = ws.ID
|
||||
}
|
||||
|
||||
logs, err := exp.TaskLogs(ctx, codersdk.Me, taskID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get task logs: %w", err)
|
||||
}
|
||||
|
||||
out, err := formatter.Format(ctx, logs.Logs)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format task logs: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stdout, out)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func Test_TaskLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testMessages := []agentapisdk.Message{
|
||||
{
|
||||
Id: 0,
|
||||
Role: agentapisdk.RoleUser,
|
||||
Content: "What is 1 + 1?",
|
||||
Time: time.Now().Add(-2 * time.Minute),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Role: agentapisdk.RoleAgent,
|
||||
Content: "2",
|
||||
Time: time.Now().Add(-1 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("ByWorkspaceName_JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client // user already has access to their own workspace
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.Name, "--output", "json")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var logs []codersdk.TaskLogEntry
|
||||
err = json.NewDecoder(strings.NewReader(stdout.String())).Decode(&logs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, logs, 2)
|
||||
require.Equal(t, "What is 1 + 1?", logs[0].Content)
|
||||
require.Equal(t, codersdk.TaskLogTypeInput, logs[0].Type)
|
||||
require.Equal(t, "2", logs[1].Content)
|
||||
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_JSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String(), "--output", "json")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var logs []codersdk.TaskLogEntry
|
||||
err = json.NewDecoder(strings.NewReader(stdout.String())).Decode(&logs)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, logs, 2)
|
||||
require.Equal(t, "What is 1 + 1?", logs[0].Content)
|
||||
require.Equal(t, codersdk.TaskLogTypeInput, logs[0].Type)
|
||||
require.Equal(t, "2", logs[1].Content)
|
||||
require.Equal(t, codersdk.TaskLogTypeOutput, logs[1].Type)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_Table", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
output := stdout.String()
|
||||
require.Contains(t, output, "What is 1 + 1?")
|
||||
require.Contains(t, output, "2")
|
||||
require.Contains(t, output, "input")
|
||||
require.Contains(t, output, "output")
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", "doesnotexist")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", uuid.Nil.String())
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("ErrorFetchingLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "exp", "task", "logs", workspace.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, assert.AnError.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskLogsOK(messages []agentapisdk.Message) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/messages": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"messages": messages,
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskLogsErr(err error) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/messages": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskSend() *serpent.Command {
|
||||
var stdin bool
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "send <task> [<input> | --stdin]",
|
||||
Short: "Send input to a task",
|
||||
Long: FormatExamples(Example{
|
||||
Description: "Send direct input to a task.",
|
||||
Command: "coder exp task send task1 \"Please also add unit tests\"",
|
||||
}, Example{
|
||||
Description: "Send input from stdin to a task.",
|
||||
Command: "echo \"Please also add unit tests\" | coder exp task send task1 --stdin",
|
||||
}),
|
||||
Middleware: serpent.RequireRangeArgs(1, 2),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "stdin",
|
||||
Flag: "stdin",
|
||||
Description: "Reads the input from stdin.",
|
||||
Value: serpent.BoolOf(&stdin),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
exp = codersdk.NewExperimentalClient(client)
|
||||
task = inv.Args[0]
|
||||
|
||||
taskInput string
|
||||
taskID uuid.UUID
|
||||
)
|
||||
|
||||
if stdin {
|
||||
bytes, err := io.ReadAll(inv.Stdin)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reading stdio: %w", err)
|
||||
}
|
||||
|
||||
taskInput = string(bytes)
|
||||
} else {
|
||||
if len(inv.Args) != 2 {
|
||||
return xerrors.Errorf("expected an input for the task")
|
||||
}
|
||||
|
||||
taskInput = inv.Args[1]
|
||||
}
|
||||
|
||||
if id, err := uuid.Parse(task); err == nil {
|
||||
taskID = id
|
||||
} else {
|
||||
ws, err := namedWorkspace(ctx, client, task)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task: %w", err)
|
||||
}
|
||||
|
||||
taskID = ws.ID
|
||||
}
|
||||
|
||||
if err = exp.TaskSend(ctx, codersdk.Me, taskID, codersdk.TaskSendRequest{Input: taskInput}); err != nil {
|
||||
return xerrors.Errorf("send input to task: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
agentapisdk "github.com/coder/agentapi-sdk-go"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ByWorkspaceName_WithArgument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceID_WithArgument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.ID.String(), "carry on with the task")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("ByWorkspaceName_WithStdin", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "--stdin")
|
||||
inv.Stdout = &stdout
|
||||
inv.Stdin = strings.NewReader("carry on with the task")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", "doesnotexist", "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("WorkspaceNotFound_ByID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", uuid.Nil.String(), "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, httpapi.ResourceNotFoundResponse.Message)
|
||||
})
|
||||
|
||||
t.Run("SendError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
userClient, workspace := setupCLITaskTest(ctx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "exp", "task", "send", workspace.Name, "some task input")
|
||||
inv.Stdout = &stdout
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, assert.AnError.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendOK(t *testing.T, expectMessage, returnMessage string) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "stable",
|
||||
})
|
||||
},
|
||||
"/message": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
var msg agentapisdk.PostMessageParams
|
||||
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, expectMessage, msg.Content)
|
||||
message := agentapisdk.Message{
|
||||
Id: 999,
|
||||
Role: agentapisdk.RoleAgent,
|
||||
Content: returnMessage,
|
||||
Time: time.Now(),
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(message)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func fakeAgentAPITaskSendErr(t *testing.T, returnErr error) map[string]http.HandlerFunc {
|
||||
return map[string]http.HandlerFunc{
|
||||
"/status": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "stable",
|
||||
})
|
||||
},
|
||||
"/message": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(returnErr.Error()))
|
||||
},
|
||||
}
|
||||
}
|
||||
+27
-63
@@ -15,13 +15,13 @@ import (
|
||||
|
||||
func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
var (
|
||||
client = new(codersdk.Client)
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
cliui.TableFormat(
|
||||
[]taskStatusRow{},
|
||||
[]string{
|
||||
"state changed",
|
||||
"status",
|
||||
"healthy",
|
||||
"state",
|
||||
"message",
|
||||
},
|
||||
@@ -44,17 +44,7 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
watchIntervalArg time.Duration
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Short: "Show the status of a task.",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Show the status of a given task.",
|
||||
Command: "coder exp task status task1",
|
||||
},
|
||||
Example{
|
||||
Description: "Watch the status of a given task until it completes (idle or stopped).",
|
||||
Command: "coder exp task status task1 --watch",
|
||||
},
|
||||
),
|
||||
Short: "Show the status of a task.",
|
||||
Use: "status",
|
||||
Aliases: []string{"stat"},
|
||||
Options: serpent.OptionSet{
|
||||
@@ -76,13 +66,9 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
client, err := r.InitClient(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := i.Context()
|
||||
ec := codersdk.NewExperimentalClient(client)
|
||||
identifier := i.Args[0]
|
||||
@@ -103,46 +89,44 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
tsr := toStatusRow(task)
|
||||
out, err := formatter.Format(ctx, []taskStatusRow{tsr})
|
||||
out, err := formatter.Format(ctx, toStatusRow(task))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format task status: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintln(i.Stdout, out)
|
||||
|
||||
if !watchArg || taskWatchIsEnded(task) {
|
||||
if !watchArg {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastStatus := task.Status
|
||||
lastState := task.CurrentState
|
||||
t := time.NewTicker(watchIntervalArg)
|
||||
defer t.Stop()
|
||||
// TODO: implement streaming updates instead of polling
|
||||
lastStatusRow := tsr
|
||||
for range t.C {
|
||||
task, err := ec.TaskByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only print if something changed
|
||||
newStatusRow := toStatusRow(task)
|
||||
if !taskStatusRowEqual(lastStatusRow, newStatusRow) {
|
||||
out, err := formatter.Format(ctx, []taskStatusRow{newStatusRow})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format task status: %w", err)
|
||||
}
|
||||
// hack: skip the extra column header from formatter
|
||||
if formatter.FormatID() != cliui.JSONFormat().ID() {
|
||||
out = strings.SplitN(out, "\n", 2)[1]
|
||||
}
|
||||
_, _ = fmt.Fprintln(i.Stdout, out)
|
||||
if lastStatus == task.Status && taskStatusEqual(lastState, task.CurrentState) {
|
||||
continue
|
||||
}
|
||||
out, err := formatter.Format(ctx, toStatusRow(task))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("format task status: %w", err)
|
||||
}
|
||||
// hack: skip the extra column header from formatter
|
||||
if formatter.FormatID() != cliui.JSONFormat().ID() {
|
||||
out = strings.SplitN(out, "\n", 2)[1]
|
||||
}
|
||||
_, _ = fmt.Fprintln(i.Stdout, out)
|
||||
|
||||
if taskWatchIsEnded(task) {
|
||||
if task.Status == codersdk.WorkspaceStatusStopped {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastStatusRow = newStatusRow
|
||||
lastStatus = task.Status
|
||||
lastState = task.CurrentState
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -151,20 +135,14 @@ func (r *RootCmd) taskStatus() *serpent.Command {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func taskWatchIsEnded(task codersdk.Task) bool {
|
||||
if task.Status == codersdk.WorkspaceStatusStopped {
|
||||
func taskStatusEqual(s1, s2 *codersdk.TaskStateEntry) bool {
|
||||
if s1 == nil && s2 == nil {
|
||||
return true
|
||||
}
|
||||
if task.WorkspaceAgentHealth == nil || !task.WorkspaceAgentHealth.Healthy {
|
||||
if s1 == nil || s2 == nil {
|
||||
return false
|
||||
}
|
||||
if task.WorkspaceAgentLifecycle == nil || task.WorkspaceAgentLifecycle.Starting() || task.WorkspaceAgentLifecycle.ShuttingDown() {
|
||||
return false
|
||||
}
|
||||
if task.CurrentState == nil || task.CurrentState.State == codersdk.TaskStateWorking {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return s1.State == s2.State
|
||||
}
|
||||
|
||||
type taskStatusRow struct {
|
||||
@@ -172,36 +150,22 @@ type taskStatusRow struct {
|
||||
ChangedAgo string `json:"-" table:"state changed,default_sort"`
|
||||
Timestamp time.Time `json:"-" table:"-"`
|
||||
TaskStatus string `json:"-" table:"status"`
|
||||
Healthy bool `json:"-" table:"healthy"`
|
||||
TaskState string `json:"-" table:"state"`
|
||||
Message string `json:"-" table:"message"`
|
||||
}
|
||||
|
||||
func taskStatusRowEqual(r1, r2 taskStatusRow) bool {
|
||||
return r1.TaskStatus == r2.TaskStatus &&
|
||||
r1.Healthy == r2.Healthy &&
|
||||
r1.TaskState == r2.TaskState &&
|
||||
r1.Message == r2.Message
|
||||
}
|
||||
|
||||
func toStatusRow(task codersdk.Task) taskStatusRow {
|
||||
func toStatusRow(task codersdk.Task) []taskStatusRow {
|
||||
tsr := taskStatusRow{
|
||||
Task: task,
|
||||
ChangedAgo: time.Since(task.UpdatedAt).Truncate(time.Second).String() + " ago",
|
||||
Timestamp: task.UpdatedAt,
|
||||
TaskStatus: string(task.Status),
|
||||
}
|
||||
tsr.Healthy = task.WorkspaceAgentHealth != nil &&
|
||||
task.WorkspaceAgentHealth.Healthy &&
|
||||
task.WorkspaceAgentLifecycle != nil &&
|
||||
!task.WorkspaceAgentLifecycle.Starting() &&
|
||||
!task.WorkspaceAgentLifecycle.ShuttingDown()
|
||||
|
||||
if task.CurrentState != nil {
|
||||
tsr.ChangedAgo = time.Since(task.CurrentState.Timestamp).Truncate(time.Second).String() + " ago"
|
||||
tsr.Timestamp = task.CurrentState.Timestamp
|
||||
tsr.TaskState = string(task.CurrentState.State)
|
||||
tsr.Message = task.CurrentState.Message
|
||||
}
|
||||
return tsr
|
||||
return []taskStatusRow{tsr}
|
||||
}
|
||||
|
||||
+36
-38
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -64,8 +63,8 @@ func Test_TaskStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
args: []string{"exists"},
|
||||
expectOutput: `STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
0s ago running true working Thinking furiously...`,
|
||||
expectOutput: `STATE CHANGED STATUS STATE MESSAGE
|
||||
0s ago running working Thinking furiously...`,
|
||||
hf: func(ctx context.Context, now time.Time) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
@@ -84,10 +83,6 @@ func Test_TaskStatus(t *testing.T) {
|
||||
Timestamp: now,
|
||||
Message: "Thinking furiously...",
|
||||
},
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
})
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
@@ -98,10 +93,12 @@ func Test_TaskStatus(t *testing.T) {
|
||||
{
|
||||
args: []string{"exists", "--watch"},
|
||||
expectOutput: `
|
||||
STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
4s ago running true
|
||||
3s ago running true working Reticulating splines...
|
||||
2s ago running true complete Splines reticulated successfully!`,
|
||||
STATE CHANGED STATUS STATE MESSAGE
|
||||
4s ago running
|
||||
3s ago running working Reticulating splines...
|
||||
2s ago running completed Splines reticulated successfully!
|
||||
2s ago stopping completed Splines reticulated successfully!
|
||||
2s ago stopped completed Splines reticulated successfully!`,
|
||||
hf: func(ctx context.Context, now time.Time) func(http.ResponseWriter, *http.Request) {
|
||||
var calls atomic.Int64
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -119,21 +116,13 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
Status: codersdk.WorkspaceStatusPending,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-5 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
})
|
||||
case 1:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
})
|
||||
case 2:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
@@ -141,10 +130,6 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateWorking,
|
||||
Timestamp: now.Add(-3 * time.Second),
|
||||
@@ -157,12 +142,32 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
Status: codersdk.WorkspaceStatusRunning,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-4 * time.Second),
|
||||
WorkspaceAgentHealth: &codersdk.WorkspaceAgentHealth{
|
||||
Healthy: true,
|
||||
},
|
||||
WorkspaceAgentLifecycle: ptr.Ref(codersdk.WorkspaceAgentLifecycleReady),
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateComplete,
|
||||
State: codersdk.TaskStateCompleted,
|
||||
Timestamp: now.Add(-2 * time.Second),
|
||||
Message: "Splines reticulated successfully!",
|
||||
},
|
||||
})
|
||||
case 4:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusStopping,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now.Add(-1 * time.Second),
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateCompleted,
|
||||
Timestamp: now.Add(-2 * time.Second),
|
||||
Message: "Splines reticulated successfully!",
|
||||
},
|
||||
})
|
||||
case 5:
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Task{
|
||||
ID: uuid.MustParse("11111111-1111-1111-1111-111111111111"),
|
||||
Status: codersdk.WorkspaceStatusStopped,
|
||||
CreatedAt: now.Add(-5 * time.Second),
|
||||
UpdatedAt: now,
|
||||
CurrentState: &codersdk.TaskStateEntry{
|
||||
State: codersdk.TaskStateCompleted,
|
||||
Timestamp: now.Add(-2 * time.Second),
|
||||
Message: "Splines reticulated successfully!",
|
||||
},
|
||||
@@ -183,17 +188,9 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
"id": "11111111-1111-1111-1111-111111111111",
|
||||
"organization_id": "00000000-0000-0000-0000-000000000000",
|
||||
"owner_id": "00000000-0000-0000-0000-000000000000",
|
||||
"owner_name": "",
|
||||
"name": "",
|
||||
"template_id": "00000000-0000-0000-0000-000000000000",
|
||||
"template_name": "",
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": null,
|
||||
"workspace_agent_id": null,
|
||||
"workspace_agent_lifecycle": null,
|
||||
"workspace_agent_health": null,
|
||||
"workspace_app_id": null,
|
||||
"initial_prompt": "",
|
||||
"status": "running",
|
||||
"current_state": {
|
||||
@@ -239,12 +236,13 @@ STATE CHANGED STATUS HEALTHY STATE MESSAGE
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
now = time.Now().UTC() // TODO: replace with quartz
|
||||
srv = httptest.NewServer(http.HandlerFunc(tc.hf(ctx, now)))
|
||||
client = codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
client = new(codersdk.Client)
|
||||
sb = strings.Builder{}
|
||||
args = []string{"exp", "task", "status", "--watch-interval", testutil.IntervalFast.String()}
|
||||
)
|
||||
|
||||
t.Cleanup(srv.Close)
|
||||
client.URL = testutil.MustURL(t, srv.URL)
|
||||
args = append(args, tc.args...)
|
||||
inv, root := clitest.New(t, args...)
|
||||
inv.Stdout = &sb
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/provisioner/echo"
|
||||
"github.com/coder/coder/v2/provisionersdk/proto"
|
||||
)
|
||||
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Workspace) {
|
||||
t.Helper()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
fakeAPI := startFakeAgentAPI(t, agentAPIHandlers)
|
||||
|
||||
authToken := uuid.NewString()
|
||||
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
|
||||
|
||||
wantPrompt := "test prompt"
|
||||
workspace := coderdtest.CreateWorkspace(t, userClient, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
|
||||
req.RichParameterValues = []codersdk.WorkspaceBuildParameter{
|
||||
{Name: codersdk.AITaskPromptParameterName, Value: wantPrompt},
|
||||
}
|
||||
})
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
return userClient, workspace
|
||||
}
|
||||
|
||||
// createAITaskTemplate creates a template configured for AI tasks with a sidebar app.
|
||||
func createAITaskTemplate(t *testing.T, client *codersdk.Client, orgID uuid.UUID, opts ...aiTemplateOpt) codersdk.Template {
|
||||
t.Helper()
|
||||
|
||||
opt := aiTemplateOpts{
|
||||
authToken: uuid.NewString(),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&opt)
|
||||
}
|
||||
|
||||
taskAppID := uuid.New()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, orgID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionPlan: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Plan{
|
||||
Plan: &proto.PlanComplete{
|
||||
Parameters: []*proto.RichParameter{{Name: codersdk.AITaskPromptParameterName, Type: "string"}},
|
||||
HasAiTasks: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ProvisionApply: []*proto.Response{
|
||||
{
|
||||
Type: &proto.Response_Apply{
|
||||
Apply: &proto.ApplyComplete{
|
||||
Resources: []*proto.Resource{
|
||||
{
|
||||
Name: "example",
|
||||
Type: "aws_instance",
|
||||
Agents: []*proto.Agent{
|
||||
{
|
||||
Id: uuid.NewString(),
|
||||
Name: "example",
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: opt.authToken,
|
||||
},
|
||||
Apps: []*proto.App{
|
||||
{
|
||||
Id: taskAppID.String(),
|
||||
Slug: "task-sidebar",
|
||||
DisplayName: "Task Sidebar",
|
||||
Url: opt.appURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AiTasks: []*proto.AITask{
|
||||
{
|
||||
SidebarApp: &proto.AITaskSidebarApp{
|
||||
Id: taskAppID.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, orgID, version.ID)
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// fakeAgentAPI implements a fake AgentAPI HTTP server for testing.
|
||||
type fakeAgentAPI struct {
|
||||
t *testing.T
|
||||
server *httptest.Server
|
||||
handlers map[string]http.HandlerFunc
|
||||
called map[string]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// startFakeAgentAPI starts an HTTP server that implements the AgentAPI endpoints.
|
||||
// handlers is a map of path -> handler function.
|
||||
func startFakeAgentAPI(t *testing.T, handlers map[string]http.HandlerFunc) *fakeAgentAPI {
|
||||
t.Helper()
|
||||
|
||||
fake := &fakeAgentAPI{
|
||||
t: t,
|
||||
handlers: handlers,
|
||||
called: make(map[string]bool),
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Register all provided handlers with call tracking
|
||||
for path, handler := range handlers {
|
||||
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
fake.mu.Lock()
|
||||
fake.called[path] = true
|
||||
fake.mu.Unlock()
|
||||
handler(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
knownEndpoints := []string{"/status", "/messages", "/message"}
|
||||
for _, endpoint := range knownEndpoints {
|
||||
if handlers[endpoint] == nil {
|
||||
endpoint := endpoint // capture loop variable
|
||||
mux.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected call to %s %s - no handler defined", r.Method, endpoint)
|
||||
})
|
||||
}
|
||||
}
|
||||
// Default handler for unknown endpoints should cause the test to fail.
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatalf("unexpected call to %s %s - no handler defined", r.Method, r.URL.Path)
|
||||
})
|
||||
|
||||
fake.server = httptest.NewServer(mux)
|
||||
|
||||
// Register cleanup to check that all defined handlers were called
|
||||
t.Cleanup(func() {
|
||||
fake.server.Close()
|
||||
fake.mu.Lock()
|
||||
for path := range handlers {
|
||||
if !fake.called[path] {
|
||||
t.Errorf("handler for %s was defined but never called", path)
|
||||
}
|
||||
}
|
||||
})
|
||||
return fake
|
||||
}
|
||||
|
||||
func (f *fakeAgentAPI) URL() string {
|
||||
return f.server.URL
|
||||
}
|
||||
|
||||
type aiTemplateOpts struct {
|
||||
appURL string
|
||||
authToken string
|
||||
}
|
||||
|
||||
type aiTemplateOpt func(*aiTemplateOpts)
|
||||
|
||||
func withSidebarURL(url string) aiTemplateOpt {
|
||||
return func(o *aiTemplateOpts) { o.appURL = url }
|
||||
}
|
||||
|
||||
func withAgentToken(token string) aiTemplateOpt {
|
||||
return func(o *aiTemplateOpts) { o.authToken = token }
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskCreate() *serpent.Command {
|
||||
var (
|
||||
orgContext = NewOrganizationContext()
|
||||
client = new(codersdk.Client)
|
||||
|
||||
templateName string
|
||||
templateVersionName string
|
||||
presetName string
|
||||
taskInput string
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "create [template]",
|
||||
Short: "Create an experimental task",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(0, 1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Flag: "input",
|
||||
Env: "CODER_TASK_INPUT",
|
||||
Value: serpent.StringOf(&taskInput),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Env: "CODER_TASK_TEMPLATE_NAME",
|
||||
Value: serpent.StringOf(&templateName),
|
||||
},
|
||||
{
|
||||
Env: "CODER_TASK_TEMPLATE_VERSION",
|
||||
Value: serpent.StringOf(&templateVersionName),
|
||||
},
|
||||
{
|
||||
Flag: "preset",
|
||||
Env: "CODER_TASK_PRESET_NAME",
|
||||
Value: serpent.StringOf(&presetName),
|
||||
Default: PresetNone,
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
var (
|
||||
ctx = inv.Context()
|
||||
expClient = codersdk.NewExperimentalClient(client)
|
||||
|
||||
templateVersionID uuid.UUID
|
||||
templateVersionPresetID uuid.UUID
|
||||
)
|
||||
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get current organization: %w", err)
|
||||
}
|
||||
|
||||
if len(inv.Args) > 0 {
|
||||
templateName, templateVersionName, _ = strings.Cut(inv.Args[0], "@")
|
||||
}
|
||||
|
||||
if templateName == "" {
|
||||
return xerrors.Errorf("template name not provided")
|
||||
}
|
||||
|
||||
if templateVersionName != "" {
|
||||
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templateName, templateVersionName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template version: %w", err)
|
||||
}
|
||||
|
||||
templateVersionID = templateVersion.ID
|
||||
} else {
|
||||
template, err := client.TemplateByName(ctx, organization.ID, templateName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template: %w", err)
|
||||
}
|
||||
|
||||
templateVersionID = template.ActiveVersionID
|
||||
}
|
||||
|
||||
if presetName != PresetNone {
|
||||
templatePresets, err := client.TemplateVersionPresets(ctx, templateVersionID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get template presets: %w", err)
|
||||
}
|
||||
|
||||
preset, err := resolvePreset(templatePresets, presetName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve preset: %w", err)
|
||||
}
|
||||
|
||||
templateVersionPresetID = preset.ID
|
||||
}
|
||||
|
||||
workspace, err := expClient.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
|
||||
TemplateVersionID: templateVersionID,
|
||||
TemplateVersionPresetID: templateVersionPresetID,
|
||||
Prompt: taskInput,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create task: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(
|
||||
inv.Stdout,
|
||||
"The task %s has been created at %s!\n",
|
||||
cliui.Keyword(workspace.Name),
|
||||
cliui.Timestamp(workspace.CreatedAt),
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
orgContext.AttachOptions(cmd)
|
||||
return cmd
|
||||
}
|
||||
@@ -5,12 +5,14 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
@@ -31,10 +33,9 @@ func TestTaskCreate(t *testing.T) {
|
||||
templateID = uuid.New()
|
||||
templateVersionID = uuid.New()
|
||||
templateVersionPresetID = uuid.New()
|
||||
taskID = uuid.New()
|
||||
)
|
||||
|
||||
templateAndVersionFoundHandler := func(t *testing.T, ctx context.Context, orgID uuid.UUID, templateName, templateVersionName, presetName, prompt, taskName, username string) http.HandlerFunc {
|
||||
templateAndVersionFoundHandler := func(t *testing.T, ctx context.Context, orgID uuid.UUID, templateName, templateVersionName, presetName, prompt string) http.HandlerFunc {
|
||||
t.Helper()
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -45,11 +46,11 @@ func TestTaskCreate(t *testing.T) {
|
||||
ID: orgID,
|
||||
}},
|
||||
})
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/%s/versions/%s", orgID, templateName, templateVersionName):
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template/versions/my-template-version", orgID):
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.TemplateVersion{
|
||||
ID: templateVersionID,
|
||||
})
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/%s", orgID, templateName):
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template", orgID):
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Template{
|
||||
ID: templateID,
|
||||
ActiveVersionID: templateVersionID,
|
||||
@@ -61,21 +62,13 @@ func TestTaskCreate(t *testing.T) {
|
||||
Name: presetName,
|
||||
},
|
||||
})
|
||||
case "/api/v2/templates":
|
||||
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
|
||||
{
|
||||
ID: templateID,
|
||||
Name: templateName,
|
||||
ActiveVersionID: templateVersionID,
|
||||
},
|
||||
})
|
||||
case fmt.Sprintf("/api/experimental/tasks/%s", username):
|
||||
case "/api/experimental/tasks/me":
|
||||
var req codersdk.CreateTaskRequest
|
||||
if !httpapi.Read(ctx, w, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, prompt, req.Input, "prompt mismatch")
|
||||
assert.Equal(t, prompt, req.Prompt, "prompt mismatch")
|
||||
assert.Equal(t, templateVersionID, req.TemplateVersionID, "template version mismatch")
|
||||
|
||||
if presetName == "" {
|
||||
@@ -84,17 +77,10 @@ func TestTaskCreate(t *testing.T) {
|
||||
assert.Equal(t, templateVersionPresetID, req.TemplateVersionPresetID, "template version preset id mismatch")
|
||||
}
|
||||
|
||||
created := codersdk.Task{
|
||||
ID: taskID,
|
||||
Name: taskName,
|
||||
httpapi.Write(ctx, w, http.StatusCreated, codersdk.Workspace{
|
||||
Name: "task-wild-goldfish-27",
|
||||
CreatedAt: taskCreatedAt,
|
||||
}
|
||||
if req.Name != "" {
|
||||
assert.Equal(t, req.Name, taskName, "name mismatch")
|
||||
created.Name = req.Name
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, w, http.StatusCreated, created)
|
||||
})
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
@@ -104,101 +90,71 @@ func TestTaskCreate(t *testing.T) {
|
||||
tests := []struct {
|
||||
args []string
|
||||
env []string
|
||||
stdin string
|
||||
expectError string
|
||||
expectOutput string
|
||||
handler func(t *testing.T, ctx context.Context) http.HandlerFunc
|
||||
}{
|
||||
{
|
||||
args: []string{"--stdin"},
|
||||
stdin: "reads prompt from stdin",
|
||||
args: []string{"my-template@my-template-version", "--input", "my custom prompt", "--org", organizationID.String()},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "reads prompt from stdin", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--owner", "someone-else"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", "someone-else")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"--name", "abc123", "my custom prompt"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("abc123"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "abc123", codersdk.Me)
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "my-template-version", "--org", organizationID.String()},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
|
||||
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
|
||||
env: []string{"CODER_TASK_TEMPLATE_VERSION=my-template-version"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--org", organizationID.String()},
|
||||
args: []string{"--input", "my custom prompt", "--org", organizationID.String()},
|
||||
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
|
||||
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version", "CODER_TASK_INPUT=my custom prompt", "CODER_ORGANIZATION=" + organizationID.String()},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--preset", "my-preset", "--org", organizationID.String()},
|
||||
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template"},
|
||||
args: []string{"my-template", "--input", "my custom prompt", "--preset", "my-preset", "--org", organizationID.String()},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my-template", "--input", "my custom prompt"},
|
||||
env: []string{"CODER_TASK_PRESET_NAME=my-preset"},
|
||||
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "-q"},
|
||||
expectOutput: taskID.String(),
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--preset", "not-real-preset"},
|
||||
args: []string{"my-template", "--input", "my custom prompt", "--preset", "not-real-preset"},
|
||||
expectError: `preset "not-real-preset" not found`,
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt", "task-wild-goldfish-27", codersdk.Me)
|
||||
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "not-real-template-version"},
|
||||
args: []string{"my-template@not-real-template-version", "--input", "my custom prompt"},
|
||||
expectError: httpapi.ResourceNotFoundResponse.Message,
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -209,11 +165,6 @@ func TestTaskCreate(t *testing.T) {
|
||||
ID: organizationID,
|
||||
}},
|
||||
})
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template", organizationID):
|
||||
httpapi.Write(ctx, w, http.StatusOK, codersdk.Template{
|
||||
ID: templateID,
|
||||
ActiveVersionID: templateVersionID,
|
||||
})
|
||||
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template/versions/not-real-template-version", organizationID):
|
||||
httpapi.ResourceNotFound(w)
|
||||
default:
|
||||
@@ -223,7 +174,7 @@ func TestTaskCreate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my custom prompt", "--template", "not-real-template", "--org", organizationID.String()},
|
||||
args: []string{"not-real-template", "--input", "my custom prompt", "--org", organizationID.String()},
|
||||
expectError: httpapi.ResourceNotFoundResponse.Message,
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -243,7 +194,7 @@ func TestTaskCreate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"my-custom-prompt", "--template", "template-in-different-org", "--org", anotherOrganizationID.String()},
|
||||
args: []string{"template-in-different-org", "--input", "my-custom-prompt", "--org", anotherOrganizationID.String()},
|
||||
expectError: httpapi.ResourceNotFoundResponse.Message,
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -263,7 +214,7 @@ func TestTaskCreate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"no-org-prompt"},
|
||||
args: []string{"no-org", "--input", "my-custom-prompt"},
|
||||
expectError: "Must select an organization with --org=<org_name>",
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -276,49 +227,6 @@ func TestTaskCreate(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"no task templates"},
|
||||
expectError: "no task templates configured",
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/organizations":
|
||||
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
|
||||
{MinimalOrganization: codersdk.MinimalOrganization{
|
||||
ID: organizationID,
|
||||
}},
|
||||
})
|
||||
case "/api/v2/templates":
|
||||
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{})
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
args: []string{"no template name provided"},
|
||||
expectError: "template name not provided, available templates: wibble, wobble",
|
||||
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v2/users/me/organizations":
|
||||
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
|
||||
{MinimalOrganization: codersdk.MinimalOrganization{
|
||||
ID: organizationID,
|
||||
}},
|
||||
})
|
||||
case "/api/v2/templates":
|
||||
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
|
||||
{Name: "wibble"},
|
||||
{Name: "wobble"},
|
||||
})
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -328,7 +236,7 @@ func TestTaskCreate(t *testing.T) {
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
srv = httptest.NewServer(tt.handler(t, ctx))
|
||||
client = codersdk.New(testutil.MustURL(t, srv.URL))
|
||||
client = new(codersdk.Client)
|
||||
args = []string{"exp", "task", "create"}
|
||||
sb strings.Builder
|
||||
err error
|
||||
@@ -336,9 +244,11 @@ func TestTaskCreate(t *testing.T) {
|
||||
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
client.URL, err = url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
inv, root := clitest.New(t, append(args, tt.args...)...)
|
||||
inv.Environ = serpent.ParseEnviron(tt.env, "")
|
||||
inv.Stdin = strings.NewReader(tt.stdin)
|
||||
inv.Stdout = &sb
|
||||
inv.Stderr = &sb
|
||||
clitest.SetupConfig(t, client, root)
|
||||
@@ -36,12 +36,13 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
statusFilter string
|
||||
all bool
|
||||
user string
|
||||
quiet bool
|
||||
|
||||
client = new(codersdk.Client)
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
cliui.TableFormat(
|
||||
[]taskListRow{},
|
||||
[]string{
|
||||
"id",
|
||||
"name",
|
||||
"status",
|
||||
"state",
|
||||
@@ -67,33 +68,12 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "list",
|
||||
Short: "List experimental tasks",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "List tasks for the current user.",
|
||||
Command: "coder exp task list",
|
||||
},
|
||||
Example{
|
||||
Description: "List tasks for a specific user.",
|
||||
Command: "coder exp task list --user someone-else",
|
||||
},
|
||||
Example{
|
||||
Description: "List all tasks you can view.",
|
||||
Command: "coder exp task list --all",
|
||||
},
|
||||
Example{
|
||||
Description: "List all your running tasks.",
|
||||
Command: "coder exp task list --status running",
|
||||
},
|
||||
Example{
|
||||
Description: "As above, but only show IDs.",
|
||||
Command: "coder exp task list --status running --quiet",
|
||||
},
|
||||
),
|
||||
Use: "list",
|
||||
Short: "List experimental tasks",
|
||||
Aliases: []string{"ls"},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
@@ -118,21 +98,8 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
Default: "",
|
||||
Value: serpent.StringOf(&user),
|
||||
},
|
||||
{
|
||||
Name: "quiet",
|
||||
Description: "Only display task IDs.",
|
||||
Flag: "quiet",
|
||||
FlagShorthand: "q",
|
||||
Default: "false",
|
||||
Value: serpent.BoolOf(&quiet),
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
exp := codersdk.NewExperimentalClient(client)
|
||||
|
||||
@@ -149,14 +116,6 @@ func (r *RootCmd) taskList() *serpent.Command {
|
||||
return xerrors.Errorf("list tasks: %w", err)
|
||||
}
|
||||
|
||||
if quiet {
|
||||
for _, task := range tasks {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, task.ID.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If no rows and not JSON, show a friendly message.
|
||||
if len(tasks) == 0 && formatter.FormatID() != cliui.JSONFormat().ID() {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No tasks found.")
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -202,43 +200,6 @@ func TestExpTaskList(t *testing.T) {
|
||||
|
||||
pty.ExpectMatch(ws.Name)
|
||||
})
|
||||
|
||||
t.Run("Quiet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Quiet logger to reduce noise.
|
||||
quiet := slog.Make(sloghuman.Sink(io.Discard))
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &quiet})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
// Given: We have two tasks
|
||||
task1 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStart, "keep me running")
|
||||
task2 := makeAITask(t, db, owner.OrganizationID, owner.UserID, memberUser.ID, database.WorkspaceTransitionStop, "stop me please")
|
||||
|
||||
// Given: We add the `--quiet` flag
|
||||
inv, root := clitest.New(t, "exp", "task", "list", "--quiet")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
var stdout bytes.Buffer
|
||||
inv.Stdout = &stdout
|
||||
inv.Stderr = &stdout
|
||||
|
||||
// When: We run the command
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
want := []string{task1.ID.String(), task2.ID.String()}
|
||||
got := slice.Filter(strings.Split(stdout.String(), "\n"), func(s string) bool {
|
||||
return len(s) != 0
|
||||
})
|
||||
|
||||
slices.Sort(want)
|
||||
slices.Sort(got)
|
||||
|
||||
require.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpTaskList_OwnerCanListOthers(t *testing.T) {
|
||||
+14
-9
@@ -2,16 +2,19 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func externalAuth() *serpent.Command {
|
||||
func (r *RootCmd) externalAuth() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "external-auth",
|
||||
Short: "Manage external authentication",
|
||||
@@ -20,15 +23,14 @@ func externalAuth() *serpent.Command {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
externalAuthAccessToken(),
|
||||
r.externalAuthAccessToken(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func externalAuthAccessToken() *serpent.Command {
|
||||
func (r *RootCmd) externalAuthAccessToken() *serpent.Command {
|
||||
var extra string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
return &serpent.Command{
|
||||
Use: "access-token <provider>",
|
||||
Short: "Print auth for an external provider",
|
||||
Long: "Print an access-token for an external auth provider. " +
|
||||
@@ -68,7 +70,12 @@ fi
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if r.agentToken == "" {
|
||||
_, _ = fmt.Fprint(inv.Stderr, pretty.Sprintf(headLineStyle(), "No agent token found, this command must be run from inside a running workspace.\n"))
|
||||
return xerrors.Errorf("agent token not found")
|
||||
}
|
||||
|
||||
client, err := r.tryCreateAgentClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
@@ -108,6 +115,4 @@ fi
|
||||
return nil
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
+5
-10
@@ -5,10 +5,12 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) favorite() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Aliases: []string{"fav", "favou" + "rite"},
|
||||
Annotations: workspaceCommand,
|
||||
@@ -16,13 +18,9 @@ func (r *RootCmd) favorite() *serpent.Command {
|
||||
Short: "Add a workspace to your favorites",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ws, err := namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace: %w", err)
|
||||
@@ -39,6 +37,7 @@ func (r *RootCmd) favorite() *serpent.Command {
|
||||
}
|
||||
|
||||
func (r *RootCmd) unfavorite() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Aliases: []string{"unfav", "unfavou" + "rite"},
|
||||
Annotations: workspaceCommand,
|
||||
@@ -46,13 +45,9 @@ func (r *RootCmd) unfavorite() *serpent.Command {
|
||||
Short: "Remove a workspace from your favorites",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ws, err := namedWorkspace(inv.Context(), client, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace: %w", err)
|
||||
|
||||
+3
-5
@@ -18,8 +18,8 @@ import (
|
||||
|
||||
// gitAskpass is used by the Coder agent to automatically authenticate
|
||||
// with Git providers based on a hostname.
|
||||
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
func (r *RootCmd) gitAskpass() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "gitaskpass",
|
||||
Hidden: true,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
@@ -33,7 +33,7 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("parse host: %w", err)
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
client, err := r.tryCreateAgentClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
@@ -90,6 +90,4 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestGitAskpass(t *testing.T) {
|
||||
@@ -33,7 +32,6 @@ func TestGitAskpass(t *testing.T) {
|
||||
url := srv.URL
|
||||
inv, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':")
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdout = pty.Output()
|
||||
clitest.Start(t, inv)
|
||||
@@ -41,7 +39,6 @@ func TestGitAskpass(t *testing.T) {
|
||||
|
||||
inv, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':")
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
pty = ptytest.New(t)
|
||||
inv.Stdout = pty.Output()
|
||||
clitest.Start(t, inv)
|
||||
@@ -59,7 +56,6 @@ func TestGitAskpass(t *testing.T) {
|
||||
url := srv.URL
|
||||
inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':")
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
pty := ptytest.New(t)
|
||||
inv.Stderr = pty.Output()
|
||||
err := inv.Run()
|
||||
@@ -69,7 +65,6 @@ func TestGitAskpass(t *testing.T) {
|
||||
|
||||
t.Run("Poll", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
resp := atomic.Pointer[agentsdk.ExternalAuthResponse]{}
|
||||
resp.Store(&agentsdk.ExternalAuthResponse{
|
||||
URL: "https://something.org",
|
||||
@@ -91,7 +86,6 @@ func TestGitAskpass(t *testing.T) {
|
||||
|
||||
inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':")
|
||||
inv.Environ.Set("GIT_PREFIX", "/")
|
||||
inv.Environ.Set("CODER_AGENT_TOKEN", "fake-token")
|
||||
stdout := ptytest.New(t)
|
||||
inv.Stdout = stdout.Output()
|
||||
stderr := ptytest.New(t)
|
||||
@@ -100,7 +94,7 @@ func TestGitAskpass(t *testing.T) {
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
testutil.RequireReceive(ctx, t, poll)
|
||||
<-poll
|
||||
stderr.ExpectMatch("Open the following URL to authenticate")
|
||||
resp.Store(&agentsdk.ExternalAuthResponse{
|
||||
Username: "username",
|
||||
|
||||
+3
-4
@@ -18,8 +18,7 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func gitssh() *serpent.Command {
|
||||
agentAuth := &AgentAuth{}
|
||||
func (r *RootCmd) gitssh() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "gitssh",
|
||||
Hidden: true,
|
||||
@@ -39,7 +38,7 @@ func gitssh() *serpent.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
client, err := r.tryCreateAgentClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
@@ -109,7 +108,7 @@ func gitssh() *serpent.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -54,7 +54,8 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*agentsdk.Client, str
|
||||
}).WithAgent().Do()
|
||||
|
||||
// start workspace agent
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
agentClient := agentsdk.New(client.URL)
|
||||
agentClient.SetSessionToken(r.AgentToken)
|
||||
_ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
+3
-30
@@ -95,8 +95,8 @@ func (r *RootCmd) list() *serpent.Command {
|
||||
),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
sharedWithMe bool
|
||||
)
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "list",
|
||||
@@ -104,37 +104,10 @@ func (r *RootCmd) list() *serpent.Command {
|
||||
Aliases: []string{"ls"},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "shared-with-me",
|
||||
Description: "Show workspaces shared with you.",
|
||||
Flag: "shared-with-me",
|
||||
Value: serpent.BoolOf(&sharedWithMe),
|
||||
Hidden: true,
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workspaceFilter := filter.Filter()
|
||||
if sharedWithMe {
|
||||
user, err := client.User(inv.Context(), codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("fetch current user: %w", err)
|
||||
}
|
||||
workspaceFilter.SharedWithUser = user.ID.String()
|
||||
|
||||
// Unset the default query that conflicts with the --shared-with-me flag
|
||||
if workspaceFilter.FilterQuery == "owner:me" {
|
||||
workspaceFilter.FilterQuery = ""
|
||||
}
|
||||
}
|
||||
|
||||
res, err := QueryConvertWorkspaces(inv.Context(), client, workspaceFilter, WorkspaceListRowFromWorkspace)
|
||||
res, err := QueryConvertWorkspaces(inv.Context(), client, filter.Filter(), WorkspaceListRowFromWorkspace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
@@ -101,49 +100,4 @@ func TestList(t *testing.T) {
|
||||
|
||||
require.Len(t, stderr.Bytes(), 0)
|
||||
})
|
||||
|
||||
t.Run("SharedWorkspaces", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
Name: "wibble",
|
||||
OwnerID: orgOwner.UserID,
|
||||
OrganizationID: orgOwner.OrganizationID,
|
||||
}).Do().Workspace
|
||||
_ = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
Name: "wobble",
|
||||
OwnerID: orgOwner.UserID,
|
||||
OrganizationID: orgOwner.OrganizationID,
|
||||
}).Do().Workspace
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
client.UpdateWorkspaceACL(ctx, sharedWorkspace.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
member.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
|
||||
inv, root := clitest.New(t, "list", "--shared-with-me", "--output=json")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
|
||||
stdout := new(bytes.Buffer)
|
||||
inv.Stdout = stdout
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
var workspaces []codersdk.Workspace
|
||||
require.NoError(t, json.Unmarshal(stdout.Bytes(), &workspaces))
|
||||
require.Len(t, workspaces, 1)
|
||||
require.Equal(t, sharedWorkspace.ID, workspaces[0].ID)
|
||||
})
|
||||
}
|
||||
|
||||
+6
-5
@@ -8,23 +8,24 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) logout() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "logout",
|
||||
Short: "Unauthenticate your local session",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var errors []error
|
||||
|
||||
config := r.createConfig()
|
||||
|
||||
var err error
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: "Are you sure you want to log out?",
|
||||
IsConfirm: true,
|
||||
|
||||
+6
-5
@@ -9,21 +9,22 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) netcheck() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "netcheck",
|
||||
Short: "Print network debug information for DERP and STUN",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(inv.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
+10
-53
@@ -16,7 +16,7 @@ func (r *RootCmd) notifications() *serpent.Command {
|
||||
Short: "Manage Coder notifications",
|
||||
Long: "Administrators can use these commands to change notification settings.\n" + FormatExamples(
|
||||
Example{
|
||||
Description: "Pause Coder notifications. Administrators can temporarily stop notifiers from dispatching messages in case of the target outage (for example: unavailable SMTP server or Webhook not responding)",
|
||||
Description: "Pause Coder notifications. Administrators can temporarily stop notifiers from dispatching messages in case of the target outage (for example: unavailable SMTP server or Webhook not responding).",
|
||||
Command: "coder notifications pause",
|
||||
},
|
||||
Example{
|
||||
@@ -24,13 +24,9 @@ func (r *RootCmd) notifications() *serpent.Command {
|
||||
Command: "coder notifications resume",
|
||||
},
|
||||
Example{
|
||||
Description: "Send a test notification. Administrators can use this to verify the notification target settings",
|
||||
Description: "Send a test notification. Administrators can use this to verify the notification target settings.",
|
||||
Command: "coder notifications test",
|
||||
},
|
||||
Example{
|
||||
Description: "Send a custom notification to the requesting user. Sending notifications targeting other users or groups is currently not supported",
|
||||
Command: "coder notifications custom \"Custom Title\" \"Custom Message\"",
|
||||
},
|
||||
),
|
||||
Aliases: []string{"notification"},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
@@ -40,26 +36,22 @@ func (r *RootCmd) notifications() *serpent.Command {
|
||||
r.pauseNotifications(),
|
||||
r.resumeNotifications(),
|
||||
r.testNotifications(),
|
||||
r.customNotifications(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) pauseNotifications() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "pause",
|
||||
Short: "Pause notifications",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.PutNotificationsSettings(inv.Context(), codersdk.NotificationsSettings{
|
||||
err := client.PutNotificationsSettings(inv.Context(), codersdk.NotificationsSettings{
|
||||
NotifierPaused: true,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -74,19 +66,16 @@ func (r *RootCmd) pauseNotifications() *serpent.Command {
|
||||
}
|
||||
|
||||
func (r *RootCmd) resumeNotifications() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "resume",
|
||||
Short: "Resume notifications",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.PutNotificationsSettings(inv.Context(), codersdk.NotificationsSettings{
|
||||
err := client.PutNotificationsSettings(inv.Context(), codersdk.NotificationsSettings{
|
||||
NotifierPaused: false,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -101,18 +90,15 @@ func (r *RootCmd) resumeNotifications() *serpent.Command {
|
||||
}
|
||||
|
||||
func (r *RootCmd) testNotifications() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "test",
|
||||
Short: "Send a test notification",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := client.PostTestNotification(inv.Context()); err != nil {
|
||||
return xerrors.Errorf("unable to post test notification: %w", err)
|
||||
}
|
||||
@@ -123,32 +109,3 @@ func (r *RootCmd) testNotifications() *serpent.Command {
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) customNotifications() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "custom <title> <message>",
|
||||
Short: "Send a custom notification",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(2),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = client.PostCustomNotification(inv.Context(), codersdk.CustomNotificationRequest{
|
||||
Content: &codersdk.CustomNotificationContent{
|
||||
Title: inv.Args[0],
|
||||
Message: inv.Args[1],
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("unable to post custom notification: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "A custom notification has been sent.")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -168,102 +166,3 @@ func TestNotificationsTest(t *testing.T) {
|
||||
require.Len(t, sent, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCustomNotifications(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("BadRequest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
notifyEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t),
|
||||
NotificationsEnqueuer: notifyEnq,
|
||||
})
|
||||
|
||||
// Given: A member user
|
||||
ownerUser := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, ownerUser.OrganizationID)
|
||||
|
||||
// When: The member user attempts to send a custom notification with empty title and message
|
||||
inv, root := clitest.New(t, "notifications", "custom", "", "")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
|
||||
// Then: an error is expected with no notifications sent
|
||||
err := inv.Run()
|
||||
var sdkError *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
|
||||
require.Equal(t, http.StatusBadRequest, sdkError.StatusCode())
|
||||
require.Equal(t, "Invalid request body", sdkError.Message)
|
||||
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTestNotification))
|
||||
require.Len(t, sent, 0)
|
||||
})
|
||||
|
||||
t.Run("SystemUserNotAllowed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
notifyEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
|
||||
ownerClient, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t),
|
||||
NotificationsEnqueuer: notifyEnq,
|
||||
})
|
||||
|
||||
// Given: A system user (prebuilds system user)
|
||||
_, token := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: database.PrebuildsSystemUserID,
|
||||
LoginType: database.LoginTypeNone,
|
||||
})
|
||||
systemUserClient := codersdk.New(ownerClient.URL)
|
||||
systemUserClient.SetSessionToken(token)
|
||||
|
||||
// When: The system user attempts to send a custom notification
|
||||
inv, root := clitest.New(t, "notifications", "custom", "Custom Title", "Custom Message")
|
||||
clitest.SetupConfig(t, systemUserClient, root)
|
||||
|
||||
// Then: an error is expected with no notifications sent
|
||||
err := inv.Run()
|
||||
var sdkError *codersdk.Error
|
||||
require.Error(t, err)
|
||||
require.ErrorAsf(t, err, &sdkError, "error should be of type *codersdk.Error")
|
||||
require.Equal(t, http.StatusForbidden, sdkError.StatusCode())
|
||||
require.Equal(t, "Forbidden", sdkError.Message)
|
||||
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTestNotification))
|
||||
require.Len(t, sent, 0)
|
||||
})
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
notifyEnq := ¬ificationstest.FakeEnqueuer{}
|
||||
|
||||
ownerClient := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t),
|
||||
NotificationsEnqueuer: notifyEnq,
|
||||
})
|
||||
|
||||
// Given: A member user
|
||||
ownerUser := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
memberClient, memberUser := coderdtest.CreateAnotherUser(t, ownerClient, ownerUser.OrganizationID)
|
||||
|
||||
// When: The member user attempts to send a custom notification
|
||||
inv, root := clitest.New(t, "notifications", "custom", "Custom Title", "Custom Message")
|
||||
clitest.SetupConfig(t, memberClient, root)
|
||||
|
||||
// Then: we expect a custom notification to be sent to the member user
|
||||
err := inv.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateCustomNotification))
|
||||
require.Len(t, sent, 1)
|
||||
require.Equal(t, memberUser.ID, sent[0].UserID)
|
||||
require.Len(t, sent[0].Labels, 2)
|
||||
require.Equal(t, "Custom Title", sent[0].Labels["custom_title"])
|
||||
require.Equal(t, "Custom Message", sent[0].Labels["custom_message"])
|
||||
require.Equal(t, memberUser.ID.String(), sent[0].CreatedBy)
|
||||
})
|
||||
}
|
||||
|
||||
+10
-12
@@ -41,25 +41,24 @@ const vscodeDesktopName = "VS Code Desktop"
|
||||
|
||||
func (r *RootCmd) openVSCode() *serpent.Command {
|
||||
var (
|
||||
generateToken bool
|
||||
testOpenError bool
|
||||
generateToken bool
|
||||
testOpenError bool
|
||||
appearanceConfig codersdk.AppearanceConfig
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "vscode <workspace> [<directory in workspace>]",
|
||||
Short: fmt.Sprintf("Open a workspace in %s", vscodeDesktopName),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(1, 2),
|
||||
r.InitClient(client),
|
||||
initAppearance(client, &appearanceConfig),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
appearanceConfig := initAppearance(ctx, client)
|
||||
|
||||
// Check if we're inside a workspace, and especially inside _this_
|
||||
// workspace so we can perform path resolution/expansion. Generally,
|
||||
@@ -300,16 +299,15 @@ func (r *RootCmd) openApp() *serpent.Command {
|
||||
testOpenError bool
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "app <workspace> <app slug>",
|
||||
Short: "Open a workspace application.",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
|
||||
|
||||
+3
-5
@@ -37,6 +37,7 @@ func (r *RootCmd) organizations() *serpent.Command {
|
||||
func (r *RootCmd) showOrganization(orgContext *OrganizationContext) *serpent.Command {
|
||||
var (
|
||||
stringFormat func(orgs []codersdk.Organization) (string, error)
|
||||
client = new(codersdk.Client)
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
cliui.ChangeFormatterData(cliui.TextFormat(), func(data any) (any, error) {
|
||||
typed, ok := data.([]codersdk.Organization)
|
||||
@@ -76,6 +77,7 @@ func (r *RootCmd) showOrganization(orgContext *OrganizationContext) *serpent.Com
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
serpent.RequireRangeArgs(0, 1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
@@ -88,17 +90,13 @@ func (r *RootCmd) showOrganization(orgContext *OrganizationContext) *serpent.Com
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orgArg := "selected"
|
||||
if len(inv.Args) >= 1 {
|
||||
orgArg = inv.Args[0]
|
||||
}
|
||||
|
||||
var orgs []codersdk.Organization
|
||||
var err error
|
||||
switch strings.ToLower(orgArg) {
|
||||
case "selected":
|
||||
stringFormat = func(orgs []codersdk.Organization) (string, error) {
|
||||
|
||||
@@ -12,24 +12,22 @@ import (
|
||||
)
|
||||
|
||||
func (r *RootCmd) createOrganization() *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "create <organization name>",
|
||||
Short: "Create a new organization.",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
orgName := inv.Args[0]
|
||||
|
||||
err = codersdk.NameValid(orgName)
|
||||
err := codersdk.NameValid(orgName)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("organization name %q is invalid: %w", orgName, err)
|
||||
}
|
||||
|
||||
+13
-17
@@ -31,17 +31,16 @@ func (r *RootCmd) organizationMembers(orgContext *OrganizationContext) *serpent.
|
||||
}
|
||||
|
||||
func (r *RootCmd) removeOrganizationMember(orgContext *OrganizationContext) *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "remove <username | user_id>",
|
||||
Short: "Remove a new member to the current organization",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := inv.Context()
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
@@ -63,17 +62,16 @@ func (r *RootCmd) removeOrganizationMember(orgContext *OrganizationContext) *ser
|
||||
}
|
||||
|
||||
func (r *RootCmd) addOrganizationMember(orgContext *OrganizationContext) *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "add <username | user_id>",
|
||||
Short: "Add a new member to the current organization",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := inv.Context()
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
@@ -95,15 +93,16 @@ func (r *RootCmd) addOrganizationMember(orgContext *OrganizationContext) *serpen
|
||||
}
|
||||
|
||||
func (r *RootCmd) assignOrganizationRoles(orgContext *OrganizationContext) *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "edit-roles <username | user_id> [roles...]",
|
||||
Aliases: []string{"edit-role"},
|
||||
Short: "Edit organization member's roles",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := inv.Context()
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
@@ -142,18 +141,15 @@ func (r *RootCmd) listOrganizationMembers(orgContext *OrganizationContext) *serp
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "list",
|
||||
Short: "List all organization members",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
organization, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
|
||||
@@ -54,15 +54,14 @@ func (r *RootCmd) showOrganizationRoles(orgContext *OrganizationContext) *serpen
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "show [role_names ...]",
|
||||
Short: "Show role(s)",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
org, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
@@ -118,6 +117,7 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
jsonInput bool
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "create <role_name>",
|
||||
Short: "Create a new organization custom role",
|
||||
@@ -144,13 +144,10 @@ func (r *RootCmd) createOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(0, 1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
org, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -243,6 +240,7 @@ func (r *RootCmd) updateOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
jsonInput bool
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "update <role_name>",
|
||||
Short: "Update an organization custom role",
|
||||
@@ -269,13 +267,9 @@ func (r *RootCmd) updateOrganizationRole(orgContext *OrganizationContext) *serpe
|
||||
},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireRangeArgs(0, 1),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
org, err := orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
|
||||
@@ -95,6 +95,7 @@ type organizationSetting struct {
|
||||
}
|
||||
|
||||
func (r *RootCmd) setOrganizationSettings(orgContext *OrganizationContext, settings []organizationSetting) *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "set",
|
||||
Short: "Update specified organization setting.",
|
||||
@@ -107,6 +108,7 @@ func (r *RootCmd) setOrganizationSettings(orgContext *OrganizationContext, setti
|
||||
Options: []serpent.Option{},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
return inv.Command.HelpHandler(inv)
|
||||
@@ -122,15 +124,12 @@ func (r *RootCmd) setOrganizationSettings(orgContext *OrganizationContext, setti
|
||||
Options: []serpent.Option{},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
var org codersdk.Organization
|
||||
var err error
|
||||
|
||||
if !set.DisableOrgContext {
|
||||
org, err = orgContext.Selected(inv, client)
|
||||
@@ -171,6 +170,7 @@ func (r *RootCmd) setOrganizationSettings(orgContext *OrganizationContext, setti
|
||||
}
|
||||
|
||||
func (r *RootCmd) printOrganizationSetting(orgContext *OrganizationContext, settings []organizationSetting) *serpent.Command {
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Use: "show",
|
||||
Short: "Outputs specified organization setting.",
|
||||
@@ -183,6 +183,7 @@ func (r *RootCmd) printOrganizationSetting(orgContext *OrganizationContext, sett
|
||||
Options: []serpent.Option{},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
return inv.Command.HelpHandler(inv)
|
||||
@@ -198,15 +199,13 @@ func (r *RootCmd) printOrganizationSetting(orgContext *OrganizationContext, sett
|
||||
Options: []serpent.Option{},
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(0),
|
||||
r.InitClient(client),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := inv.Context()
|
||||
var org codersdk.Organization
|
||||
var err error
|
||||
|
||||
if !set.DisableOrgContext {
|
||||
org, err = orgContext.Selected(inv, client)
|
||||
if err != nil {
|
||||
|
||||
+10
-10
@@ -84,28 +84,28 @@ func (s *pingSummary) Write(w io.Writer) {
|
||||
|
||||
func (r *RootCmd) ping() *serpent.Command {
|
||||
var (
|
||||
pingNum int64
|
||||
pingTimeout time.Duration
|
||||
pingWait time.Duration
|
||||
pingTimeLocal bool
|
||||
pingTimeUTC bool
|
||||
pingNum int64
|
||||
pingTimeout time.Duration
|
||||
pingWait time.Duration
|
||||
pingTimeLocal bool
|
||||
pingTimeUTC bool
|
||||
appearanceConfig codersdk.AppearanceConfig
|
||||
)
|
||||
|
||||
client := new(codersdk.Client)
|
||||
cmd := &serpent.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "ping <workspace>",
|
||||
Short: "Ping a workspace",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
r.InitClient(client),
|
||||
initAppearance(client, &appearanceConfig),
|
||||
),
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
appearanceConfig := initAppearance(ctx, client)
|
||||
|
||||
notifyCtx, notifyCancel := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer notifyCancel()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user