Compare commits

..

11 Commits

Author SHA1 Message Date
Ben 01c6266e3e docs: add v2.8.3 changelog 2024-02-15 17:20:43 +00:00
Spike Curtis f9011dcba2 feat: expose DERP server debug metrics (#12135) (#12151)
Adds some debug endpoints for looking into the DERP server.

The `api/v2/debug/derp/traffic` endpoint requires the `ss` utility to be present in order to function.  I have *not* added the `iproute2` package to our base image as it adds 11MB, so this endpoint won't be useful by default.  However, in a debugging situation, we could exec into the container and then `apk add iproute2`, or build a special debug image.

The `api/v2/debug/expvar` handler contains DERP metrics as well as commandline and memstats.

Example:

```
{
"alert_failed": 0,
"alert_generated": 0,
"cmdline": ["/Users/spike/repos/coder/build/coder_darwin_arm64","--global-config","/Users/spike/repos/coder/.coderv2","server","--http-address","0.0.0.0:3000","--swagger-enable","--access-url","http://127.0.0.1:3000","--dangerous-allow-cors-requests=true"],
"derp": {"accepts": 1, "average_queue_duration_ms": 0, "bytes_received": 0, "bytes_sent": 0, "counter_packets_dropped_reason": {"gone_disconnected": 0, "gone_not_here": 0, "queue_head": 0, "queue_tail": 0, "unknown_dest": 0, "unknown_dest_on_fwd": 0, "write_error": 0}, "counter_packets_dropped_type": {"disco": 0, "other": 0}, "counter_packets_received_kind": {"disco": 0, "other": 0}, "counter_tcp_rtt": {}, "counter_total_dup_client_conns": 0, "gauge_clients_local": 1, "gauge_clients_remote": 0, "gauge_clients_total": 1, "gauge_current_connections": 1, "gauge_current_dup_client_conns": 0, "gauge_current_dup_client_keys": 0, "gauge_current_file_descriptors": 0, "gauge_current_home_connections": 1, "gauge_memstats_sys0": 20874504, "gauge_watchers": 0, "got_ping": 0, "home_moves_in": 0, "home_moves_out": 0, "multiforwarder_created": 0, "multiforwarder_deleted": 0, "packet_forwarder_delete_other_value": 0, "packets_dropped": 0, "packets_forwarded_in": 0, "packets_forwarded_out": 0, "packets_received": 0, "packets_sent": 0, "peer_gone_disconnected_frames": 0, "peer_gone_not_here_frames": 0, "sent_pong": 0, "unknown_frames": 0, "version": "1.47.0-dev20240214-t64db8c604"},
"memstats": {"Alloc":286506256,"TotalAlloc":297594632,"Sys":310621512,"Lookups":0,"Mallocs":304204,"Frees":171570,"HeapAlloc":286506256,"HeapSys":294060032,"HeapIdle":3694592,"HeapInuse":290365440,"HeapReleased":3620864,"HeapObjects":132634,"StackInuse":3735552,"StackSys":3735552,"MSpanInuse":347256,"MSpanSys":358512,"MCacheInuse":9600,"MCacheSys":15600,"BuckHashSys":1469877,"GCSys":9434896,"OtherSys":1547043,"NextGC":551867656,"LastGC":1707892877408883000,"PauseTotalNs":1247000,"PauseNs":[200333,229375,239875,209542,106958,203792,57125,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"PauseEnd":[1707892876217481000,1707892876219726000,1707892876222273000,1707892876226151000,1707892876234815000,1707892877398146000,1707892877408883000,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"NumGC":7,"NumForcedGC":0,"GCCPUFraction":0.0022425810335762954,"EnableGC":true,"DebugGC":false,"BySize":[{"Size":0,"Mallocs":0,"Frees":0},{"Size":8,"Mallocs":14396,"Frees":9143},{"Size":16,"Mallocs":89090,"Frees":50507},{"Size":24,"Mallocs":40839,"Frees":24456},{"Size":32,"Mallocs":22404,"Frees":12379},{"Size":48,"Mallocs":51174,"Frees":23718},{"Size":64,"Mallocs":15406,"Frees":3501},{"Size":80,"Mallocs":6688,"Frees":2352},{"Size":96,"Mallocs":2567,"Frees":374},{"Size":112,"Mallocs":19371,"Frees":16883},{"Size":128,"Mallocs":2873,"Frees":1061},{"Size":144,"Mallocs":5600,"Frees":2742},{"Size":160,"Mallocs":2159,"Frees":622},{"Size":176,"Mallocs":454,"Frees":86},{"Size":192,"Mallocs":227,"Frees":128},{"Size":208,"Mallocs":1407,"Frees":732},{"Size":224,"Mallocs":1365,"Frees":1090},{"Size":240,"Mallocs":82,"Frees":48},{"Size":256,"Mallocs":310,"Frees":162},{"Size":288,"Mallocs":1945,"Frees":562},{"Size":320,"Mallocs":1200,"Frees":458},{"Size":352,"Mallocs":133,"Frees":33},{"Size":384,"Mallocs":582,"Frees":51},{"Size":416,"Mallocs":747,"Frees":200},{"Size":448,"Mallocs":113,"Frees":22},{"Size":480,"Mallocs":34,"Frees":21},{"Size":512,"Mallocs":951,"Frees":91},{"Size":576,"Mallocs":364,"Frees":122},{"Size":640,"Mallocs":532,"Frees":270},{"Size":704,"Mallocs":93,"Frees":39},{"Size":768,"Mallocs":83,"Frees":35},{"Size":896,"Mallocs":308,"Frees":175},{"Size":1024,"Mallocs":226,"Frees":122},{"Size":1152,"Mallocs":198,"Frees":100},{"Size":1280,"Mallocs":314,"Frees":171},{"Size":1408,"Mallocs":77,"Frees":47},{"Size":1536,"Mallocs":80,"Frees":54},{"Size":1792,"Mallocs":199,"Frees":107},{"Size":2048,"Mallocs":112,"Frees":48},{"Size":2304,"Mallocs":71,"Frees":32},{"Size":2688,"Mallocs":206,"Frees":81},{"Size":3072,"Mallocs":39,"Frees":15},{"Size":3200,"Mallocs":16,"Frees":7},{"Size":3456,"Mallocs":44,"Frees":29},{"Size":4096,"Mallocs":192,"Frees":83},{"Size":4864,"Mallocs":44,"Frees":25},{"Size":5376,"Mallocs":105,"Frees":43},{"Size":6144,"Mallocs":25,"Frees":5},{"Size":6528,"Mallocs":22,"Frees":7},{"Size":6784,"Mallocs":3,"Frees":0},{"Size":6912,"Mallocs":4,"Frees":2},{"Size":8192,"Mallocs":59,"Frees":10},{"Size":9472,"Mallocs":31,"Frees":12},{"Size":9728,"Mallocs":5,"Frees":2},{"Size":10240,"Mallocs":5,"Frees":0},{"Size":10880,"Mallocs":27,"Frees":11},{"Size":12288,"Mallocs":4,"Frees":1},{"Size":13568,"Mallocs":4,"Frees":2},{"Size":14336,"Mallocs":9,"Frees":2},{"Size":16384,"Mallocs":10,"Frees":2},{"Size":18432,"Mallocs":4,"Frees":2}]},
"warning_failed": 0,
"warning_generated": 0
}
```

If we find the DERP metrics useful we could consider how to include them in Prometheus scrapes based on the tailnet `varz` package.  That's for a later PR if at all.
2024-02-15 09:10:29 +04:00
Spike Curtis ae1be27ba6 fix: set node callback each time we reinit the coordinator in servertailnet (#12140) (#12150)
I think this will resolve #12136 but lets get a proper test at the system level before closing.

Before this change, we only register the node callback at start of day for the server tailnet.  If the coordinator changes, like we know happens when we are licensed for the PGCoordinator, we close the connection to the old coord, and open a new one to the new coord.

The callback is designed to direct the updates to the new coordinator, but there is nothing that specifically triggers it to fire after we connect to the new coordinator.

If we have STUN, then period re-STUNs will generally get it to fire eventually, but without STUN it we could go indefinitely without a callback.

This PR changes the servertailnet to re-register the callback each time we reconnect to the coordinator.  Registering a callback (even if it's the same callback) triggers an immediate call with our node information, so the new coordinator will have it.
2024-02-15 08:58:04 +04:00
Colin Adler c4a01a42ce chore(docs): add v2.8.2 changelog 2024-02-12 05:33:05 +00:00
Dean Sheather c0aeb2fc2e fix: copy app ID in healthcheck (#12087)
(cherry picked from commit 429144da22)
2024-02-12 05:30:34 +00:00
Bruno Quaresma 908d236a19 fix(site): enable submit when auto start and stop are both disabled (#12055) 2024-02-07 18:28:33 +00:00
Spike Curtis f519db88fb fix: allow startup scripts larger than 32k (#12060)
Fixes #12057 and adds a regression test.
2024-02-07 18:28:14 +00:00
Ben e996e8b7e8 fix typo 2024-02-07 00:10:37 +00:00
Ben da60671b33 formatting + add TF version 2024-02-07 00:08:19 +00:00
Ben 963a1404c0 fmt 2024-02-06 23:41:58 +00:00
Ben 002110228c docs: add v2.8.0 changelog 2024-02-06 23:39:03 +00:00
2822 changed files with 153547 additions and 279244 deletions
+10 -10
View File
@@ -1,13 +1,13 @@
{
"name": "Development environments on your infrastructure",
"image": "codercom/oss-dogfood:latest",
"name": "Development environments on your infrastructure",
"image": "codercom/oss-dogfood:latest",
"features": {
// See all possible options here https://github.com/devcontainers/features/tree/main/src/docker-in-docker
"ghcr.io/devcontainers/features/docker-in-docker:2": {
"moby": "false"
}
},
// SYS_PTRACE to enable go debugging
"runArgs": ["--cap-add=SYS_PTRACE"]
"features": {
// See all possible options here https://github.com/devcontainers/features/tree/main/src/docker-in-docker
"ghcr.io/devcontainers/features/docker-in-docker:2": {
"moby": "false"
}
},
// SYS_PTRACE to enable go debugging
"runArgs": ["--cap-add=SYS_PTRACE"]
}
+1 -1
View File
@@ -7,7 +7,7 @@ trim_trailing_whitespace = true
insert_final_newline = true
indent_style = tab
[*.{yaml,yml,tf,tfvars,nix}]
[*.{md,json,yaml,yml,tf,tfvars,nix}]
indent_style = space
indent_size = 2
-2
View File
@@ -3,5 +3,3 @@
# chore: format code with semicolons when using prettier (#9555)
988c9af0153561397686c119da9d1336d2433fdd
# chore: use tabs for prettier and biome (#14283)
95a7c0c4f087744a22c2e88dd3c5d30024d5fb02
+2 -4
View File
@@ -1,17 +1,15 @@
# Generated files
coderd/apidoc/docs.go linguist-generated=true
docs/reference/api/*.md linguist-generated=true
docs/reference/cli/*.md linguist-generated=true
docs/api/*.md linguist-generated=true
docs/cli/*.md linguist-generated=true
coderd/apidoc/swagger.json linguist-generated=true
coderd/database/dump.sql linguist-generated=true
peerbroker/proto/*.go linguist-generated=true
provisionerd/proto/*.go linguist-generated=true
provisionerd/proto/version.go linguist-generated=false
provisionersdk/proto/*.go linguist-generated=true
*.tfplan.json linguist-generated=true
*.tfstate.json linguist-generated=true
*.tfstate.dot linguist-generated=true
*.tfplan.dot linguist-generated=true
site/e2e/provisionerGenerated.ts linguist-generated=true
site/src/api/typesGenerated.ts linguist-generated=true
site/src/pages/SetupPage/countries.tsx linguist-generated=true
+2 -2
View File
@@ -4,12 +4,12 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.22.5"
default: "1.21.5"
runs:
using: "composite"
steps:
- name: Setup Go
uses: actions/setup-go@v5
uses: buildjet/setup-go@v4
with:
go-version: ${{ inputs.version }}
+4 -4
View File
@@ -11,13 +11,13 @@ runs:
using: "composite"
steps:
- name: Install pnpm
uses: pnpm/action-setup@v3
uses: pnpm/action-setup@v2
with:
version: 9.6
version: 8
- name: Setup Node
uses: actions/setup-node@v4.0.3
uses: buildjet/setup-node@v3
with:
node-version: 20.16.0
node-version: 18.19.0
# See https://github.com/actions/setup-node#caching-global-packages-data
cache: "pnpm"
cache-dependency-path: ${{ inputs.directory }}/pnpm-lock.yaml
+1 -1
View File
@@ -7,5 +7,5 @@ runs:
- name: Install Terraform
uses: hashicorp/setup-terraform@v3
with:
terraform_version: 1.9.2
terraform_version: 1.5.7
terraform_wrapper: false
+43
View File
@@ -0,0 +1,43 @@
codecov:
require_ci_to_pass: false
notify:
after_n_builds: 5
comment: false
github_checks:
annotations: false
coverage:
range: 50..75
round: down
precision: 2
status:
patch:
default:
informational: yes
project:
default:
target: 65%
informational: true
ignore:
# This is generated code.
- coderd/database/models.go
- coderd/database/queries.sql.go
- coderd/database/databasefake
# These are generated or don't require tests.
- cmd
- coderd/tunnel
- coderd/database/dump
- coderd/database/postgres
- peerbroker/proto
- provisionerd/proto
- provisionersdk/proto
- scripts
- site/.storybook
- rules.go
# Packages used for writing tests.
- cli/clitest
- coderd/coderdtest
- pty/ptytest
+51 -34
View File
@@ -39,10 +39,6 @@ updates:
prefix: "chore"
labels: []
open-pull-requests-limit: 15
groups:
x:
patterns:
- "golang.org/x/*"
ignore:
# Ignore patch updates for all dependencies
- dependency-name: "*"
@@ -65,9 +61,7 @@ updates:
- dependency-name: "terraform"
- package-ecosystem: "npm"
directories:
- "/site"
- "/offlinedocs"
directory: "/site/"
schedule:
interval: "monthly"
time: "06:00"
@@ -77,35 +71,58 @@ updates:
commit-message:
prefix: "chore"
labels: []
groups:
xterm:
patterns:
- "@xterm*"
mui:
patterns:
- "@mui*"
react:
patterns:
- "react"
- "react-dom"
- "@types/react"
- "@types/react-dom"
emotion:
patterns:
- "@emotion*"
exclude-patterns:
- "jest-runner-eslint"
jest:
patterns:
- "jest"
- "@types/jest"
vite:
patterns:
- "vite*"
- "@vitejs/plugin-react"
ignore:
# Ignore major version updates to avoid breaking changes
# Ignore patch updates for all dependencies
- dependency-name: "*"
update-types:
- version-update:semver-patch
# Ignore major updates to Node.js types, because they need to
# correspond to the Node.js engine version
- dependency-name: "@types/node"
update-types:
- version-update:semver-major
open-pull-requests-limit: 15
groups:
site:
patterns:
- "*"
- package-ecosystem: "npm"
directory: "/offlinedocs/"
schedule:
interval: "monthly"
time: "06:00"
timezone: "America/Chicago"
reviewers:
- "coder/ts"
commit-message:
prefix: "chore"
labels: []
ignore:
# Ignore patch updates for all dependencies
- dependency-name: "*"
update-types:
- version-update:semver-patch
# Ignore major updates to Node.js types, because they need to
# correspond to the Node.js engine version
- dependency-name: "@types/node"
update-types:
- version-update:semver-major
groups:
offlinedocs:
patterns:
- "*"
# Update dogfood.
- package-ecosystem: "terraform"
directory: "/dogfood/"
schedule:
interval: "weekly"
time: "06:00"
timezone: "America/Chicago"
commit-message:
prefix: "chore"
labels: []
ignore:
# We likely want to update this ourselves.
- dependency-name: "coder/coder"
-34
View File
@@ -1,34 +0,0 @@
app = "jnb-coder"
primary_region = "jnb"
[experimental]
entrypoint = ["/bin/sh", "-c", "CODER_DERP_SERVER_RELAY_URL=\"http://[${FLY_PRIVATE_IP}]:3000\" /opt/coder wsproxy server"]
auto_rollback = true
[build]
image = "ghcr.io/coder/coder-preview:main"
[env]
CODER_ACCESS_URL = "https://jnb.fly.dev.coder.com"
CODER_HTTP_ADDRESS = "0.0.0.0:3000"
CODER_PRIMARY_ACCESS_URL = "https://dev.coder.com"
CODER_WILDCARD_ACCESS_URL = "*--apps.jnb.fly.dev.coder.com"
CODER_VERBOSE = "true"
[http_service]
internal_port = 3000
force_https = true
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
memory_mb = 512
-6
View File
@@ -22,12 +22,6 @@ primary_region = "cdg"
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
@@ -22,12 +22,6 @@ primary_region = "gru"
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
-6
View File
@@ -22,12 +22,6 @@ primary_region = "syd"
auto_start_machines = true
min_machines_running = 0
# Ref: https://fly.io/docs/reference/configuration/#http_service-concurrency
[http_service.concurrency]
type = "requests"
soft_limit = 50
hard_limit = 100
[[vm]]
cpu_kind = "shared"
cpus = 2
+14 -14
View File
@@ -86,12 +86,12 @@ provider "kubernetes" {
}
data "coder_workspace" "me" {}
data "coder_workspace_owner" "me" {}
resource "coder_agent" "main" {
os = "linux"
arch = "amd64"
startup_script = <<-EOT
os = "linux"
arch = "amd64"
startup_script_timeout = 180
startup_script = <<-EOT
set -e
# install and start code-server
@@ -176,21 +176,21 @@ resource "coder_app" "code-server" {
resource "kubernetes_persistent_volume_claim" "home" {
metadata {
name = "coder-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}-home"
name = "coder-${lower(data.coder_workspace.me.owner)}-${lower(data.coder_workspace.me.name)}-home"
namespace = var.namespace
labels = {
"app.kubernetes.io/name" = "coder-pvc"
"app.kubernetes.io/instance" = "coder-pvc-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/instance" = "coder-pvc-${lower(data.coder_workspace.me.owner)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/part-of" = "coder"
//Coder-specific labels.
"com.coder.resource" = "true"
"com.coder.workspace.id" = data.coder_workspace.me.id
"com.coder.workspace.name" = data.coder_workspace.me.name
"com.coder.user.id" = data.coder_workspace_owner.me.id
"com.coder.user.username" = data.coder_workspace_owner.me.name
"com.coder.user.id" = data.coder_workspace.me.owner_id
"com.coder.user.username" = data.coder_workspace.me.owner
}
annotations = {
"com.coder.user.email" = data.coder_workspace_owner.me.email
"com.coder.user.email" = data.coder_workspace.me.owner_email
}
}
wait_until_bound = false
@@ -211,20 +211,20 @@ resource "kubernetes_deployment" "main" {
]
wait_for_rollout = false
metadata {
name = "coder-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
name = "coder-${lower(data.coder_workspace.me.owner)}-${lower(data.coder_workspace.me.name)}"
namespace = var.namespace
labels = {
"app.kubernetes.io/name" = "coder-workspace"
"app.kubernetes.io/instance" = "coder-workspace-${lower(data.coder_workspace_owner.me.name)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/instance" = "coder-workspace-${lower(data.coder_workspace.me.owner)}-${lower(data.coder_workspace.me.name)}"
"app.kubernetes.io/part-of" = "coder"
"com.coder.resource" = "true"
"com.coder.workspace.id" = data.coder_workspace.me.id
"com.coder.workspace.name" = data.coder_workspace.me.name
"com.coder.user.id" = data.coder_workspace_owner.me.id
"com.coder.user.username" = data.coder_workspace_owner.me.name
"com.coder.user.id" = data.coder_workspace.me.owner_id
"com.coder.user.username" = data.coder_workspace.me.owner
}
annotations = {
"com.coder.user.email" = data.coder_workspace_owner.me.email
"com.coder.user.email" = data.coder_workspace.me.owner_email
}
}
+101 -185
View File
@@ -37,10 +37,8 @@ jobs:
k8s: ${{ steps.filter.outputs.k8s }}
ci: ${{ steps.filter.outputs.ci }}
db: ${{ steps.filter.outputs.db }}
gomod: ${{ steps.filter.outputs.gomod }}
offlinedocs-only: ${{ steps.filter.outputs.offlinedocs_count == steps.filter.outputs.all_count }}
offlinedocs: ${{ steps.filter.outputs.offlinedocs }}
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -92,9 +90,6 @@ jobs:
- "scaletest/**"
- "tailnet/**"
- "testutil/**"
gomod:
- "go.mod"
- "go.sum"
ts:
- "site/**"
- "Makefile"
@@ -108,54 +103,15 @@ jobs:
- ".github/workflows/ci.yaml"
offlinedocs:
- "offlinedocs/**"
tailnet-integration:
- "tailnet/**"
- "go.mod"
- "go.sum"
- id: debug
run: |
echo "${{ toJSON(steps.filter )}}"
# Disabled due to instability. See: https://github.com/coder/coder/issues/14553
# Re-enable once the flake hash calculation is stable.
# update-flake:
# needs: changes
# if: needs.changes.outputs.gomod == 'true'
# runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
# steps:
# - name: Checkout
# uses: actions/checkout@v4
# with:
# fetch-depth: 1
# # See: https://github.com/stefanzweifel/git-auto-commit-action?tab=readme-ov-file#commits-made-by-this-action-do-not-trigger-new-workflow-runs
# token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
# - name: Setup Go
# uses: ./.github/actions/setup-go
# - name: Update Nix Flake SRI Hash
# run: ./scripts/update-flake.sh
# # auto update flake for dependabot
# - uses: stefanzweifel/git-auto-commit-action@v5
# if: github.actor == 'dependabot[bot]'
# with:
# # Allows dependabot to still rebase!
# commit_message: "[dependabot skip] Update Nix Flake SRI Hash"
# commit_user_name: "dependabot[bot]"
# commit_user_email: "49699333+dependabot[bot]@users.noreply.github.com>"
# commit_author: "dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>"
# # require everyone else to update it themselves
# - name: Ensure No Changes
# if: github.actor != 'dependabot[bot]'
# run: git diff --exit-code
lint:
needs: changes
if: needs.changes.outputs.offlinedocs-only == 'false' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -170,13 +126,12 @@ jobs:
- name: Get golangci-lint cache dir
run: |
linter_ver=$(egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/contents/Dockerfile | cut -d '=' -f 2)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v$linter_ver
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.53.2
dir=$(golangci-lint cache status | awk '/Dir/ { print $2 }')
echo "LINT_CACHE_DIR=$dir" >> $GITHUB_ENV
- name: golangci-lint cache
uses: actions/cache@v4
uses: buildjet/cache@v4
with:
path: |
${{ env.LINT_CACHE_DIR }}
@@ -186,7 +141,7 @@ jobs:
# Check for any typos
- name: Check for typos
uses: crate-ci/typos@v1.24.6
uses: crate-ci/typos@v1.18.0
with:
config: .github/workflows/typos.toml
@@ -199,7 +154,7 @@ jobs:
# Needed for helm chart linting
- name: Install helm
uses: azure/setup-helm@v4
uses: azure/setup-helm@v3
with:
version: v3.9.2
@@ -207,15 +162,9 @@ jobs:
run: |
make --output-sync=line -j lint
- name: Check workflow files
run: |
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) 1.6.22
./actionlint -color -shellcheck= -ignore "set-output"
shell: bash
gen:
timeout-minutes: 8
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs: changes
if: needs.changes.outputs.docs-only == 'false' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
@@ -233,9 +182,6 @@ jobs:
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Setup Terraform
uses: ./.github/actions/setup-tf
- name: go install tools
run: |
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
@@ -255,9 +201,7 @@ jobs:
popd
- name: make gen
# no `-j` flag as `make` fails with:
# coderd/rbac/object_gen.go:1:1: syntax error: package statement must be first
run: "make --output-sync -B gen"
run: "make --output-sync -j -B gen"
- name: Check for unstaged files
run: ./scripts/check_unstaged.sh
@@ -265,7 +209,7 @@ jobs:
fmt:
needs: changes
if: needs.changes.outputs.offlinedocs-only == 'false' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
timeout-minutes: 7
steps:
- name: Checkout
@@ -276,9 +220,12 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
# Use default Go version
- name: Setup Go
uses: ./.github/actions/setup-go
uses: buildjet/setup-go@v5
with:
# This doesn't need caching. It's super fast anyways!
cache: false
go-version: 1.21.5
- name: Install shfmt
run: go install mvdan.cc/sh/v3/cmd/shfmt@v3.7.0
@@ -292,7 +239,7 @@ jobs:
run: ./scripts/check_unstaged.sh
test-go:
runs-on: ${{ matrix.os == 'ubuntu-latest' && github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'macos-latest-xlarge' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'windows-latest-16-cores' || matrix.os }}
runs-on: ${{ matrix.os == 'ubuntu-latest' && github.repository_owner == 'coder' && 'buildjet-4vcpu-ubuntu-2204' || matrix.os == 'macos-latest' && github.repository_owner == 'coder' && 'macos-latest-xlarge' || matrix.os == 'windows-2022' && github.repository_owner == 'coder' && 'windows-latest-16-cores' || matrix.os }}
needs: changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
timeout-minutes: 20
@@ -319,6 +266,16 @@ jobs:
id: test
shell: bash
run: |
# Code coverage is more computationally expensive and also
# prevents test caching, so we disable it on alternate operating
# systems.
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
echo "cover=true" >> $GITHUB_OUTPUT
export COVERAGE_FLAGS='-covermode=atomic -coverprofile="gotests.coverage" -coverpkg=./...'
else
echo "cover=false" >> $GITHUB_OUTPUT
fi
# if macOS, install google-chrome for scaletests. As another concern,
# should we really have this kind of external dependency requirement
# on standard CI?
@@ -337,7 +294,7 @@ jobs:
fi
export TS_DEBUG_DISCO=true
gotestsum --junitfile="gotests.xml" --jsonfile="gotests.json" \
--packages="./..." -- $PARALLEL_FLAG -short -failfast
--packages="./..." -- $PARALLEL_FLAG -short -failfast $COVERAGE_FLAGS
- name: Upload test stats to Datadog
timeout-minutes: 1
@@ -347,8 +304,21 @@ jobs:
with:
api-key: ${{ secrets.DATADOG_API_KEY }}
- name: Check code coverage
uses: codecov/codecov-action@v4
# This action has a tendency to error out unexpectedly, it has
# the `fail_ci_if_error` option that defaults to `false`, but
# that is no guarantee, see:
# https://github.com/codecov/codecov-action/issues/788
continue-on-error: true
if: steps.test.outputs.cover && github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./gotests.coverage
flags: unittest-go-${{ matrix.os }}
test-go-pg:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs:
- changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
@@ -370,10 +340,8 @@ jobs:
uses: ./.github/actions/setup-tf
- name: Test with PostgreSQL Database
env:
POSTGRES_VERSION: "13"
TS_DEBUG_DISCO: "true"
run: |
export TS_DEBUG_DISCO=true
make test-postgres
- name: Upload test stats to Datadog
@@ -384,48 +352,21 @@ jobs:
with:
api-key: ${{ secrets.DATADOG_API_KEY }}
# NOTE: this could instead be defined as a matrix strategy, but we want to
# only block merging if tests on postgres 13 fail. Using a matrix strategy
# here makes the check in the above `required` job rather complicated.
test-go-pg-16:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
needs:
- changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
# This timeout must be greater than the timeout set by `go test` in
# `make test-postgres` to ensure we receive a trace of running
# goroutines. Setting this to the timeout +5m should work quite well
# even if some of the preceding steps are slow.
timeout-minutes: 25
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup Terraform
uses: ./.github/actions/setup-tf
- name: Test with PostgreSQL Database
env:
POSTGRES_VERSION: "16"
TS_DEBUG_DISCO: "true"
run: |
make test-postgres
- name: Upload test stats to Datadog
timeout-minutes: 1
- name: Check code coverage
uses: codecov/codecov-action@v4
# This action has a tendency to error out unexpectedly, it has
# the `fail_ci_if_error` option that defaults to `false`, but
# that is no guarantee, see:
# https://github.com/codecov/codecov-action/issues/788
continue-on-error: true
uses: ./.github/actions/upload-datadog
if: success() || failure()
if: github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
with:
api-key: ${{ secrets.DATADOG_API_KEY }}
token: ${{ secrets.CODECOV_TOKEN }}
files: ./gotests.coverage
flags: unittest-go-postgres-linux
test-go-race:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs: changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
timeout-minutes: 25
@@ -453,36 +394,8 @@ jobs:
with:
api-key: ${{ secrets.DATADOG_API_KEY }}
# Tailnet integration tests only run when the `tailnet` directory or `go.sum`
# and `go.mod` are changed. These tests are to ensure we don't add regressions
# to tailnet, either due to our code or due to updating dependencies.
#
# These tests are skipped in the main go test jobs because they require root
# and mess with networking.
test-go-tailnet-integration:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
needs: changes
# Unnecessary to run on main for now
if: needs.changes.outputs.tailnet-integration == 'true' || needs.changes.outputs.ci == 'true'
timeout-minutes: 20
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Setup Go
uses: ./.github/actions/setup-go
# Used by some integration tests.
- name: Install Nginx
run: sudo apt-get update && sudo apt-get install -y nginx
- name: Run Tests
run: make test-tailnet-integration
test-js:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs: changes
if: needs.changes.outputs.ts == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
timeout-minutes: 20
@@ -498,20 +411,24 @@ jobs:
- run: pnpm test:ci --max-workers $(nproc)
working-directory: site
- name: Check code coverage
uses: codecov/codecov-action@v4
# This action has a tendency to error out unexpectedly, it has
# the `fail_ci_if_error` option that defaults to `false`, but
# that is no guarantee, see:
# https://github.com/codecov/codecov-action/issues/788
continue-on-error: true
if: github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./site/coverage/lcov.info
flags: unittest-js
test-e2e:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-16' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-16vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs: changes
if: needs.changes.outputs.go == 'true' || needs.changes.outputs.ts == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
variant:
- enterprise: false
name: test-e2e
- enterprise: true
name: test-e2e-enterprise
name: ${{ matrix.variant.name }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -524,40 +441,44 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
# Assume that the checked-in versions are up-to-date
- run: make gen/mark-fresh
name: make gen
- name: Setup Terraform
uses: ./.github/actions/setup-tf
- run: pnpm build
working-directory: site
- name: go install tools
run: |
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.30
go install storj.io/drpc/cmd/protoc-gen-go-drpc@v0.0.33
go install golang.org/x/tools/cmd/goimports@latest
go install github.com/mikefarah/yq/v4@v4.30.6
go install go.uber.org/mock/mockgen@v0.4.0
- name: Install Protoc
run: |
mkdir -p /tmp/proto
pushd /tmp/proto
curl -L -o protoc.zip https://github.com/protocolbuffers/protobuf/releases/download/v23.3/protoc-23.3-linux-x86_64.zip
unzip protoc.zip
cp -r ./bin/* /usr/local/bin
cp -r ./include /usr/local/bin/include
popd
- name: Build
run: |
make -B site/out/index.html
- run: pnpm playwright:install
working-directory: site
# Run tests that don't require an enterprise license without an enterprise license
- run: pnpm playwright:test --forbid-only --workers 1
if: ${{ !matrix.variant.enterprise }}
- run: pnpm playwright:test --workers 1
env:
DEBUG: pw:api
working-directory: site
# Run all of the tests with an enterprise license
- run: pnpm playwright:test --forbid-only --workers 1
if: ${{ matrix.variant.enterprise }}
env:
DEBUG: pw:api
CODER_E2E_ENTERPRISE_LICENSE: ${{ secrets.CODER_E2E_ENTERPRISE_LICENSE }}
CODER_E2E_REQUIRE_ENTERPRISE_TESTS: "1"
working-directory: site
# Temporarily allow these to fail so that I can gather data about which
# tests are failing.
continue-on-error: true
- name: Upload Playwright Failed Tests
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@v4
with:
name: failed-test-videos${{ matrix.variant.enterprise && '-enterprise' || '-agpl' }}
name: failed-test-videos
path: ./site/test-results/**/*.webm
retention-days: 7
@@ -565,7 +486,7 @@ jobs:
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@v4
with:
name: debug-pprof-dumps${{ matrix.variant.enterprise && '-enterprise' || '-agpl' }}
name: debug-pprof-dumps
path: ./site/test-results/**/debug-pprof-*.txt
retention-days: 7
@@ -643,7 +564,7 @@ jobs:
offlinedocs:
name: offlinedocs
needs: changes
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
if: needs.changes.outputs.offlinedocs == 'true' || needs.changes.outputs.ci == 'true' || needs.changes.outputs.docs == 'true'
steps:
@@ -693,10 +614,8 @@ jobs:
pnpm lint
- name: Build
# no `-j` flag as `make` fails with:
# coderd/rbac/object_gen.go:1:1: syntax error: package statement must be first
run: |
make build/coder_docs_"$(./scripts/version.sh)".tgz
make -j build/coder_docs_"$(./scripts/version.sh)".tgz
required:
runs-on: ubuntu-latest
@@ -739,10 +658,11 @@ jobs:
build:
# This builds and publishes ghcr.io/coder/coder-preview:main for each commit
# to main branch.
# to main branch. We are only building this for amd64 platform. (>95% pulls
# are for amd64)
needs: changes
if: github.ref == 'refs/heads/main' && needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
if: needs.changes.outputs.docs-only == 'false' && !github.event.pull_request.head.repo.fork
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
env:
DOCKER_CLI_EXPERIMENTAL: "enabled"
outputs:
@@ -802,15 +722,13 @@ jobs:
echo "tag=$tag" >> $GITHUB_OUTPUT
# build images for each architecture
# note: omitting the -j argument to avoid race conditions when pushing
make build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
make -j build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
# only push if we are on main branch
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
# build and push multi-arch manifest, this depends on the other images
# being pushed so will automatically push them
# note: omitting the -j argument to avoid race conditions when pushing
make push/build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
make -j push/build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
# Define specific tags
tags=("$tag" "main" "latest")
@@ -937,20 +855,18 @@ jobs:
flyctl deploy --image "$IMAGE" --app paris-coder --config ./.github/fly-wsproxies/paris-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_PARIS" --yes
flyctl deploy --image "$IMAGE" --app sydney-coder --config ./.github/fly-wsproxies/sydney-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SYDNEY" --yes
flyctl deploy --image "$IMAGE" --app sao-paulo-coder --config ./.github/fly-wsproxies/sao-paulo-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_SAO_PAULO" --yes
flyctl deploy --image "$IMAGE" --app jnb-coder --config ./.github/fly-wsproxies/jnb-coder.toml --env "CODER_PROXY_SESSION_TOKEN=$TOKEN_JNB" --yes
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}
IMAGE: ${{ needs.build.outputs.IMAGE }}
TOKEN_PARIS: ${{ secrets.FLY_PARIS_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SYDNEY: ${{ secrets.FLY_SYDNEY_CODER_PROXY_SESSION_TOKEN }}
TOKEN_SAO_PAULO: ${{ secrets.FLY_SAO_PAULO_CODER_PROXY_SESSION_TOKEN }}
TOKEN_JNB: ${{ secrets.FLY_JNB_CODER_PROXY_SESSION_TOKEN }}
# sqlc-vet runs a postgres docker container, runs Coder migrations, and then
# runs sqlc-vet to ensure all queries are valid. This catches any mistakes
# in migrations or sqlc queries that makes a query unable to be prepared.
sqlc-vet:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
needs: changes
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
+2 -4
View File
@@ -13,8 +13,6 @@ on:
- opened
- reopened
- edited
# For jobs that don't run on draft PRs.
- ready_for_review
# Only run one instance per PR to ensure in-order execution.
concurrency: pr-${{ github.ref }}
@@ -36,7 +34,7 @@ jobs:
steps:
- name: cla
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/github-action@v2.6.0
uses: contributor-assistant/github-action@v2.3.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# the below token should have repo scope and must be manually added by you in the repository's secret
@@ -54,7 +52,7 @@ jobs:
release-labels:
runs-on: ubuntu-latest
# Skip tagging for draft PRs.
if: ${{ github.event_name == 'pull_request_target' && !github.event.pull_request.draft }}
if: ${{ github.event_name == 'pull_request_target' && success() && !github.event.pull_request.draft }}
steps:
- name: release-labels
uses: actions/github-script@v7
+1 -7
View File
@@ -8,11 +8,6 @@ on:
- scripts/Dockerfile.base
- scripts/Dockerfile
pull_request:
paths:
- scripts/Dockerfile.base
- .github/workflows/docker-base.yaml
schedule:
# Run every week at 09:43 on Monday, Wednesday and Friday. We build this
# frequently to ensure that packages are up-to-date.
@@ -62,12 +57,11 @@ jobs:
platforms: linux/amd64,linux/arm64,linux/arm/v7
pull: true
no-cache: true
push: ${{ github.event_name != 'pull_request' }}
push: true
tags: |
ghcr.io/coder/coder-base:latest
- name: Verify that images are pushed properly
if: github.event_name != 'pull_request'
run: |
# retry 10 times with a 5 second delay as the images may not be
# available immediately
+14 -24
View File
@@ -17,13 +17,8 @@ on:
- "flake.nix"
workflow_dispatch:
permissions:
# Necessary for GCP authentication (https://github.com/google-github-actions/setup-gcloud#usage)
id-token: write
jobs:
build_image:
if: github.actor != 'dependabot[bot]' # Skip Dependabot PRs
runs-on: ubuntu-latest
steps:
- name: Checkout
@@ -60,7 +55,7 @@ jobs:
project: b4q6ltmpzh
token: ${{ secrets.DEPOT_TOKEN }}
buildx-fallback: true
context: "{{defaultContext}}:dogfood/contents"
context: "{{defaultContext}}:dogfood"
pull: true
save: true
push: ${{ github.ref == 'refs/heads/main' }}
@@ -73,7 +68,7 @@ jobs:
token: ${{ secrets.DEPOT_TOKEN }}
buildx-fallback: true
context: "."
file: "dogfood/contents/Dockerfile.nix"
file: "dogfood/Dockerfile.nix"
pull: true
save: true
push: ${{ github.ref == 'refs/heads/main' }}
@@ -89,20 +84,11 @@ jobs:
- name: Setup Terraform
uses: ./.github/actions/setup-tf
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v2
with:
workload_identity_provider: projects/573722524737/locations/global/workloadIdentityPools/github/providers/github
service_account: coder-ci@coder-dogfood.iam.gserviceaccount.com
- name: Terraform init and validate
run: |
cd dogfood
terraform init -upgrade
terraform validate
cd contents
terraform init -upgrade
terraform validate
- name: Get short commit SHA
if: github.ref == 'refs/heads/main'
@@ -114,18 +100,22 @@ jobs:
id: message
run: echo "pr_title=$(git log --format=%s -n 1 ${{ github.sha }})" >> $GITHUB_OUTPUT
- name: "Get latest Coder binary from the server"
if: github.ref == 'refs/heads/main'
run: |
curl -fsSL "https://dev.coder.com/bin/coder-linux-amd64" -o "./coder"
chmod +x "./coder"
- name: "Push template"
if: github.ref == 'refs/heads/main'
run: |
cd dogfood
terraform apply -auto-approve
./coder templates push $CODER_TEMPLATE_NAME --directory $CODER_TEMPLATE_DIR --yes --name=$CODER_TEMPLATE_VERSION --message="$CODER_TEMPLATE_MESSAGE" --variable jfrog_url=${{ secrets.JFROG_URL }}
env:
# Consumed by coderd provider
# Consumed by Coder CLI
CODER_URL: https://dev.coder.com
CODER_SESSION_TOKEN: ${{ secrets.CODER_SESSION_TOKEN }}
# Template source & details
TF_VAR_CODER_TEMPLATE_NAME: ${{ secrets.CODER_TEMPLATE_NAME }}
TF_VAR_CODER_TEMPLATE_VERSION: ${{ steps.vars.outputs.sha_short }}
TF_VAR_CODER_TEMPLATE_DIR: ./contents
TF_VAR_CODER_TEMPLATE_MESSAGE: ${{ steps.message.outputs.pr_title }}
TF_LOG: info
CODER_TEMPLATE_NAME: ${{ secrets.CODER_TEMPLATE_NAME }}
CODER_TEMPLATE_VERSION: ${{ steps.vars.outputs.sha_short }}
CODER_TEMPLATE_DIR: ./dogfood
CODER_TEMPLATE_MESSAGE: ${{ steps.message.outputs.pr_title }}
+21 -24
View File
@@ -1,26 +1,23 @@
{
"ignorePatterns": [
{
"pattern": "://localhost"
},
{
"pattern": "://.*.?example\\.com"
},
{
"pattern": "developer.github.com"
},
{
"pattern": "docs.github.com"
},
{
"pattern": "support.google.com"
},
{
"pattern": "tailscale.com"
},
{
"pattern": "wireguard.com"
}
],
"aliveStatusCodes": [200, 0]
"ignorePatterns": [
{
"pattern": "://localhost"
},
{
"pattern": "://.*.?example\\.com"
},
{
"pattern": "developer.github.com"
},
{
"pattern": "docs.github.com"
},
{
"pattern": "support.google.com"
},
{
"pattern": "tailscale.com"
}
],
"aliveStatusCodes": [200, 0]
}
+2 -2
View File
@@ -11,7 +11,7 @@ jobs:
# While GitHub's toaster runners are likelier to flake, we want consistency
# between this environment and the regular test environment for DataDog
# statistics and to only show real workflow threats.
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: "buildjet-8vcpu-ubuntu-2204"
# This runner costs 0.016 USD per minute,
# so 0.016 * 240 = 3.84 USD per run.
timeout-minutes: 240
@@ -40,7 +40,7 @@ jobs:
go-timing:
# We run these tests with p=1 so we don't need a lot of compute.
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04' || 'ubuntu-latest' }}
runs-on: "buildjet-2vcpu-ubuntu-2204"
timeout-minutes: 10
steps:
- name: Checkout
+1 -1
View File
@@ -14,4 +14,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Assign author
uses: toshimaru/auto-author-assign@v2.1.1
uses: toshimaru/auto-author-assign@v2.1.0
+3 -3
View File
@@ -101,7 +101,7 @@ jobs:
run: |
set -euo pipefail
mkdir -p ~/.kube
echo "${{ secrets.PR_DEPLOYMENTS_KUBECONFIG_BASE64 }}" | base64 --decode > ~/.kube/config
echo "${{ secrets.PR_DEPLOYMENTS_KUBECONFIG }}" > ~/.kube/config
chmod 644 ~/.kube/config
export KUBECONFIG=~/.kube/config
@@ -189,7 +189,7 @@ jobs:
needs: get_info
# Run build job only if there are changes in the files that we care about or if the workflow is manually triggered with --build flag
if: needs.get_info.outputs.BUILD == 'true'
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
# This concurrency only cancels build jobs if a new build is triggred. It will avoid cancelling the current deployemtn in case of docs chnages.
concurrency:
group: build-${{ github.workflow }}-${{ github.ref }}-${{ needs.get_info.outputs.BUILD }}
@@ -253,7 +253,7 @@ jobs:
run: |
set -euo pipefail
mkdir -p ~/.kube
echo "${{ secrets.PR_DEPLOYMENTS_KUBECONFIG_BASE64 }}" | base64 --decode > ~/.kube/config
echo "${{ secrets.PR_DEPLOYMENTS_KUBECONFIG }}" > ~/.kube/config
chmod 644 ~/.kube/config
export KUBECONFIG=~/.kube/config
-20
View File
@@ -1,20 +0,0 @@
name: release-validation
on:
push:
tags:
- "v*"
jobs:
network-performance:
runs-on: ubuntu-latest
steps:
- name: Run Schmoder CI
uses: benc-uk/workflow-dispatch@v1.2.4
with:
workflow: ci.yaml
repo: coder/schmoder
inputs: '{ "num_releases": "3", "commit": "${{ github.sha }}" }'
token: ${{ secrets.CDRCI_SCHMODER_ACTIONS_TOKEN }}
ref: main
+20 -97
View File
@@ -1,16 +1,11 @@
# GitHub release workflow.
name: Release
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
release_channel:
type: choice
description: Release channel
options:
- mainline
- stable
release_notes:
description: Release notes for the publishing the release. This is required to create a release.
dry_run:
description: Perform a dry-run release (devel). Note that ref must be an annotated tag when run without dry-run.
type: boolean
@@ -33,13 +28,11 @@ env:
# https://github.blog/changelog/2022-06-10-github-actions-inputs-unified-across-manual-and-reusable-workflows/
CODER_RELEASE: ${{ !inputs.dry_run }}
CODER_DRY_RUN: ${{ inputs.dry_run }}
CODER_RELEASE_CHANNEL: ${{ inputs.release_channel }}
CODER_RELEASE_NOTES: ${{ inputs.release_notes }}
jobs:
release:
name: Build and publish
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
env:
# Necessary for Docker manifest
DOCKER_CLI_EXPERIMENTAL: "enabled"
@@ -69,45 +62,21 @@ jobs:
echo "CODER_FORCE_VERSION=$version" >> $GITHUB_ENV
echo "$version"
# Verify that all expectations for a release are met.
- name: Verify release input
if: ${{ !inputs.dry_run }}
run: |
set -euo pipefail
if [[ "${GITHUB_REF}" != "refs/tags/v"* ]]; then
echo "Ref must be a semver tag when creating a release, did you use scripts/release.sh?"
exit 1
fi
# 2.10.2 -> release/2.10
version="$(./scripts/version.sh)"
release_branch=release/${version%.*}
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
exit 1
fi
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
echo "Release notes are required to create a release, did you use scripts/release.sh?"
exit 1
fi
echo "Release inputs verified:"
echo
echo "- Ref: ${GITHUB_REF}"
echo "- Version: ${version}"
echo "- Release channel: ${CODER_RELEASE_CHANNEL}"
echo "- Release branch: ${release_branch}"
echo "- Release notes: true"
- name: Create release notes file
- name: Create release notes
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# We always have to set this since there might be commits on
# main that didn't have a PR.
CODER_IGNORE_MISSING_COMMIT_METADATA: "1"
run: |
set -euo pipefail
ref=HEAD
old_version="$(git describe --abbrev=0 "$ref^1")"
version="v$(./scripts/version.sh)"
# Generate notes.
release_notes_file="$(mktemp -t release_notes.XXXXXX)"
echo "$CODER_RELEASE_NOTES" > "$release_notes_file"
./scripts/release/generate_release_notes.sh --check-for-changelog --old-version "$old_version" --new-version "$version" --ref "$ref" >> "$release_notes_file"
echo CODER_RELEASE_NOTES_FILE="$release_notes_file" >> $GITHUB_ENV
- name: Show release notes
@@ -128,13 +97,6 @@ jobs:
- name: Setup Node
uses: ./.github/actions/setup-node
# Necessary for signing Windows binaries.
- name: Setup Java
uses: actions/setup-java@v4
with:
distribution: "zulu"
java-version: "11.0"
- name: Install nsis and zstd
run: sudo apt-get install -y nsis zstd
@@ -168,32 +130,6 @@ jobs:
AC_CERTIFICATE_PASSWORD: ${{ secrets.AC_CERTIFICATE_PASSWORD }}
AC_APIKEY_P8_BASE64: ${{ secrets.AC_APIKEY_P8_BASE64 }}
- name: Setup Windows EV Signing Certificate
run: |
set -euo pipefail
touch /tmp/ev_cert.pem
chmod 600 /tmp/ev_cert.pem
echo "$EV_SIGNING_CERT" > /tmp/ev_cert.pem
wget https://github.com/ebourg/jsign/releases/download/6.0/jsign-6.0.jar -O /tmp/jsign-6.0.jar
env:
EV_SIGNING_CERT: ${{ secrets.EV_SIGNING_CERT }}
- name: Test migrations from current ref to main
run: |
POSTGRES_VERSION=13 make test-migrations
# Setup GCloud for signing Windows binaries.
- name: Authenticate to Google Cloud
id: gcloud_auth
uses: google-github-actions/auth@v2
with:
workload_identity_provider: ${{ secrets.GCP_CODE_SIGNING_WORKLOAD_ID_PROVIDER }}
service_account: ${{ secrets.GCP_CODE_SIGNING_SERVICE_ACCOUNT }}
token_format: "access_token"
- name: Setup GCloud SDK
uses: "google-github-actions/setup-gcloud@v2"
- name: Build binaries
run: |
set -euo pipefail
@@ -208,26 +144,16 @@ jobs:
build/coder_helm_"$version".tgz \
build/provisioner_helm_"$version".tgz
env:
CODER_SIGN_WINDOWS: "1"
CODER_SIGN_DARWIN: "1"
AC_CERTIFICATE_FILE: /tmp/apple_cert.p12
AC_CERTIFICATE_PASSWORD_FILE: /tmp/apple_cert_password.txt
AC_APIKEY_ISSUER_ID: ${{ secrets.AC_APIKEY_ISSUER_ID }}
AC_APIKEY_ID: ${{ secrets.AC_APIKEY_ID }}
AC_APIKEY_FILE: /tmp/apple_apikey.p8
EV_KEY: ${{ secrets.EV_KEY }}
EV_KEYSTORE: ${{ secrets.EV_KEYSTORE }}
EV_TSA_URL: ${{ secrets.EV_TSA_URL }}
EV_CERTIFICATE_PATH: /tmp/ev_cert.pem
GCLOUD_ACCESS_TOKEN: ${{ steps.gcloud_auth.outputs.access_token }}
JSIGN_PATH: /tmp/jsign-6.0.jar
- name: Delete Apple Developer certificate and API key
run: rm -f /tmp/{apple_cert.p12,apple_cert_password.txt,apple_apikey.p8}
- name: Delete Windows EV Signing Cert
run: rm /tmp/ev_cert.pem
- name: Determine base image tag
id: image-base-tag
run: |
@@ -297,7 +223,7 @@ jobs:
# build Docker images for each architecture
version="$(./scripts/version.sh)"
make build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
make -j build/coder_"$version"_linux_{amd64,arm64,armv7}.tag
# we can't build multi-arch if the images aren't pushed, so quit now
# if dry-running
@@ -308,7 +234,7 @@ jobs:
# build and push multi-arch manifest, this depends on the other images
# being pushed so will automatically push them.
make push/build/coder_"$version"_linux.tag
make -j push/build/coder_"$version"_linux.tag
# if the current version is equal to the highest (according to semver)
# version in the repo, also create a multi-arch image as ":latest" and
@@ -335,9 +261,6 @@ jobs:
set -euo pipefail
publish_args=()
if [[ $CODER_RELEASE_CHANNEL == "stable" ]]; then
publish_args+=(--stable)
fi
if [[ $CODER_DRY_RUN == *t* ]]; then
publish_args+=(--dry-run)
fi
@@ -396,14 +319,14 @@ jobs:
./build/*.rpm
retention-days: 7
- name: Send repository-dispatch event
- name: Start Packer builds
if: ${{ !inputs.dry_run }}
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.CDRCI_GITHUB_TOKEN }}
repository: coder/packages
event-type: coder-release
client-payload: '{"coder_version": "${{ steps.version.outputs.version }}", "release_channel": "${{ inputs.release_channel }}"}'
client-payload: '{"coder_version": "${{ steps.version.outputs.version }}"}'
publish-homebrew:
name: Publish to Homebrew tap
@@ -488,7 +411,7 @@ jobs:
- name: Sync fork
run: gh repo sync cdrci/winget-pkgs -b master
env:
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
GH_TOKEN: ${{ secrets.WINGET_GH_TOKEN }}
- name: Checkout
uses: actions/checkout@v4
+14 -16
View File
@@ -23,19 +23,19 @@ concurrency:
jobs:
codeql:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
with:
languages: go, javascript
- name: Setup Go
uses: ./.github/actions/setup-go
# Workaround to prevent CodeQL from building the dashboard.
- name: Remove Makefile
run: |
@@ -56,7 +56,7 @@ jobs:
"${{ secrets.SLACK_SECURITY_FAILURE_WEBHOOK_URL }}"
trivy:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
runs-on: ${{ github.repository_owner == 'coder' && 'buildjet-8vcpu-ubuntu-2204' || 'ubuntu-latest' }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -113,8 +113,16 @@ jobs:
make -j "$image_job"
echo "image=$(cat "$image_job")" >> $GITHUB_OUTPUT
- name: Run Prisma Cloud image scan
uses: PaloAltoNetworks/prisma-cloud-scan@v1
with:
pcc_console_url: ${{ secrets.PRISMA_CLOUD_URL }}
pcc_user: ${{ secrets.PRISMA_CLOUD_ACCESS_KEY }}
pcc_pass: ${{ secrets.PRISMA_CLOUD_SECRET_KEY }}
image_name: ${{ steps.build.outputs.image }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@6e7b7d1fd3e4fef0c5fa8cce1229c54b2c9bd0d8
uses: aquasecurity/trivy-action@d43c1f16c00cfd3978dde6c07f4bbcf9eb6993ca
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
@@ -134,16 +142,6 @@ jobs:
path: trivy-results.sarif
retention-days: 7
# Prisma cloud scan runs last because it fails the entire job if it
# detects vulnerabilities. :|
- name: Run Prisma Cloud image scan
uses: PaloAltoNetworks/prisma-cloud-scan@v1
with:
pcc_console_url: ${{ secrets.PRISMA_CLOUD_URL }}
pcc_user: ${{ secrets.PRISMA_CLOUD_ACCESS_KEY }}
pcc_pass: ${{ secrets.PRISMA_CLOUD_SECRET_KEY }}
image_name: ${{ steps.build.outputs.image }}
- name: Send Slack notification on failure
if: ${{ failure() }}
run: |
+1 -4
View File
@@ -17,10 +17,7 @@ jobs:
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
# days-before-stale: 180
# essentially disabled for now while we work through polish issues
days-before-stale: 3650
days-before-stale: 180
# Pull Requests become stale more quickly due to merge conflicts.
# Also, we promote minimizing WIP.
days-before-pr-stale: 7
+1 -9
View File
@@ -14,14 +14,7 @@ darcula = "darcula"
Hashi = "Hashi"
trialer = "trialer"
encrypter = "encrypter"
# as in helsinki
hel = "hel"
# this is used as proto node
pn = "pn"
# typos doesn't like the EDE in TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA
EDE = "EDE"
# HELO is an SMTP command
HELO = "HELO"
hel = "hel" # as in helsinki
[files]
extend-exclude = [
@@ -39,5 +32,4 @@ extend-exclude = [
"**/pnpm-lock.yaml",
"tailnet/testdata/**",
"site/src/pages/SetupPage/countries.tsx",
"provisioner/terraform/testdata/**",
]
+1 -6
View File
@@ -4,11 +4,6 @@ on:
schedule:
- cron: "0 9 * * 1"
workflow_dispatch: # allows to run manually for testing
pull_request:
branches:
- main
paths:
- "docs/**"
jobs:
check-docs:
@@ -29,7 +24,7 @@ jobs:
file-path: "./README.md"
- name: Send Slack notification
if: failure() && github.event_name == 'schedule'
if: failure()
run: |
curl -X POST -H 'Content-type: application/json' -d '{"msg":"Broken links found in the documentation. Please check the logs at ${{ env.LOGS_URL }}"}' ${{ secrets.DOCS_LINK_SLACK_WEBHOOK }}
echo "Sent Slack notification"
-3
View File
@@ -68,6 +68,3 @@ result
# Filebrowser.db
**/filebrowser.db
# pnpm
.pnpm-store/
-5
View File
@@ -195,11 +195,6 @@ linters-settings:
- name: var-naming
- name: waitgroup-by-value
# irrelevant as of Go v1.22: https://go.dev/blog/loopvar-preview
govet:
disable:
- loopclosure
issues:
# Rules listed here: https://github.com/securego/gosec#available-rules
exclude-rules:
+12 -9
View File
@@ -71,21 +71,24 @@ result
# Filebrowser.db
**/filebrowser.db
# pnpm
.pnpm-store/
# .prettierignore.include:
# Helm templates contain variables that are invalid YAML and can't be formatted
# by Prettier.
helm/**/templates/*.yaml
# Terraform state files used in tests, these are automatically generated.
# Example: provisioner/terraform/testdata/instance-id/instance-id.tfstate.json
**/testdata/**/*.tf*.json
# Testdata shouldn't be formatted.
testdata/
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts
# Ignore generated files
**/pnpm-lock.yaml
**/*.gen.json
# Everything in site/ is formatted by Biome. For the rest of the repo though, we
# need broader language support.
site/
# Ignore generated JSON (e.g. examples/examples.gen.json).
**/*.gen.json
+12 -6
View File
@@ -2,13 +2,19 @@
# by Prettier.
helm/**/templates/*.yaml
# Terraform state files used in tests, these are automatically generated.
# Example: provisioner/terraform/testdata/instance-id/instance-id.tfstate.json
**/testdata/**/*.tf*.json
# Testdata shouldn't be formatted.
testdata/
scripts/apitypings/testdata/**/*.ts
enterprise/tailnet/testdata/*.golden.html
tailnet/testdata/*.golden.html
# Generated files shouldn't be formatted.
site/e2e/provisionerGenerated.ts
# Ignore generated files
**/pnpm-lock.yaml
**/*.gen.json
# Everything in site/ is formatted by Biome. For the rest of the repo though, we
# need broader language support.
site/
# Ignore generated JSON (e.g. examples/examples.gen.json).
**/*.gen.json
+3 -3
View File
@@ -4,13 +4,13 @@
printWidth: 80
proseWrap: always
trailingComma: all
useTabs: true
useTabs: false
tabWidth: 2
overrides:
- files:
- README.md
- docs/reference/api/**/*.md
- docs/reference/cli/**/*.md
- docs/api/**/*.md
- docs/cli/**/*.md
- docs/changelogs/*.md
- .github/**/*.{yaml,yml,toml}
- scripts/**/*.{yaml,yml,toml}
+13 -13
View File
@@ -1,15 +1,15 @@
{
"recommendations": [
"github.vscode-codeql",
"golang.go",
"hashicorp.terraform",
"esbenp.prettier-vscode",
"foxundermoon.shell-format",
"emeraldwalk.runonsave",
"zxh404.vscode-proto3",
"redhat.vscode-yaml",
"streetsidesoftware.code-spell-checker",
"EditorConfig.EditorConfig",
"biomejs.biome"
]
"recommendations": [
"github.vscode-codeql",
"golang.go",
"hashicorp.terraform",
"esbenp.prettier-vscode",
"foxundermoon.shell-format",
"emeraldwalk.runonsave",
"zxh404.vscode-proto3",
"redhat.vscode-yaml",
"streetsidesoftware.code-spell-checker",
"dbaeumer.vscode-eslint",
"EditorConfig.EditorConfig"
]
}
+223 -237
View File
@@ -1,239 +1,225 @@
{
"cSpell.words": [
"afero",
"agentsdk",
"apps",
"ASKPASS",
"authcheck",
"autostop",
"awsidentity",
"bodyclose",
"buildinfo",
"buildname",
"circbuf",
"cliflag",
"cliui",
"codecov",
"coderd",
"coderdenttest",
"coderdtest",
"codersdk",
"contravariance",
"cronstrue",
"databasefake",
"dbgen",
"dbmem",
"dbtype",
"DERP",
"derphttp",
"derpmap",
"devel",
"devtunnel",
"dflags",
"drpc",
"drpcconn",
"drpcmux",
"drpcserver",
"Dsts",
"embeddedpostgres",
"enablements",
"enterprisemeta",
"errgroup",
"eventsourcemock",
"externalauth",
"Failf",
"fatih",
"Formik",
"gitauth",
"gitsshkey",
"goarch",
"gographviz",
"goleak",
"gonet",
"gossh",
"gsyslog",
"GTTY",
"hashicorp",
"hclsyntax",
"httpapi",
"httpmw",
"idtoken",
"Iflag",
"incpatch",
"initialisms",
"ipnstate",
"isatty",
"Jobf",
"Keygen",
"kirsle",
"Kubernetes",
"ldflags",
"magicsock",
"manifoldco",
"mapstructure",
"mattn",
"mitchellh",
"moby",
"namesgenerator",
"namespacing",
"netaddr",
"netip",
"netmap",
"netns",
"netstack",
"nettype",
"nfpms",
"nhooyr",
"nmcfg",
"nolint",
"nosec",
"ntqry",
"OIDC",
"oneof",
"opty",
"paralleltest",
"parameterscopeid",
"pqtype",
"prometheusmetrics",
"promhttp",
"protobuf",
"provisionerd",
"provisionerdserver",
"provisionersdk",
"ptty",
"ptys",
"ptytest",
"quickstart",
"reconfig",
"replicasync",
"retrier",
"rpty",
"SCIM",
"sdkproto",
"sdktrace",
"Signup",
"slogtest",
"sourcemapped",
"spinbutton",
"Srcs",
"stdbuf",
"stretchr",
"STTY",
"stuntest",
"subpage",
"tailbroker",
"tailcfg",
"tailexchange",
"tailnet",
"tailnettest",
"Tailscale",
"tanstack",
"tbody",
"TCGETS",
"tcpip",
"TCSETS",
"templateversions",
"testdata",
"testid",
"testutil",
"tfexec",
"tfjson",
"tfplan",
"tfstate",
"thead",
"tios",
"tmpdir",
"tokenconfig",
"Topbar",
"tparallel",
"trialer",
"trimprefix",
"tsdial",
"tslogger",
"tstun",
"turnconn",
"typegen",
"typesafe",
"unconvert",
"Untar",
"Userspace",
"VMID",
"walkthrough",
"weblinks",
"webrtc",
"wgcfg",
"wgconfig",
"wgengine",
"wgmonitor",
"wgnet",
"workspaceagent",
"workspaceagents",
"workspaceapp",
"workspaceapps",
"workspacebuilds",
"workspacename",
"wsjson",
"xerrors",
"xlarge",
"xsmall",
"yamux"
],
"cSpell.ignorePaths": ["site/package.json", ".vscode/settings.json"],
"emeraldwalk.runonsave": {
"commands": [
{
"match": "database/queries/*.sql",
"cmd": "make gen"
},
{
"match": "provisionerd/proto/provisionerd.proto",
"cmd": "make provisionerd/proto/provisionerd.pb.go"
}
]
},
"search.exclude": {
"**.pb.go": true,
"**/*.gen.json": true,
"**/testdata/*": true,
"coderd/apidoc/**": true,
"docs/reference/api/*.md": true,
"docs/reference/cli/*.md": true,
"docs/templates/*.md": true,
"LICENSE": true,
"scripts/metricsdocgen/metrics": true,
"site/out/**": true,
"site/storybook-static/**": true,
"**.map": true,
"pnpm-lock.yaml": true
},
// Ensure files always have a newline.
"files.insertFinalNewline": true,
"go.lintTool": "golangci-lint",
"go.lintFlags": ["--fast"],
"go.coverageDecorator": {
"type": "gutter",
"coveredGutterStyle": "blockgreen",
"uncoveredGutterStyle": "blockred"
},
// The codersdk is used by coderd another other packages extensively.
// To reduce redundancy in tests, it's covered by other packages.
// Since package coverage pairing can't be defined, all packages cover
// all other packages.
"go.testFlags": ["-short", "-coverpkg=./..."],
// We often use a version of TypeScript that's ahead of the version shipped
// with VS Code.
"typescript.tsdk": "./site/node_modules/typescript/lib",
// Playwright tests in VSCode will open a browser to live "view" the test.
"playwright.reuseBrowser": true,
"[javascript][javascriptreact][json][jsonc][typescript][typescriptreact]": {
"editor.defaultFormatter": "biomejs.biome"
// "editor.codeActionsOnSave": {
// "source.organizeImports.biome": "explicit"
// }
},
"[css][html][markdown][yaml]": {
"editor.defaultFormatter": "esbenp.prettier-vscode"
}
"cSpell.words": [
"afero",
"agentsdk",
"apps",
"ASKPASS",
"authcheck",
"autostop",
"awsidentity",
"bodyclose",
"buildinfo",
"buildname",
"circbuf",
"cliflag",
"cliui",
"codecov",
"coderd",
"coderdenttest",
"coderdtest",
"codersdk",
"contravariance",
"cronstrue",
"databasefake",
"dbgen",
"dbmem",
"dbtype",
"DERP",
"derphttp",
"derpmap",
"devel",
"devtunnel",
"dflags",
"drpc",
"drpcconn",
"drpcmux",
"drpcserver",
"Dsts",
"embeddedpostgres",
"enablements",
"enterprisemeta",
"errgroup",
"eventsourcemock",
"externalauth",
"Failf",
"fatih",
"Formik",
"gitauth",
"gitsshkey",
"goarch",
"gographviz",
"goleak",
"gonet",
"gossh",
"gsyslog",
"GTTY",
"hashicorp",
"hclsyntax",
"httpapi",
"httpmw",
"idtoken",
"Iflag",
"incpatch",
"initialisms",
"ipnstate",
"isatty",
"Jobf",
"Keygen",
"kirsle",
"Kubernetes",
"ldflags",
"magicsock",
"manifoldco",
"mapstructure",
"mattn",
"mitchellh",
"moby",
"namesgenerator",
"namespacing",
"netaddr",
"netip",
"netmap",
"netns",
"netstack",
"nettype",
"nfpms",
"nhooyr",
"nmcfg",
"nolint",
"nosec",
"ntqry",
"OIDC",
"oneof",
"opty",
"paralleltest",
"parameterscopeid",
"pqtype",
"prometheusmetrics",
"promhttp",
"protobuf",
"provisionerd",
"provisionerdserver",
"provisionersdk",
"ptty",
"ptys",
"ptytest",
"quickstart",
"reconfig",
"replicasync",
"retrier",
"rpty",
"SCIM",
"sdkproto",
"sdktrace",
"Signup",
"slogtest",
"sourcemapped",
"Srcs",
"stdbuf",
"stretchr",
"STTY",
"stuntest",
"tailbroker",
"tailcfg",
"tailexchange",
"tailnet",
"tailnettest",
"Tailscale",
"tanstack",
"tbody",
"TCGETS",
"tcpip",
"TCSETS",
"templateversions",
"testdata",
"testid",
"testutil",
"tfexec",
"tfjson",
"tfplan",
"tfstate",
"thead",
"tios",
"tmpdir",
"tokenconfig",
"Topbar",
"tparallel",
"trialer",
"trimprefix",
"tsdial",
"tslogger",
"tstun",
"turnconn",
"typegen",
"typesafe",
"unconvert",
"Untar",
"Userspace",
"VMID",
"walkthrough",
"weblinks",
"webrtc",
"wgcfg",
"wgconfig",
"wgengine",
"wgmonitor",
"wgnet",
"workspaceagent",
"workspaceagents",
"workspaceapp",
"workspaceapps",
"workspacebuilds",
"workspacename",
"wsjson",
"xerrors",
"xlarge",
"xsmall",
"yamux"
],
"cSpell.ignorePaths": ["site/package.json", ".vscode/settings.json"],
"emeraldwalk.runonsave": {
"commands": [
{
"match": "database/queries/*.sql",
"cmd": "make gen"
},
{
"match": "provisionerd/proto/provisionerd.proto",
"cmd": "make provisionerd/proto/provisionerd.pb.go"
}
]
},
"eslint.workingDirectories": ["./site"],
"search.exclude": {
"**.pb.go": true,
"**/*.gen.json": true,
"**/testdata/*": true,
"**Generated.ts": true,
"coderd/apidoc/**": true,
"docs/api/*.md": true,
"docs/templates/*.md": true,
"LICENSE": true,
"scripts/metricsdocgen/metrics": true,
"site/out/**": true,
"site/storybook-static/**": true,
"**.map": true,
"pnpm-lock.yaml": true
},
// Ensure files always have a newline.
"files.insertFinalNewline": true,
"go.lintTool": "golangci-lint",
"go.lintFlags": ["--fast"],
"go.coverageDecorator": {
"type": "gutter",
"coveredGutterStyle": "blockgreen",
"uncoveredGutterStyle": "blockred"
},
// The codersdk is used by coderd another other packages extensively.
// To reduce redundancy in tests, it's covered by other packages.
// Since package coverage pairing can't be defined, all packages cover
// all other packages.
"go.testFlags": ["-short", "-coverpkg=./..."],
// We often use a version of TypeScript that's ahead of the version shipped
// with VS Code.
"typescript.tsdk": "./site/node_modules/typescript/lib"
}
+86 -112
View File
@@ -36,7 +36,6 @@ GOOS := $(shell go env GOOS)
GOARCH := $(shell go env GOARCH)
GOOS_BIN_EXT := $(if $(filter windows, $(GOOS)),.exe,)
VERSION := $(shell ./scripts/version.sh)
POSTGRES_VERSION ?= 16
# Use the highest ZSTD compression level in CI.
ifdef CI
@@ -57,9 +56,6 @@ GO_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.go' -not -nam
# All the shell files in the repo, excluding ignored files.
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
# Ensure we don't use the user's git configs which might cause side-effects
GIT_FLAGS = GIT_CONFIG_GLOBAL=/dev/null GIT_CONFIG_SYSTEM=/dev/null
# All ${OS}_${ARCH} combos we build for. Windows binaries have the .exe suffix.
OS_ARCHES := \
linux_amd64 linux_arm64 linux_armv7 \
@@ -204,8 +200,7 @@ endef
# calling this manually.
$(CODER_ALL_BINARIES): go.mod go.sum \
$(GO_SRC_FILES) \
$(shell find ./examples/templates) \
site/static/error.html
$(shell find ./examples/templates)
$(get-mode-os-arch-ext)
if [[ "$$os" != "windows" ]] && [[ "$$ext" != "" ]]; then
@@ -366,8 +361,6 @@ $(foreach chart,$(charts),build/$(chart)_helm_$(VERSION).tgz): build/%_helm_$(VE
site/out/index.html: site/package.json $(shell find ./site $(FIND_EXCLUSIONS) -type f \( -name '*.ts' -o -name '*.tsx' \))
cd site
# prevents this directory from getting to big, and causing "too much data" errors
rm -rf out/assets/
../scripts/pnpm_install.sh
pnpm build
@@ -387,49 +380,32 @@ install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
cp "$<" "$$output_file"
.PHONY: install
BOLD := $(shell tput bold 2>/dev/null)
GREEN := $(shell tput setaf 2 2>/dev/null)
RESET := $(shell tput sgr0 2>/dev/null)
fmt: fmt/ts fmt/go fmt/terraform fmt/shfmt fmt/prettier
fmt: fmt/prettier fmt/terraform fmt/shfmt fmt/go
.PHONY: fmt
fmt/go:
go mod tidy
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/go$(RESET)"
# VS Code users should check out
# https://github.com/mvdan/gofumpt#visual-studio-code
go run mvdan.cc/gofumpt@v0.4.0 -w -l .
.PHONY: fmt/go
fmt/ts:
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/ts$(RESET)"
fmt/prettier:
echo "--- prettier"
cd site
# Avoid writing files in CI to reduce file write activity
ifdef CI
pnpm run check --linter-enabled=false
else
pnpm run check:fix
endif
.PHONY: fmt/ts
fmt/prettier: .prettierignore
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/prettier$(RESET)"
# Avoid writing files in CI to reduce file write activity
ifdef CI
pnpm run format:check
else
pnpm run format
pnpm run format:write
endif
.PHONY: fmt/prettier
fmt/terraform: $(wildcard *.tf)
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/terraform$(RESET)"
terraform fmt -recursive
.PHONY: fmt/terraform
fmt/shfmt: $(SHELL_SRC_FILES)
echo "$(GREEN)==>$(RESET) $(BOLD)fmt/shfmt$(RESET)"
echo "--- shfmt"
# Only do diff check in CI, errors on diff.
ifdef CI
shfmt -d $(SHELL_SRC_FILES)
@@ -438,7 +414,7 @@ else
endif
.PHONY: fmt/shfmt
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons
lint: lint/shellcheck lint/go lint/ts lint/helm lint/site-icons
.PHONY: lint
lint/site-icons:
@@ -447,20 +423,16 @@ lint/site-icons:
lint/ts:
cd site
pnpm lint
pnpm i && pnpm lint
.PHONY: lint/ts
lint/go:
./scripts/check_enterprise_imports.sh
./scripts/check_codersdk_imports.sh
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/contents/Dockerfile | cut -d '=' -f 2)
go run github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver run
linter_ver=$(shell egrep -o 'GOLANGCI_LINT_VERSION=\S+' dogfood/Dockerfile | cut -d '=' -f 2)
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v$$linter_ver
golangci-lint run
.PHONY: lint/go
lint/examples:
go run ./scripts/examplegen/main.go -lint
.PHONY: lint/examples
# Use shfmt to determine the shell files, takes editorconfig into consideration.
lint/shellcheck: $(SHELL_SRC_FILES)
echo "--- shellcheck"
@@ -492,15 +464,15 @@ gen: \
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
coderd/rbac/object_gen.go \
codersdk/rbacresources_gen.go \
site/src/api/rbacresourcesGenerated.ts \
docs/admin/prometheus.md \
docs/reference/cli/README.md \
docs/cli.md \
docs/admin/audit-logs.md \
coderd/apidoc/swagger.json \
.prettierignore.include \
.prettierignore \
provisioner/terraform/testdata/version \
site/.prettierrc.yaml \
site/.prettierignore \
site/.eslintignore \
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json \
@@ -521,14 +493,15 @@ gen/mark-fresh:
$(DB_GEN_FILES) \
site/src/api/typesGenerated.ts \
coderd/rbac/object_gen.go \
codersdk/rbacresources_gen.go \
site/src/api/rbacresourcesGenerated.ts \
docs/admin/prometheus.md \
docs/reference/cli/README.md \
docs/cli.md \
docs/admin/audit-logs.md \
coderd/apidoc/swagger.json \
.prettierignore.include \
.prettierignore \
site/.prettierrc.yaml \
site/.prettierignore \
site/.eslintignore \
site/e2e/provisionerGenerated.ts \
site/src/theme/icons.json \
examples/examples.gen.json \
@@ -562,9 +535,6 @@ coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $
coderd/database/dbmock/dbmock.go: coderd/database/db.go coderd/database/querier.go
go generate ./coderd/database/dbmock/
coderd/database/pubsub/psmock/psmock.go: coderd/database/pubsub/pubsub.go
go generate ./coderd/database/pubsub/psmock
tailnet/tailnettest/coordinatormock.go tailnet/tailnettest/multiagentmock.go tailnet/tailnettest/coordinateemock.go: tailnet/coordinator.go tailnet/multiagent.go
go generate ./tailnet/tailnettest/
@@ -602,7 +572,7 @@ provisionerd/proto/provisionerd.pb.go: provisionerd/proto/provisionerd.proto
site/src/api/typesGenerated.ts: $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
go run ./scripts/apitypings/ > $@
./scripts/pnpm_install.sh
pnpm run format:write:only "$@"
site/e2e/provisionerGenerated.ts: provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
cd site
@@ -611,41 +581,29 @@ site/e2e/provisionerGenerated.ts: provisionerd/proto/provisionerd.pb.go provisio
site/src/theme/icons.json: $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
./scripts/pnpm_install.sh
pnpm -C site/ exec biome format --write src/theme/icons.json
pnpm run format:write:only "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
go run ./scripts/examplegen/main.go > examples/examples.gen.json
coderd/rbac/object_gen.go: scripts/rbacgen/rbacobject.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/rbacgen/main.go rbac > coderd/rbac/object_gen.go
codersdk/rbacresources_gen.go: scripts/rbacgen/codersdk.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/rbacgen/main.go codersdk > codersdk/rbacresources_gen.go
site/src/api/rbacresourcesGenerated.ts: scripts/rbacgen/codersdk.gotmpl scripts/rbacgen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/rbacgen/main.go typescript > "$@"
coderd/rbac/object_gen.go: scripts/rbacgen/main.go coderd/rbac/object.go
go run scripts/rbacgen/main.go ./coderd/rbac > coderd/rbac/object_gen.go
docs/admin/prometheus.md: scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics
go run scripts/metricsdocgen/main.go
./scripts/pnpm_install.sh
pnpm exec prettier --write ./docs/admin/prometheus.md
pnpm run format:write:only ./docs/admin/prometheus.md
docs/reference/cli/README.md: scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES)
docs/cli.md: scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES)
CI=true BASE_PATH="." go run ./scripts/clidocgen
./scripts/pnpm_install.sh
pnpm exec prettier --write ./docs/reference/cli/README.md ./docs/reference/cli/*.md ./docs/manifest.json
pnpm run format:write:only ./docs/cli.md ./docs/cli/*.md ./docs/manifest.json
docs/admin/audit-logs.md: coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go
go run scripts/auditdocgen/main.go
./scripts/pnpm_install.sh
pnpm exec prettier --write ./docs/admin/audit-logs.md
pnpm run format:write:only ./docs/admin/audit-logs.md
coderd/apidoc/swagger.json: $(shell find ./scripts/apidocgen $(FIND_EXCLUSIONS) -type f) $(wildcard coderd/*.go) $(wildcard enterprise/coderd/*.go) $(wildcard codersdk/*.go) $(wildcard enterprise/wsproxy/wsproxysdk/*.go) $(DB_GEN_FILES) .swaggo docs/manifest.json coderd/rbac/object_gen.go
./scripts/apidocgen/generate.sh
./scripts/pnpm_install.sh
pnpm exec prettier --write ./docs/reference/api ./docs/manifest.json ./coderd/apidoc/swagger.json
pnpm run format:write:only ./docs/api ./docs/manifest.json ./coderd/apidoc/swagger.json
update-golden-files: \
cli/testdata/.gen-golden \
@@ -660,7 +618,7 @@ update-golden-files: \
.PHONY: update-golden-files
cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) $(wildcard cli/*_test.go)
go test ./cli -run="Test(CommandHelp|ServerYAML|ErrorExamples)" -update
go test ./cli -run="Test(CommandHelp|ServerYAML)" -update
touch "$@"
enterprise/cli/testdata/.gen-golden: $(wildcard enterprise/cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) $(wildcard enterprise/cli/*_test.go)
@@ -691,16 +649,27 @@ provisioner/terraform/testdata/.gen-golden: $(wildcard provisioner/terraform/tes
go test ./provisioner/terraform -run="Test.*Golden$$" -update
touch "$@"
provisioner/terraform/testdata/version:
if [[ "$(shell cat provisioner/terraform/testdata/version.txt)" != "$(shell terraform version -json | jq -r '.terraform_version')" ]]; then
./provisioner/terraform/testdata/generate.sh
fi
.PHONY: provisioner/terraform/testdata/version
scripts/ci-report/testdata/.gen-golden: $(wildcard scripts/ci-report/testdata/*) $(wildcard scripts/ci-report/*.go)
go test ./scripts/ci-report -run=TestOutputMatchesGoldenFile -update
touch "$@"
# Generate a prettierrc for the site package that uses relative paths for
# overrides. This allows us to share the same prettier config between the
# site and the root of the repo.
site/.prettierrc.yaml: .prettierrc.yaml
. ./scripts/lib.sh
dependencies yq
echo "# Code generated by Makefile (../$<). DO NOT EDIT." > "$@"
echo "" >> "$@"
# Replace all listed override files with relative paths inside site/.
# - ./ -> ../
# - ./site -> ./
yq \
'.overrides[].files |= map(. | sub("^./"; "") | sub("^"; "../") | sub("../site/"; "./") | sub("../!"; "!../"))' \
"$<" >> "$@"
# Combine .gitignore with .prettierignore.include to generate .prettierignore.
.prettierignore: .gitignore .prettierignore.include
echo "# Code generated by Makefile ($^). DO NOT EDIT." > "$@"
@@ -710,8 +679,42 @@ scripts/ci-report/testdata/.gen-golden: $(wildcard scripts/ci-report/testdata/*)
cat "$$f" >> "$@"
done
# Generate ignore files based on gitignore into the site directory. We turn all
# rules into relative paths for the `site/` directory (where applicable),
# following the pattern format defined by git:
# https://git-scm.com/docs/gitignore#_pattern_format
#
# This is done for compatibility reasons, see:
# https://github.com/prettier/prettier/issues/8048
# https://github.com/prettier/prettier/issues/8506
# https://github.com/prettier/prettier/issues/8679
site/.eslintignore site/.prettierignore: .prettierignore Makefile
rm -f "$@"
touch "$@"
# Skip generated by header, inherit `.prettierignore` header as-is.
while read -r rule; do
# Remove leading ! if present to simplify rule, added back at the end.
tmp="$${rule#!}"
ignore="$${rule%"$$tmp"}"
rule="$$tmp"
case "$$rule" in
# Comments or empty lines (include).
\#*|'') ;;
# Generic rules (include).
\*\**) ;;
# Site prefixed rules (include).
site/*) rule="$${rule#site/}";;
./site/*) rule="$${rule#./site/}";;
# Rules that are non-generic and don't start with site (rewrite).
/*) rule=.."$$rule";;
*/?*) rule=../"$$rule";;
*) ;;
esac
echo "$${ignore}$${rule}" >> "$@"
done < "$<"
test:
$(GIT_FLAGS) gotestsum --format standard-quiet -- -v -short -count=1 ./...
gotestsum --format standard-quiet -- -v -short -count=1 ./...
.PHONY: test
# sqlc-cloud-is-setup will fail if no SQLc auth token is set. Use this as a
@@ -747,7 +750,7 @@ sqlc-vet: test-postgres-docker
test-postgres: test-postgres-docker
# The postgres test is prone to failure, so we limit parallelism for
# more consistent execution.
$(GIT_FLAGS) DB=ci DB_FROM=$(shell go run scripts/migrate-ci/main.go) gotestsum \
DB=ci DB_FROM=$(shell go run scripts/migrate-ci/main.go) gotestsum \
--junitfile="gotests.xml" \
--jsonfile="gotests.json" \
--packages="./..." -- \
@@ -756,20 +759,8 @@ test-postgres: test-postgres-docker
-count=1
.PHONY: test-postgres
test-migrations: test-postgres-docker
echo "--- test migrations"
set -euo pipefail
COMMIT_FROM=$(shell git log -1 --format='%h' HEAD)
echo "COMMIT_FROM=$${COMMIT_FROM}"
COMMIT_TO=$(shell git log -1 --format='%h' origin/main)
echo "COMMIT_TO=$${COMMIT_TO}"
if [[ "$${COMMIT_FROM}" == "$${COMMIT_TO}" ]]; then echo "Nothing to do!"; exit 0; fi
echo "DROP DATABASE IF EXISTS migrate_test_$${COMMIT_FROM}; CREATE DATABASE migrate_test_$${COMMIT_FROM};" | psql 'postgresql://postgres:postgres@localhost:5432/postgres?sslmode=disable'
go run ./scripts/migrate-test/main.go --from="$$COMMIT_FROM" --to="$$COMMIT_TO" --postgres-url="postgresql://postgres:postgres@localhost:5432/migrate_test_$${COMMIT_FROM}?sslmode=disable"
# NOTE: we set --memory to the same size as a GitHub runner.
test-postgres-docker:
docker rm -f test-postgres-docker-${POSTGRES_VERSION} || true
docker rm -f test-postgres-docker || true
docker run \
--env POSTGRES_PASSWORD=postgres \
--env POSTGRES_USER=postgres \
@@ -777,11 +768,10 @@ test-postgres-docker:
--env PGDATA=/tmp \
--tmpfs /tmp \
--publish 5432:5432 \
--name test-postgres-docker-${POSTGRES_VERSION} \
--name test-postgres-docker \
--restart no \
--detach \
--memory 16GB \
gcr.io/coder-dev-1/postgres:${POSTGRES_VERSION} \
gcr.io/coder-dev-1/postgres:13 \
-c shared_buffers=1GB \
-c work_mem=1GB \
-c effective_cache_size=1GB \
@@ -799,28 +789,12 @@ test-postgres-docker:
# Make sure to keep this in sync with test-go-race from .github/workflows/ci.yaml.
test-race:
$(GIT_FLAGS) gotestsum --junitfile="gotests.xml" -- -race -count=1 ./...
gotestsum --junitfile="gotests.xml" -- -race -count=1 ./...
.PHONY: test-race
test-tailnet-integration:
env \
CODER_TAILNET_TESTS=true \
CODER_MAGICSOCK_DEBUG_LOGGING=true \
TS_DEBUG_NETCHECK=true \
GOTRACEBACK=single \
go test \
-exec "sudo -E" \
-timeout=5m \
-count=1 \
./tailnet/test/integration
# Note: we used to add this to the test target, but it's not necessary and we can
# achieve the desired result by specifying -count=1 in the go test invocation
# instead. Keeping it here for convenience.
test-clean:
go clean -testcache
.PHONY: test-clean
.PHONY: test-e2e
test-e2e:
cd ./site && DEBUG=pw:api pnpm playwright:test --forbid-only --workers 1
+30 -33
View File
@@ -20,17 +20,18 @@
<br>
<br>
[Quickstart](#quickstart) | [Docs](https://coder.com/docs) | [Why Coder](https://coder.com/why) | [Enterprise](https://coder.com/docs/enterprise)
[Quickstart](#quickstart) | [Docs](https://coder.com/docs) | [Why Coder](https://coder.com/why) | [Enterprise](https://coder.com/docs/v2/latest/enterprise)
[![discord](https://img.shields.io/discord/747933592273027093?label=discord)](https://discord.gg/coder)
[![codecov](https://codecov.io/gh/coder/coder/branch/main/graph/badge.svg?token=TNLW3OAP6G)](https://codecov.io/gh/coder/coder)
[![release](https://img.shields.io/github/v/release/coder/coder)](https://github.com/coder/coder/releases/latest)
[![godoc](https://pkg.go.dev/badge/github.com/coder/coder.svg)](https://pkg.go.dev/github.com/coder/coder)
[![Go Report Card](https://goreportcard.com/badge/github.com/coder/coder/v2)](https://goreportcard.com/report/github.com/coder/coder/v2)
[![Go Report Card](https://goreportcard.com/badge/github.com/coder/coder)](https://goreportcard.com/report/github.com/coder/coder)
[![license](https://img.shields.io/github/license/coder/coder)](./LICENSE)
</div>
[Coder](https://coder.com) enables organizations to set up development environments in their public or private cloud infrastructure. Cloud development environments are defined with Terraform, connected through a secure high-speed Wireguard® tunnel, and automatically shut down when not used to save on costs. Coder gives engineering teams the flexibility to use the cloud for workloads most beneficial to them.
[Coder](https://coder.com) enables organizations to set up development environments in their public or private cloud infrastructure. Cloud development environments are defined with Terraform, connected through a secure high-speed Wireguard® tunnel, and are automatically shut down when not in use to save on costs. Coder gives engineering teams the flexibility to use the cloud for workloads that are most beneficial to them.
- Define cloud development environments in Terraform
- EC2 VMs, Kubernetes Pods, Docker Containers, etc.
@@ -52,8 +53,8 @@ curl -L https://coder.com/install.sh | sh
# Start the Coder server (caches data in ~/.cache/coder)
coder server
# Navigate to http://localhost:3000 to create your initial user,
# create a Docker template and provision a workspace
# Navigate to http://localhost:3000 to create your initial user
# Create a Docker template, and provision a workspace
```
## Install
@@ -67,11 +68,11 @@ Releases.
curl -L https://coder.com/install.sh | sh
```
You can run the install script with `--dry-run` to see the commands that will be used to install without executing them. Run the install script with `--help` for additional flags.
You can run the install script with `--dry-run` to see the commands that will be used to install without executing them. You can modify the installation process by including flags. Run the install script with `--help` for reference.
> See [install](https://coder.com/docs/install) for additional methods.
> See [install](https://coder.com/docs/v2/latest/install) for additional methods.
Once installed, you can start a production deployment with a single command:
Once installed, you can start a production deployment<sup>1</sup> with a single command:
```shell
# Automatically sets up an external access URL on *.try.coder.app
@@ -81,48 +82,44 @@ coder server
coder server --postgres-url <url> --access-url <url>
```
Use `coder --help` to get a list of flags and environment variables. Use our [install guides](https://coder.com/docs/install) for a complete walkthrough.
> <sup>1</sup> For production deployments, set up an external PostgreSQL instance for reliability.
Use `coder --help` to get a list of flags and environment variables. Use our [install guides](https://coder.com/docs/v2/latest/install) for a full walkthrough.
## Documentation
Browse our docs [here](https://coder.com/docs) or visit a specific section below:
Browse our docs [here](https://coder.com/docs/v2) or visit a specific section below:
- [**Templates**](https://coder.com/docs/templates): Templates are written in Terraform and describe the infrastructure for workspaces
- [**Workspaces**](https://coder.com/docs/workspaces): Workspaces contain the IDEs, dependencies, and configuration information needed for software development
- [**IDEs**](https://coder.com/docs/ides): Connect your existing editor to a workspace
- [**Administration**](https://coder.com/docs/admin): Learn how to operate Coder
- [**Enterprise**](https://coder.com/docs/enterprise): Learn about our paid features built for large teams
- [**Templates**](https://coder.com/docs/v2/latest/templates): Templates are written in Terraform and describe the infrastructure for workspaces
- [**Workspaces**](https://coder.com/docs/v2/latest/workspaces): Workspaces contain the IDEs, dependencies, and configuration information needed for software development
- [**IDEs**](https://coder.com/docs/v2/latest/ides): Connect your existing editor to a workspace
- [**Administration**](https://coder.com/docs/v2/latest/admin): Learn how to operate Coder
- [**Enterprise**](https://coder.com/docs/v2/latest/enterprise): Learn about our paid features built for large teams
## Support
## Community and Support
Feel free to [open an issue](https://github.com/coder/coder/issues/new) if you have questions, run into bugs, or have a feature request.
[Join our Discord](https://discord.gg/coder) to provide feedback on in-progress features and chat with the community using Coder!
[Join our Discord](https://discord.gg/coder) or [Slack](https://cdr.co/join-community) to provide feedback on in-progress features, and chat with the community using Coder!
## Integrations
## Contributing
We are always working on new integrations. Please feel free to open an issue and ask for an integration. Contributions are welcome in any official or community repositories.
Contributions are welcome! Read the [contributing docs](https://coder.com/docs/v2/latest/CONTRIBUTING) to get started.
Find our list of contributors [here](https://github.com/coder/coder/graphs/contributors).
## Related
We are always working on new integrations. Feel free to open an issue to request an integration. Contributions are welcome in any official or community repositories.
### Official
- [**VS Code Extension**](https://marketplace.visualstudio.com/items?itemName=coder.coder-remote): Open any Coder workspace in VS Code with a single click
- [**JetBrains Gateway Extension**](https://plugins.jetbrains.com/plugin/19620-coder): Open any Coder workspace in JetBrains Gateway with a single click
- [**Dev Container Builder**](https://github.com/coder/envbuilder): Build development environments using `devcontainer.json` on Docker, Kubernetes, and OpenShift
- [**Module Registry**](https://registry.coder.com): Extend development environments with common use-cases
- [**Kubernetes Log Stream**](https://github.com/coder/coder-logstream-kube): Stream Kubernetes Pod events to the Coder startup logs
- [**Self-Hosted VS Code Extension Marketplace**](https://github.com/coder/code-marketplace): A private extension marketplace that works in restricted or airgapped networks integrating with [code-server](https://github.com/coder/code-server).
### Community
- [**Provision Coder with Terraform**](https://github.com/ElliotG/coder-oss-tf): Provision Coder on Google GKE, Azure AKS, AWS EKS, DigitalOcean DOKS, IBMCloud K8s, OVHCloud K8s, and Scaleway K8s Kapsule with Terraform
- [**Coder Template GitHub Action**](https://github.com/marketplace/actions/update-coder-template): A GitHub Action that updates Coder templates
## Contributing
We are always happy to see new contributors to Coder. If you are new to the Coder codebase, we have
[a guide on how to get started](https://coder.com/docs/CONTRIBUTING). We'd love to see your
contributions!
## Hiring
Apply [here](https://jobs.ashbyhq.com/coder?utm_source=github&utm_medium=readme&utm_campaign=unknown) if you're interested in joining our team.
- [**Coder GitHub Action**](https://github.com/marketplace/actions/update-coder-template): A GitHub Action that updates Coder templates
- [**Various Templates**](./examples/templates/community-templates.md): Hetzner Cloud, Docker in Docker, and other templates the community has built.
+525 -1064
View File
File diff suppressed because it is too large Load Diff
+107 -482
View File
@@ -5,9 +5,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httptest"
@@ -46,16 +46,14 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentproc/agentproctest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
@@ -87,11 +85,11 @@ func TestAgent_Stats_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
var s *proto.Stats
var s *agentsdk.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSSH == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
@@ -113,18 +111,18 @@ func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
require.NoError(t, err)
defer ptyConn.Close()
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
Data: "echo test\r\n",
})
require.NoError(t, err)
_, err = ptyConn.Write(data)
require.NoError(t, err)
var s *proto.Stats
var s *agentsdk.Stats
require.Eventuallyf(t, func() bool {
var ok bool
s, ok = <-stats
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPty == 1
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountReconnectingPTY == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats: %+v", s,
)
@@ -179,14 +177,14 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats: ok=%t, ConnectionCount=%d, RxBytes=%d, TxBytes=%d, SessionCountVSCode=%d, ConnectionMedianLatencyMS=%f",
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVscode, s.ConnectionMedianLatencyMs)
ok, s.ConnectionCount, s.RxBytes, s.TxBytes, s.SessionCountVSCode, s.ConnectionMedianLatencyMS)
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 &&
// Ensure that the connection didn't count as a "normal" SSH session.
// This was a special one, so it should be labeled specially in the stats!
s.SessionCountVscode == 1 &&
s.SessionCountVSCode == 1 &&
// Ensure that connection latency is being counted!
// If it isn't, it's set to -1.
s.ConnectionMedianLatencyMs >= 0
s.ConnectionMedianLatencyMS >= 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats",
)
@@ -245,9 +243,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
ok, s.ConnectionCount, s.SessionCountJetbrains)
ok, s.ConnectionCount, s.SessionCountJetBrains)
return ok && s.ConnectionCount > 0 &&
s.SessionCountJetbrains == 1
s.SessionCountJetBrains == 1
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats with conn open",
)
@@ -260,9 +258,9 @@ func TestAgent_Stats_Magic(t *testing.T) {
require.Eventuallyf(t, func() bool {
s, ok := <-stats
t.Logf("got stats after disconnect %t, %d",
ok, s.SessionCountJetbrains)
ok, s.SessionCountJetBrains)
return ok &&
s.SessionCountJetbrains == 0
s.SessionCountJetBrains == 0
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes",
)
@@ -282,91 +280,6 @@ func TestAgent_SessionExec(t *testing.T) {
require.Equal(t, "test", strings.TrimSpace(string(output)))
}
//nolint:tparallel // Sub tests need to run sequentially.
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
t.Parallel()
tmpdir := t.TempDir()
// Defined by the coder script runner, hardcoded here since we don't
// have a reference to it.
scriptBinDir := filepath.Join(tmpdir, "coder-script-data", "bin")
manifest := agentsdk.Manifest{
EnvironmentVariables: map[string]string{
"MY_MANIFEST": "true",
"MY_OVERRIDE": "false",
"MY_SESSION_MANIFEST": "false",
},
}
banner := codersdk.ServiceBannerConfig{}
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
opts.ScriptDataDir = tmpdir
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
})
err := session.Setenv("MY_SESSION_MANIFEST", "true")
require.NoError(t, err)
err = session.Setenv("MY_SESSION", "true")
require.NoError(t, err)
command := "sh"
echoEnv := func(t *testing.T, w io.Writer, env string) {
if runtime.GOOS == "windows" {
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
require.NoError(t, err)
} else {
_, err := fmt.Fprintf(w, "echo $%s\n", env)
require.NoError(t, err)
}
}
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
stdin, err := session.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
err = session.Start(command)
require.NoError(t, err)
// Context is fine here since we're not doing a parallel subtest.
ctx := testutil.Context(t, testutil.WaitLong)
go func() {
<-ctx.Done()
_ = session.Close()
}()
s := bufio.NewScanner(stdout)
//nolint:paralleltest // These tests need to run sequentially.
for k, partialV := range map[string]string{
"CODER": "true", // From the agent.
"MY_MANIFEST": "true", // From the manifest.
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
"MY_SESSION": "true", // From the session.
"PATH": scriptBinDir + string(filepath.ListSeparator),
} {
t.Run(k, func(t *testing.T) {
echoEnv(t, stdin, k)
// Windows is unreliable, so keep scanning until we find a match.
for s.Scan() {
got := strings.TrimSpace(s.Text())
t.Logf("%s=%s", k, got)
if strings.Contains(got, partialV) {
break
}
}
if err := s.Err(); !errors.Is(err, io.EOF) {
require.NoError(t, err)
}
})
}
}
func TestAgent_GitSSH(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
@@ -614,12 +527,12 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner.
ready := make(chan struct{}, 2)
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return []codersdk.BannerConfig{test.banner}, nil
return test.banner, nil
})
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
@@ -838,7 +751,7 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
var ll net.Listener
var err error
for {
randomPort = testutil.RandomPortNoListen(t)
randomPort = pickRandomPort()
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
@@ -970,99 +883,6 @@ func TestAgent_SCP(t *testing.T) {
require.NoError(t, err)
}
func TestAgent_FileTransferBlocked(t *testing.T) {
t.Parallel()
assertFileTransferBlocked := func(t *testing.T, errorMessage string) {
// NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results
// in stopping the client in different phases, and returning different errors:
// - client read the full error message: File transfer has been disabled.
// - client's stream was terminated before reading the error message: EOF
// - client just read the error code (Windows): Process exited with status 65
isErr := strings.Contains(errorMessage, agentssh.BlockedFileTransferErrorMessage) ||
strings.Contains(errorMessage, "EOF") ||
strings.Contains(errorMessage, "Process exited with status 65")
require.True(t, isErr, fmt.Sprintf("Message: "+errorMessage))
}
t.Run("SFTP", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sftp.NewClient(sshClient)
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
})
t.Run("SCP with go-scp package", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient)
require.NoError(t, err)
defer scpClient.Close()
tempFile := filepath.Join(t.TempDir(), "scp")
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())
})
t.Run("Forbidden commands", func(t *testing.T) {
t.Parallel()
for _, c := range agentssh.BlockedFileTransferCommands {
t.Run(c, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)
//nolint:govet // we don't need `c := c` in Go 1.22
err = session.Start(c)
require.NoError(t, err)
defer session.Close()
msg, err := io.ReadAll(stdout)
require.NoError(t, err)
assertFileTransferBlocked(t, string(msg))
})
}
})
}
func TestAgent_EnvironmentVariables(t *testing.T) {
t.Parallel()
key := "EXAMPLE"
@@ -1517,18 +1337,16 @@ func TestAgent_Lifecycle(t *testing.T) {
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{{
ID: uuid.New(),
LogPath: "coder-startup-script.log",
Script: "echo 1",
RunOnStart: true,
}, {
ID: uuid.New(),
LogPath: "coder-shutdown-script.log",
Script: "echo " + expected,
RunOnStop: true,
}},
},
make(chan *proto.Stats, 50),
make(chan *agentsdk.Stats, 50),
tailnet.NewCoordinator(logger),
)
defer client.Close()
@@ -1701,7 +1519,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
require.NoError(t, tr1.ReadUntil(ctx, matchPrompt), "find prompt")
require.NoError(t, tr2.ReadUntil(ctx, matchPrompt), "find prompt")
data, err := json.Marshal(workspacesdk.ReconnectingPTYRequest{
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
Data: "echo test\r",
})
require.NoError(t, err)
@@ -1729,7 +1547,7 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
require.NoError(t, tr3.ReadUntil(ctx, matchEchoOutput), "find echo output")
// Exit should cause the connection to close.
data, err = json.Marshal(workspacesdk.ReconnectingPTYRequest{
data, err = json.Marshal(codersdk.ReconnectingPTYRequest{
Data: "exit\r",
})
require.NoError(t, err)
@@ -1849,7 +1667,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
_ = coordinator.Close()
})
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t,
logger.Named("agent"),
@@ -1878,7 +1696,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
})
// Setup a client connection.
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn {
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *codersdk.WorkspaceAgentConn {
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
DERPMap: derpMap,
@@ -1898,9 +1716,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
coordinator, conn)
t.Cleanup(func() {
t.Logf("closing coordination %s", name)
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
err := coordination.Close()
if err != nil {
t.Logf("error closing in-memory coordination: %s", err.Error())
}
@@ -1909,9 +1725,9 @@ func TestAgent_UpdatedDERP(t *testing.T) {
// Force DERP.
conn.SetBlockEndpoints(true)
sdkConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
sdkConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: agentID,
CloseFunc: func() error { return workspacesdk.ErrSkipClose },
CloseFunc: func() error { return codersdk.ErrSkipClose },
})
t.Cleanup(func() {
t.Logf("closing sdkConn %s", name)
@@ -2000,7 +1816,7 @@ func TestAgent_Reconnect(t *testing.T) {
defer coordinator.Close()
agentID := uuid.New()
statsCh := make(chan *proto.Stats, 50)
statsCh := make(chan *agentsdk.Stats, 50)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
client := agenttest.NewClient(t,
logger,
@@ -2045,7 +1861,7 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
GitAuthConfigs: 1,
DERPMap: &tailcfg.DERPMap{},
},
make(chan *proto.Stats, 50),
make(chan *agentsdk.Stats, 50),
coordinator,
)
defer client.Close()
@@ -2072,21 +1888,11 @@ func TestAgent_WriteVSCodeConfigs(t *testing.T) {
func TestAgent_DebugServer(t *testing.T) {
t.Parallel()
logDir := t.TempDir()
logPath := filepath.Join(logDir, "coder-agent.log")
randLogStr, err := cryptorand.String(32)
require.NoError(t, err)
require.NoError(t, os.WriteFile(logPath, []byte(randLogStr), 0o600))
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
//nolint:dogsled
conn, _, _, _, agnt := setupAgent(t, agentsdk.Manifest{
DERPMap: derpMap,
}, 0, func(c *agenttest.Client, o *agent.Options) {
o.ExchangeToken = func(context.Context) (string, error) {
return "token", nil
}
o.LogDir = logDir
})
}, 0)
awaitReachableCtx := testutil.Context(t, testutil.WaitLong)
ok := conn.AwaitReachable(awaitReachableCtx)
@@ -2167,114 +1973,6 @@ func TestAgent_DebugServer(t *testing.T) {
require.Contains(t, string(resBody), `invalid state "blah", must be a boolean`)
})
})
t.Run("Manifest", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/manifest", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
var v agentsdk.Manifest
require.NoError(t, json.NewDecoder(res.Body).Decode(&v))
require.NotNil(t, v)
})
t.Run("Logs", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL+"/debug/logs", nil)
require.NoError(t, err)
res, err := srv.Client().Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
defer res.Body.Close()
resBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.NotEmpty(t, string(resBody))
require.Contains(t, string(resBody), randLogStr)
})
}
func TestAgent_ScriptLogging(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("bash scripts only")
}
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
logsCh := make(chan *proto.BatchCreateLogsRequest, 100)
lsStart := uuid.UUID{0x11}
lsStop := uuid.UUID{0x22}
//nolint:dogsled
_, _, _, _, agnt := setupAgent(
t,
agentsdk.Manifest{
DERPMap: derpMap,
Scripts: []codersdk.WorkspaceAgentScript{
{
LogSourceID: lsStart,
RunOnStart: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 5 ]
do
i=$(($i+1))
echo "start $i"
done
`,
},
{
LogSourceID: lsStop,
RunOnStop: true,
Script: `#!/bin/sh
i=0
while [ $i -ne 3000 ]
do
i=$(($i+1))
echo "stop $i"
done
`, // send a lot of stop logs to make sure we don't truncate shutdown logs before closing the API conn
},
},
},
0,
func(cl *agenttest.Client, _ *agent.Options) {
cl.SetLogsChannel(logsCh)
},
)
n := 1
for n <= 5 {
logs := testutil.RequireRecvCtx(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("start %d", n), l.GetOutput())
n++
}
}
err := agnt.Close()
require.NoError(t, err)
n = 1
for n <= 3000 {
logs := testutil.RequireRecvCtx(ctx, t, logsCh)
require.NotNil(t, logs)
for _, l := range logs.GetLogs() {
require.Equal(t, fmt.Sprintf("stop %d", n), l.GetOutput())
n++
}
t.Logf("got %d stop logs", n-1)
}
}
// setupAgentSSHClient creates an agent, dials it, and sets up an ssh.Client for it
@@ -2290,19 +1988,17 @@ func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
func setupSSHSession(
t *testing.T,
manifest agentsdk.Manifest,
banner codersdk.BannerConfig,
serviceBanner codersdk.ServiceBannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
c.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
return []codersdk.BannerConfig{banner}, nil
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, func(c *agenttest.Client, _ *agent.Options) {
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
})
})
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
if prepareFS != nil {
prepareFS(fs)
}
@@ -2320,9 +2016,9 @@ func setupSSHSession(
}
func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
*workspacesdk.AgentConn,
*codersdk.WorkspaceAgentConn,
*agenttest.Client,
<-chan *proto.Stats,
<-chan *agentsdk.Stats,
afero.Fs,
agent.Agent,
) {
@@ -2350,9 +2046,9 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *proto.Stats, 50)
statsCh := make(chan *agentsdk.Stats, 50)
fs := afero.NewMemMapFs()
c := agenttest.NewClient(t, logger.Named("agenttest"), metadata.AgentID, metadata, statsCh, coordinator)
c := agenttest.NewClient(t, logger.Named("agent"), metadata.AgentID, metadata, statsCh, coordinator)
t.Cleanup(c.Close)
options := agent.Options{
@@ -2360,16 +2056,15 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
EnvironmentVariables: map[string]string{},
}
for _, opt := range opts {
opt(c, &options)
}
agnt := agent.New(options)
closer := agent.New(options)
t.Cleanup(func() {
_ = agnt.Close()
_ = closer.Close()
})
conn, err := tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
@@ -2388,14 +2083,12 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
clientID, metadata.AgentID,
coordinator, conn)
t.Cleanup(func() {
cctx, ccancel := context.WithTimeout(testCtx, testutil.WaitShort)
defer ccancel()
err := coordination.Close(cctx)
err := coordination.Close()
if err != nil {
t.Logf("error closing in-mem coordination: %s", err.Error())
}
})
agentConn := workspacesdk.NewAgentConn(conn, workspacesdk.AgentConnOptions{
agentConn := codersdk.NewWorkspaceAgentConn(conn, codersdk.WorkspaceAgentConnOptions{
AgentID: metadata.AgentID,
})
t.Cleanup(func() {
@@ -2408,7 +2101,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
if !agentConn.AwaitReachable(ctx) {
t.Fatal("agent not reachable")
}
return agentConn, c, statsCh, fs, agnt
return agentConn, c, statsCh, fs, closer
}
var dialTestPayload = []byte("dean-was-here123")
@@ -2537,17 +2230,17 @@ func TestAgent_Metrics_SSH(t *testing.T) {
err = session.Shell()
require.NoError(t, err)
expected := []*proto.Stats_Metric{
expected := []agentsdk.AgentMetric{
{
Name: "agent_reconnecting_pty_connections_total",
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: 0,
},
{
Name: "agent_sessions_total",
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
Labels: []agentsdk.AgentMetricLabel{
{
Name: "magic_type",
Value: "ssh",
@@ -2560,45 +2253,29 @@ func TestAgent_Metrics_SSH(t *testing.T) {
},
{
Name: "agent_ssh_server_failed_connections_total",
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_connections_total",
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: 0,
},
{
Name: "agent_ssh_server_sftp_server_errors_total",
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: 0,
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 0,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
Value: "derp",
},
},
},
{
Name: "coderd_agentstats_currently_reachable_peers",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
Labels: []*proto.Stats_Metric_Label{
{
Name: "connection_type",
Value: "p2p",
},
},
},
{
Name: "coderd_agentstats_startup_script_seconds",
Type: proto.Stats_Metric_GAUGE,
Value: 1,
Type: agentsdk.AgentMetricTypeGauge,
Value: 0,
Labels: []agentsdk.AgentMetricLabel{
{
Name: "success",
Value: "true",
},
},
},
}
@@ -2608,33 +2285,17 @@ func TestAgent_Metrics_SSH(t *testing.T) {
if err != nil {
return false
}
count := 0
for _, m := range actual {
count += len(m.GetMetric())
if len(expected) != len(actual) {
return false
}
return count == len(expected)
return verifyCollectedMetrics(t, expected, actual)
}, testutil.WaitLong, testutil.IntervalFast)
i := 0
for _, mf := range actual {
for _, m := range mf.GetMetric() {
assert.Equal(t, expected[i].Name, mf.GetName())
assert.Equal(t, expected[i].Type.String(), mf.GetType().String())
// Value is max expected
if expected[i].Type == proto.Stats_Metric_GAUGE {
assert.GreaterOrEqualf(t, expected[i].Value, m.GetGauge().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetGauge().GetValue())
} else if expected[i].Type == proto.Stats_Metric_COUNTER {
assert.GreaterOrEqualf(t, expected[i].Value, m.GetCounter().GetValue(), "expected %s to be greater than or equal to %f, got %f", expected[i].Name, expected[i].Value, m.GetCounter().GetValue())
}
for j, lbl := range expected[i].Labels {
assert.Equal(t, m.GetLabel()[j], &promgo.LabelPair{
Name: &lbl.Name,
Value: &lbl.Value,
})
}
i++
}
}
require.Len(t, actual, len(expected))
collected := verifyCollectedMetrics(t, expected, actual)
require.True(t, collected, "expected metrics were not collected")
_ = stdin.Close()
err = session.Wait()
@@ -2660,11 +2321,11 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
logger = slog.Make(sloghuman.Sink(io.Discard))
)
requireFileWrite(t, fs, "/proc/self/oom_score_adj", "-500")
// Create some processes.
for i := 0; i < 4; i++ {
// Create a prioritized process.
// Create a prioritized process. This process should
// have it's oom_score_adj set to -500 and its nice
// score should be untouched.
var proc agentproc.Process
if i == 0 {
proc = agentproctest.GenerateProcess(t, fs,
@@ -2682,8 +2343,8 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
},
)
syscaller.EXPECT().GetPriority(proc.PID).Return(20, nil)
syscaller.EXPECT().SetPriority(proc.PID, 10).Return(nil)
syscaller.EXPECT().GetPriority(proc.PID).Return(20, nil)
}
syscaller.EXPECT().
Kill(proc.PID, syscall.Signal(0)).
@@ -2702,9 +2363,6 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
})
actualProcs := <-modProcs
require.Len(t, actualProcs, len(expectedProcs)-1)
for _, proc := range actualProcs {
requireFileEquals(t, fs, fmt.Sprintf("/proc/%d/oom_score_adj", proc.PID), "0")
}
})
t.Run("IgnoreCustomNice", func(t *testing.T) {
@@ -2723,11 +2381,8 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
logger = slog.Make(sloghuman.Sink(io.Discard))
)
err := afero.WriteFile(fs, "/proc/self/oom_score_adj", []byte("0"), 0o644)
require.NoError(t, err)
// Create some processes.
for i := 0; i < 3; i++ {
for i := 0; i < 2; i++ {
proc := agentproctest.GenerateProcess(t, fs)
syscaller.EXPECT().
Kill(proc.PID, syscall.Signal(0)).
@@ -2755,59 +2410,7 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
})
actualProcs := <-modProcs
// We should ignore the process with a custom nice score.
require.Len(t, actualProcs, 2)
for _, proc := range actualProcs {
_, ok := expectedProcs[proc.PID]
require.True(t, ok)
requireFileEquals(t, fs, fmt.Sprintf("/proc/%d/oom_score_adj", proc.PID), "998")
}
})
t.Run("CustomOOMScore", func(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" {
t.Skip("Skipping non-linux environment")
}
var (
fs = afero.NewMemMapFs()
ticker = make(chan time.Time)
syscaller = agentproctest.NewMockSyscaller(gomock.NewController(t))
modProcs = make(chan []*agentproc.Process)
logger = slog.Make(sloghuman.Sink(io.Discard))
)
err := afero.WriteFile(fs, "/proc/self/oom_score_adj", []byte("0"), 0o644)
require.NoError(t, err)
// Create some processes.
for i := 0; i < 3; i++ {
proc := agentproctest.GenerateProcess(t, fs)
syscaller.EXPECT().
Kill(proc.PID, syscall.Signal(0)).
Return(nil)
syscaller.EXPECT().GetPriority(proc.PID).Return(20, nil)
syscaller.EXPECT().SetPriority(proc.PID, 10).Return(nil)
}
_, _, _, _, _ = setupAgent(t, agentsdk.Manifest{}, 0, func(c *agenttest.Client, o *agent.Options) {
o.Syscaller = syscaller
o.ModifiedProcesses = modProcs
o.EnvironmentVariables = map[string]string{
agent.EnvProcPrioMgmt: "1",
agent.EnvProcOOMScore: "-567",
}
o.Filesystem = fs
o.Logger = logger
o.ProcessManagementTick = ticker
})
actualProcs := <-modProcs
// We should ignore the process with a custom nice score.
require.Len(t, actualProcs, 3)
for _, proc := range actualProcs {
requireFileEquals(t, fs, fmt.Sprintf("/proc/%d/oom_score_adj", proc.PID), "-567")
}
require.Len(t, actualProcs, 1)
})
t.Run("DisabledByDefault", func(t *testing.T) {
@@ -2866,6 +2469,28 @@ func TestAgent_ManageProcessPriority(t *testing.T) {
})
}
func verifyCollectedMetrics(t *testing.T, expected []agentsdk.AgentMetric, actual []*promgo.MetricFamily) bool {
t.Helper()
for i, e := range expected {
assert.Equal(t, e.Name, actual[i].GetName())
assert.Equal(t, string(e.Type), strings.ToLower(actual[i].GetType().String()))
for _, m := range actual[i].GetMetric() {
assert.Equal(t, e.Value, m.Counter.GetValue())
if len(m.GetLabel()) > 0 {
for j, lbl := range m.GetLabel() {
assert.Equal(t, e.Labels[j].Name, lbl.GetName())
assert.Equal(t, e.Labels[j].Value, lbl.GetValue())
}
}
m.GetLabel()
}
}
return true
}
type syncWriter struct {
mu sync.Mutex
w io.Writer
@@ -2877,6 +2502,20 @@ func (s *syncWriter) Write(p []byte) (int, error) {
return s.w.Write(p)
}
// pickRandomPort picks a random port number for the ephemeral range. We do this entirely randomly
// instead of opening a listener and closing it to find a port that is likely to be free, since
// sometimes the OS reallocates the port very quickly.
func pickRandomPort() uint16 {
const (
// Overlap of windows, linux in https://en.wikipedia.org/wiki/Ephemeral_port
min = 49152
max = 60999
)
n := max - min
x := rand.Intn(n) //nolint: gosec
return uint16(min + x)
}
// echoOnce accepts a single connection, reads 4 bytes and echos them back
func echoOnce(t *testing.T, ll net.Listener) {
t.Helper()
@@ -2906,17 +2545,3 @@ func requireEcho(t *testing.T, conn net.Conn) {
require.NoError(t, err)
require.Equal(t, "test", string(b))
}
func requireFileWrite(t testing.TB, fs afero.Fs, fp, data string) {
t.Helper()
err := afero.WriteFile(fs, fp, []byte(data), 0o600)
require.NoError(t, err)
}
func requireFileEquals(t testing.TB, fs afero.Fs, fp, expect string) {
t.Helper()
actual, err := afero.ReadFile(fs, fp)
require.NoError(t, err)
require.Equal(t, expect, string(actual))
}
+2 -8
View File
@@ -2,7 +2,6 @@ package agentproctest
import (
"fmt"
"strconv"
"testing"
"github.com/spf13/afero"
@@ -30,9 +29,8 @@ func GenerateProcess(t *testing.T, fs afero.Fs, muts ...func(*agentproc.Process)
cmdline := fmt.Sprintf("%s\x00%s\x00%s", arg1, arg2, arg3)
process := agentproc.Process{
CmdLine: cmdline,
PID: int32(pid),
OOMScoreAdj: 0,
CmdLine: cmdline,
PID: int32(pid),
}
for _, mut := range muts {
@@ -47,9 +45,5 @@ func GenerateProcess(t *testing.T, fs afero.Fs, muts ...func(*agentproc.Process)
err = afero.WriteFile(fs, fmt.Sprintf("%s/cmdline", process.Dir), []byte(process.CmdLine), 0o444)
require.NoError(t, err)
score := strconv.Itoa(process.OOMScoreAdj)
err = afero.WriteFile(fs, fmt.Sprintf("%s/oom_score_adj", process.Dir), []byte(score), 0o444)
require.NoError(t, err)
return process
}
+5 -30
View File
@@ -5,7 +5,6 @@ package agentproc
import (
"errors"
"os"
"path/filepath"
"strconv"
"strings"
@@ -45,31 +44,16 @@ func List(fs afero.Fs, syscaller Syscaller) ([]*Process, error) {
cmdline, err := afero.ReadFile(fs, filepath.Join(defaultProcDir, entry, "cmdline"))
if err != nil {
if isBenignError(err) {
var errNo syscall.Errno
if xerrors.As(err, &errNo) && errNo == syscall.EPERM {
continue
}
return nil, xerrors.Errorf("read cmdline: %w", err)
}
oomScore, err := afero.ReadFile(fs, filepath.Join(defaultProcDir, entry, "oom_score_adj"))
if err != nil {
if isBenignError(err) {
continue
}
return nil, xerrors.Errorf("read oom_score_adj: %w", err)
}
oom, err := strconv.Atoi(strings.TrimSpace(string(oomScore)))
if err != nil {
return nil, xerrors.Errorf("convert oom score: %w", err)
}
processes = append(processes, &Process{
PID: int32(pid),
CmdLine: string(cmdline),
Dir: filepath.Join(defaultProcDir, entry),
OOMScoreAdj: oom,
PID: int32(pid),
CmdLine: string(cmdline),
Dir: filepath.Join(defaultProcDir, entry),
})
}
@@ -123,12 +107,3 @@ func (p *Process) Cmd() string {
func (p *Process) cmdLine() []string {
return strings.Split(p.CmdLine, "\x00")
}
func isBenignError(err error) bool {
var errno syscall.Errno
if !xerrors.As(err, &errno) {
return false
}
return errno == syscall.ESRCH || errno == syscall.EPERM || xerrors.Is(err, os.ErrNotExist)
}
+3 -4
View File
@@ -14,8 +14,7 @@ type Syscaller interface {
const defaultProcDir = "/proc"
type Process struct {
Dir string
CmdLine string
PID int32
OOMScoreAdj int
Dir string
CmdLine string
PID int32
}
+35 -150
View File
@@ -13,19 +13,15 @@ import (
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/robfig/cron/v3"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
@@ -45,19 +41,13 @@ var (
parser = cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.DowOptional)
)
type ScriptLogger interface {
Send(ctx context.Context, log ...agentsdk.Log) error
Flush(context.Context) error
}
// Options are a set of options for the runner.
type Options struct {
DataDirBase string
LogDir string
Logger slog.Logger
SSHServer *agentssh.Server
Filesystem afero.Fs
GetScriptLogger func(logSourceID uuid.UUID) ScriptLogger
LogDir string
Logger slog.Logger
SSHServer *agentssh.Server
Filesystem afero.Fs
PatchLogs func(ctx context.Context, req agentsdk.PatchLogs) error
}
// New creates a runner for the provided scripts.
@@ -69,7 +59,6 @@ func New(opts Options) *Runner {
cronCtxCancel: cronCtxCancel,
cron: cron.New(cron.WithParser(parser)),
closed: make(chan struct{}),
dataDir: filepath.Join(opts.DataDirBase, "coder-script-data"),
scriptsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "scripts",
@@ -78,21 +67,17 @@ func New(opts Options) *Runner {
}
}
type ScriptCompletedFunc func(context.Context, *proto.WorkspaceAgentScriptCompletedRequest) (*proto.WorkspaceAgentScriptCompletedResponse, error)
type Runner struct {
Options
cronCtx context.Context
cronCtxCancel context.CancelFunc
cmdCloseWait sync.WaitGroup
closed chan struct{}
closeMutex sync.Mutex
cron *cron.Cron
initialized atomic.Bool
scripts []codersdk.WorkspaceAgentScript
dataDir string
scriptCompleted ScriptCompletedFunc
cronCtx context.Context
cronCtxCancel context.CancelFunc
cmdCloseWait sync.WaitGroup
closed chan struct{}
closeMutex sync.Mutex
cron *cron.Cron
initialized atomic.Bool
scripts []codersdk.WorkspaceAgentScript
// scriptsExecuted includes all scripts executed by the workspace agent. Agents
// execute startup scripts, and scripts on a cron schedule. Both will increment
@@ -100,17 +85,6 @@ type Runner struct {
scriptsExecuted *prometheus.CounterVec
}
// DataDir returns the directory where scripts data is stored.
func (r *Runner) DataDir() string {
return r.dataDir
}
// ScriptBinDir returns the directory where scripts can store executable
// binaries.
func (r *Runner) ScriptBinDir() string {
return filepath.Join(r.dataDir, "bin")
}
func (r *Runner) RegisterMetrics(reg prometheus.Registerer) {
if reg == nil {
// If no registry, do nothing.
@@ -122,27 +96,21 @@ func (r *Runner) RegisterMetrics(reg prometheus.Registerer) {
// Init initializes the runner with the provided scripts.
// It also schedules any scripts that have a schedule.
// This function must be called before Execute.
func (r *Runner) Init(scripts []codersdk.WorkspaceAgentScript, scriptCompleted ScriptCompletedFunc) error {
func (r *Runner) Init(scripts []codersdk.WorkspaceAgentScript) error {
if r.initialized.Load() {
return xerrors.New("init: already initialized")
}
r.initialized.Store(true)
r.scripts = scripts
r.scriptCompleted = scriptCompleted
r.Logger.Info(r.cronCtx, "initializing agent scripts", slog.F("script_count", len(scripts)), slog.F("log_dir", r.LogDir))
err := r.Filesystem.MkdirAll(r.ScriptBinDir(), 0o700)
if err != nil {
return xerrors.Errorf("create script bin dir: %w", err)
}
for _, script := range scripts {
if script.Cron == "" {
continue
}
script := script
_, err := r.cron.AddFunc(script.Cron, func() {
err := r.trackRun(r.cronCtx, script, ExecuteCronScripts)
err := r.trackRun(r.cronCtx, script)
if err != nil {
r.Logger.Warn(context.Background(), "run agent script on schedule", slog.Error(err))
}
@@ -179,33 +147,22 @@ func (r *Runner) StartCron() {
}
}
// ExecuteOption describes what scripts we want to execute.
type ExecuteOption int
// ExecuteOption enums.
const (
ExecuteAllScripts ExecuteOption = iota
ExecuteStartScripts
ExecuteStopScripts
ExecuteCronScripts
)
// Execute runs a set of scripts according to a filter.
func (r *Runner) Execute(ctx context.Context, option ExecuteOption) error {
func (r *Runner) Execute(ctx context.Context, filter func(script codersdk.WorkspaceAgentScript) bool) error {
if filter == nil {
// Execute em' all!
filter = func(script codersdk.WorkspaceAgentScript) bool {
return true
}
}
var eg errgroup.Group
for _, script := range r.scripts {
runScript := (option == ExecuteStartScripts && script.RunOnStart) ||
(option == ExecuteStopScripts && script.RunOnStop) ||
(option == ExecuteCronScripts && script.Cron != "") ||
option == ExecuteAllScripts
if !runScript {
if !filter(script) {
continue
}
script := script
eg.Go(func() error {
err := r.trackRun(ctx, script, option)
err := r.trackRun(ctx, script)
if err != nil {
return xerrors.Errorf("run agent script %q: %w", script.LogSourceID, err)
}
@@ -216,8 +173,8 @@ func (r *Runner) Execute(ctx context.Context, option ExecuteOption) error {
}
// trackRun wraps "run" with metrics.
func (r *Runner) trackRun(ctx context.Context, script codersdk.WorkspaceAgentScript, option ExecuteOption) error {
err := r.run(ctx, script, option)
func (r *Runner) trackRun(ctx context.Context, script codersdk.WorkspaceAgentScript) error {
err := r.run(ctx, script)
if err != nil {
r.scriptsExecuted.WithLabelValues("false").Add(1)
} else {
@@ -230,7 +187,7 @@ func (r *Runner) trackRun(ctx context.Context, script codersdk.WorkspaceAgentScr
// If the timeout is exceeded, the process is sent an interrupt signal.
// If the process does not exit after a few seconds, it is forcefully killed.
// This function immediately returns after a timeout, and does not wait for the process to exit.
func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript, option ExecuteOption) error {
func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript) error {
logPath := script.LogPath
if logPath == "" {
logPath = fmt.Sprintf("coder-script-%s.log", script.LogSourceID)
@@ -251,18 +208,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
if !filepath.IsAbs(logPath) {
logPath = filepath.Join(r.LogDir, logPath)
}
scriptDataDir := filepath.Join(r.DataDir(), script.LogSourceID.String())
err := r.Filesystem.MkdirAll(scriptDataDir, 0o700)
if err != nil {
return xerrors.Errorf("%s script: create script temp dir: %w", scriptDataDir, err)
}
logger := r.Logger.With(
slog.F("log_source_id", script.LogSourceID),
slog.F("log_path", logPath),
slog.F("script_data_dir", scriptDataDir),
)
logger := r.Logger.With(slog.F("log_path", logPath))
logger.Info(ctx, "running agent script", slog.F("script", script.Script))
fileWriter, err := r.Filesystem.OpenFile(logPath, os.O_CREATE|os.O_RDWR, 0o600)
@@ -292,34 +238,27 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
cmd.WaitDelay = 10 * time.Second
cmd.Cancel = cmdCancel(cmd)
// Expose env vars that can be used in the script for storing data
// and binaries. In the future, we may want to expose more env vars
// for the script to use, like CODER_SCRIPT_DATA_DIR for persistent
// storage.
cmd.Env = append(cmd.Env, "CODER_SCRIPT_DATA_DIR="+scriptDataDir)
cmd.Env = append(cmd.Env, "CODER_SCRIPT_BIN_DIR="+r.ScriptBinDir())
scriptLogger := r.GetScriptLogger(script.LogSourceID)
send, flushAndClose := agentsdk.LogsSender(script.LogSourceID, r.PatchLogs, logger)
// If ctx is canceled here (or in a writer below), we may be
// discarding logs, but that's okay because we're shutting down
// anyway. We could consider creating a new context here if we
// want better control over flush during shutdown.
defer func() {
if err := scriptLogger.Flush(ctx); err != nil {
if err := flushAndClose(ctx); err != nil {
logger.Warn(ctx, "flush startup logs failed", slog.Error(err))
}
}()
infoW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelInfo)
infoW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelInfo)
defer infoW.Close()
errW := agentsdk.LogsWriter(ctx, scriptLogger.Send, script.LogSourceID, codersdk.LogLevelError)
errW := agentsdk.LogsWriter(ctx, send, script.LogSourceID, codersdk.LogLevelError)
defer errW.Close()
cmd.Stdout = io.MultiWriter(fileWriter, infoW)
cmd.Stderr = io.MultiWriter(fileWriter, errW)
start := dbtime.Now()
start := time.Now()
defer func() {
end := dbtime.Now()
end := time.Now()
execTime := end.Sub(start)
exitCode := 0
if err != nil {
@@ -332,60 +271,6 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
} else {
logger.Info(ctx, fmt.Sprintf("%s script completed", logPath), slog.F("execution_time", execTime), slog.F("exit_code", exitCode))
}
if r.scriptCompleted == nil {
logger.Debug(ctx, "r.scriptCompleted unexpectedly nil")
return
}
// We want to check this outside of the goroutine to avoid a race condition
timedOut := errors.Is(err, ErrTimeout)
pipesLeftOpen := errors.Is(err, ErrOutputPipesOpen)
err = r.trackCommandGoroutine(func() {
var stage proto.Timing_Stage
switch option {
case ExecuteStartScripts:
stage = proto.Timing_START
case ExecuteStopScripts:
stage = proto.Timing_STOP
case ExecuteCronScripts:
stage = proto.Timing_CRON
}
var status proto.Timing_Status
switch {
case timedOut:
status = proto.Timing_TIMED_OUT
case pipesLeftOpen:
status = proto.Timing_PIPES_LEFT_OPEN
case exitCode != 0:
status = proto.Timing_EXIT_FAILURE
default:
status = proto.Timing_OK
}
reportTimeout := 30 * time.Second
reportCtx, cancel := context.WithTimeout(context.Background(), reportTimeout)
defer cancel()
_, err := r.scriptCompleted(reportCtx, &proto.WorkspaceAgentScriptCompletedRequest{
Timing: &proto.Timing{
ScriptId: script.ID[:],
Start: timestamppb.New(start),
End: timestamppb.New(end),
ExitCode: int32(exitCode),
Stage: stage,
Status: status,
},
})
if err != nil {
logger.Error(ctx, fmt.Sprintf("reporting script completed: %s", err.Error()))
}
})
if err != nil {
logger.Error(ctx, fmt.Sprintf("reporting script completed: track command goroutine: %s", err.Error()))
}
}()
err = cmd.Start()
@@ -421,7 +306,7 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript,
"This usually means a child process was started with references to stdout or stderr. As a result, this " +
"process may now have been terminated. Consider redirecting the output or using a separate " +
"\"coder_script\" for the process, see " +
"https://coder.com/docs/templates/troubleshooting#startup-script-issues for more information.",
"https://coder.com/docs/v2/latest/templates/troubleshooting#startup-script-issues for more information.",
)
// Inform the user by propagating the message via log writers.
_, _ = fmt.Fprintf(cmd.Stderr, "WARNING: %s. %s\n", message, details)
+28 -150
View File
@@ -2,25 +2,20 @@ package agentscripts_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
@@ -29,117 +24,33 @@ func TestMain(m *testing.M) {
func TestExecuteBasic(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
fLogger := newFakeScriptLogger()
runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger {
return fLogger
logs := make(chan agentsdk.PatchLogs, 1)
runner := setup(t, func(ctx context.Context, req agentsdk.PatchLogs) error {
logs <- req
return nil
})
defer runner.Close()
aAPI := agenttest.NewFakeAgentAPI(t, slogtest.Make(t, nil), nil, nil)
err := runner.Init([]codersdk.WorkspaceAgentScript{{
LogSourceID: uuid.New(),
Script: "echo hello",
}}, aAPI.ScriptCompleted)
Script: "echo hello",
}})
require.NoError(t, err)
require.NoError(t, runner.Execute(context.Background(), agentscripts.ExecuteAllScripts))
log := testutil.RequireRecvCtx(ctx, t, fLogger.logs)
require.Equal(t, "hello", log.Output)
}
func TestEnv(t *testing.T) {
t.Parallel()
fLogger := newFakeScriptLogger()
runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger {
return fLogger
})
defer runner.Close()
id := uuid.New()
script := "echo $CODER_SCRIPT_DATA_DIR\necho $CODER_SCRIPT_BIN_DIR\n"
if runtime.GOOS == "windows" {
script = `
cmd.exe /c echo %CODER_SCRIPT_DATA_DIR%
cmd.exe /c echo %CODER_SCRIPT_BIN_DIR%
`
}
aAPI := agenttest.NewFakeAgentAPI(t, slogtest.Make(t, nil), nil, nil)
err := runner.Init([]codersdk.WorkspaceAgentScript{{
LogSourceID: id,
Script: script,
}}, aAPI.ScriptCompleted)
require.NoError(t, err)
ctx := testutil.Context(t, testutil.WaitLong)
done := testutil.Go(t, func() {
err := runner.Execute(ctx, agentscripts.ExecuteAllScripts)
assert.NoError(t, err)
})
defer func() {
select {
case <-ctx.Done():
case <-done:
}
}()
var log []agentsdk.Log
for {
select {
case <-ctx.Done():
require.Fail(t, "timed out waiting for logs")
case l := <-fLogger.logs:
t.Logf("log: %s", l.Output)
log = append(log, l)
}
if len(log) >= 2 {
break
}
}
require.Contains(t, log[0].Output, filepath.Join(runner.DataDir(), id.String()))
require.Contains(t, log[1].Output, runner.ScriptBinDir())
require.NoError(t, runner.Execute(context.Background(), func(script codersdk.WorkspaceAgentScript) bool {
return true
}))
log := <-logs
require.Equal(t, "hello", log.Logs[0].Output)
}
func TestTimeout(t *testing.T) {
t.Parallel()
runner := setup(t, nil)
defer runner.Close()
aAPI := agenttest.NewFakeAgentAPI(t, slogtest.Make(t, nil), nil, nil)
err := runner.Init([]codersdk.WorkspaceAgentScript{{
LogSourceID: uuid.New(),
Script: "sleep infinity",
Timeout: time.Millisecond,
}}, aAPI.ScriptCompleted)
Script: "sleep infinity",
Timeout: time.Millisecond,
}})
require.NoError(t, err)
require.ErrorIs(t, runner.Execute(context.Background(), agentscripts.ExecuteAllScripts), agentscripts.ErrTimeout)
}
func TestScriptReportsTiming(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
fLogger := newFakeScriptLogger()
runner := setup(t, func(uuid2 uuid.UUID) agentscripts.ScriptLogger {
return fLogger
})
aAPI := agenttest.NewFakeAgentAPI(t, slogtest.Make(t, nil), nil, nil)
err := runner.Init([]codersdk.WorkspaceAgentScript{{
DisplayName: "say-hello",
LogSourceID: uuid.New(),
Script: "echo hello",
}}, aAPI.ScriptCompleted)
require.NoError(t, err)
require.NoError(t, runner.Execute(ctx, agentscripts.ExecuteAllScripts))
runner.Close()
log := testutil.RequireRecvCtx(ctx, t, fLogger.logs)
require.Equal(t, "hello", log.Output)
timings := aAPI.GetTimings()
require.Equal(t, 1, len(timings))
timing := timings[0]
require.Equal(t, int32(0), timing.ExitCode)
require.GreaterOrEqual(t, timing.End.AsTime(), timing.Start.AsTime())
require.ErrorIs(t, runner.Execute(context.Background(), nil), agentscripts.ErrTimeout)
}
// TestCronClose exists because cron.Run() can happen after cron.Close().
@@ -151,61 +62,28 @@ func TestCronClose(t *testing.T) {
require.NoError(t, runner.Close(), "close runner")
}
func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner {
func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchLogs) error) *agentscripts.Runner {
t.Helper()
if getScriptLogger == nil {
if patchLogs == nil {
// noop
getScriptLogger = func(uuid uuid.UUID) agentscripts.ScriptLogger {
return noopScriptLogger{}
patchLogs = func(ctx context.Context, req agentsdk.PatchLogs) error {
return nil
}
}
fs := afero.NewMemMapFs()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, 0, "")
require.NoError(t, err)
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
t.Cleanup(func() {
_ = s.Close()
})
return agentscripts.New(agentscripts.Options{
LogDir: t.TempDir(),
DataDirBase: t.TempDir(),
Logger: logger,
SSHServer: s,
Filesystem: fs,
GetScriptLogger: getScriptLogger,
LogDir: t.TempDir(),
Logger: logger,
SSHServer: s,
Filesystem: fs,
PatchLogs: patchLogs,
})
}
type noopScriptLogger struct{}
func (noopScriptLogger) Send(context.Context, ...agentsdk.Log) error {
return nil
}
func (noopScriptLogger) Flush(context.Context) error {
return nil
}
type fakeScriptLogger struct {
logs chan agentsdk.Log
}
func (f *fakeScriptLogger) Send(ctx context.Context, logs ...agentsdk.Log) error {
for _, log := range logs {
select {
case <-ctx.Done():
return ctx.Err()
case f.logs <- log:
// OK!
}
}
return nil
}
func (*fakeScriptLogger) Flush(context.Context) error {
return nil
}
func newFakeScriptLogger() *fakeScriptLogger {
return &fakeScriptLogger{make(chan agentsdk.Log, 100)}
}
+88 -139
View File
@@ -32,6 +32,7 @@ import (
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty"
)
@@ -52,40 +53,8 @@ const (
// MagicProcessCmdlineJetBrains is a string in a process's command line that
// uniquely identifies it as JetBrains software.
MagicProcessCmdlineJetBrains = "idea.vendor.name=JetBrains"
// BlockedFileTransferErrorCode indicates that SSH server restricted the raw command from performing
// the file transfer.
BlockedFileTransferErrorCode = 65 // Error code: host not allowed to connect
BlockedFileTransferErrorMessage = "File transfer has been disabled."
)
// BlockedFileTransferCommands contains a list of restricted file transfer commands.
var BlockedFileTransferCommands = []string{"nc", "rsync", "scp", "sftp"}
// Config sets configuration parameters for the agent SSH server.
type Config struct {
// MaxTimeout sets the absolute connection timeout, none if empty. If set to
// 3 seconds or more, keep alive will be used instead.
MaxTimeout time.Duration
// MOTDFile returns the path to the message of the day file. If set, the
// file will be displayed to the user upon login.
MOTDFile func() string
// ServiceBanner returns the configuration for the Coder service banner.
AnnouncementBanners func() *[]codersdk.BannerConfig
// UpdateEnv updates the environment variables for the command to be
// executed. It can be used to add, modify or replace environment variables.
UpdateEnv func(current []string) (updated []string, err error)
// WorkingDirectory sets the working directory for commands and defines
// where users will land when they connect via SSH. Default is the home
// directory of the user.
WorkingDirectory func() string
// X11DisplayOffset is the offset to add to the X11 display number.
// Default is 10.
X11DisplayOffset *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
}
type Server struct {
mu sync.RWMutex // Protects following.
fs afero.Fs
@@ -97,10 +66,14 @@ type Server struct {
// a lock on mu but protected by closing.
wg sync.WaitGroup
logger slog.Logger
srv *ssh.Server
logger slog.Logger
srv *ssh.Server
x11SocketDir string
config *Config
Env map[string]string
AgentToken func() string
Manifest *atomic.Pointer[agentsdk.Manifest]
ServiceBanner *atomic.Pointer[codersdk.ServiceBannerConfig]
connCountVSCode atomic.Int64
connCountJetBrains atomic.Int64
@@ -109,7 +82,7 @@ type Server struct {
metrics *sshServerMetrics
}
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, config *Config) (*Server, error) {
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
@@ -121,30 +94,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
if err != nil {
return nil, err
}
if config == nil {
config = &Config{}
}
if config.X11DisplayOffset == nil {
offset := X11DefaultDisplayOffset
config.X11DisplayOffset = &offset
}
if config.UpdateEnv == nil {
config.UpdateEnv = func(current []string) ([]string, error) { return current, nil }
}
if config.MOTDFile == nil {
config.MOTDFile = func() string { return "" }
}
if config.AnnouncementBanners == nil {
config.AnnouncementBanners = func() *[]codersdk.BannerConfig { return &[]codersdk.BannerConfig{} }
}
if config.WorkingDirectory == nil {
config.WorkingDirectory = func() string {
home, err := userHomeDir()
if err != nil {
return ""
}
return home
}
if x11SocketDir == "" {
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
}
forwardHandler := &ssh.ForwardedTCPHandler{}
@@ -152,13 +103,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
config: config,
listeners: make(map[net.Listener]struct{}),
fs: fs,
conns: make(map[net.Conn]struct{}),
sessions: make(map[ssh.Session]struct{}),
logger: logger,
x11SocketDir: x11SocketDir,
metrics: metrics,
}
@@ -222,16 +172,14 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
},
}
// The MaxTimeout functionality has been substituted with the introduction
// of the KeepAlive feature. In cases where very short timeouts are set, the
// SSH server will automatically switch to the connection timeout for both
// read and write operations.
if config.MaxTimeout >= 3*time.Second {
// The MaxTimeout functionality has been substituted with the introduction of the KeepAlive feature.
// In cases where very short timeouts are set, the SSH server will automatically switch to the connection timeout for both read and write operations.
if maxTimeout >= 3*time.Second {
srv.ClientAliveCountMax = 3
srv.ClientAliveInterval = config.MaxTimeout / time.Duration(srv.ClientAliveCountMax)
srv.ClientAliveInterval = maxTimeout / time.Duration(srv.ClientAliveCountMax)
srv.MaxTimeout = 0
} else {
srv.MaxTimeout = config.MaxTimeout
srv.MaxTimeout = maxTimeout
}
s.srv = srv
@@ -274,25 +222,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
extraEnv := make([]string, 0)
x11, hasX11 := session.X11()
if hasX11 {
display, handled := s.x11Handler(session.Context(), x11)
handled := s.x11Handler(session.Context(), x11)
if !handled {
_ = session.Exit(1)
logger.Error(ctx, "x11 handler failed")
return
}
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=localhost:%d.%d", display, x11.ScreenNumber))
}
if s.fileTransferBlocked(session) {
s.logger.Warn(ctx, "file transfer blocked", slog.F("session_subsystem", session.Subsystem()), slog.F("raw_command", session.RawCommand()))
if session.Subsystem() == "" { // sftp does not expect error, otherwise it fails with "package too long"
// Response format: <status_code><message body>\n
errorMessage := fmt.Sprintf("\x02%s\n", BlockedFileTransferErrorMessage)
_, _ = session.Write([]byte(errorMessage))
}
_ = session.Exit(BlockedFileTransferErrorCode)
return
extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber))
}
switch ss := session.Subsystem(); ss {
@@ -345,37 +281,6 @@ func (s *Server) sessionHandler(session ssh.Session) {
_ = session.Exit(0)
}
// fileTransferBlocked method checks if the file transfer commands should be blocked.
//
// Warning: consider this mechanism as "Do not trespass" sign, as a violator can still ssh to the host,
// smuggle the `scp` binary, or just manually send files outside with `curl` or `ftp`.
// If a user needs a more sophisticated and battle-proof solution, consider full endpoint security.
func (s *Server) fileTransferBlocked(session ssh.Session) bool {
if !s.config.BlockFileTransfer {
return false // file transfers are permitted
}
// File transfers are restricted.
if session.Subsystem() == "sftp" {
return true
}
cmd := session.Command()
if len(cmd) == 0 {
return false // no command?
}
c := cmd[0]
c = filepath.Base(c) // in case the binary is absolute path, /usr/sbin/scp
for _, cmd := range BlockedFileTransferCommands {
if cmd == c {
return true
}
}
return false
}
func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, extraEnv []string) (retErr error) {
ctx := session.Context()
env := append(session.Environ(), extraEnv...)
@@ -495,24 +400,26 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
session.DisablePTYEmulation()
if isLoginShell(session.RawCommand()) {
banners := s.config.AnnouncementBanners()
if banners != nil {
for _, banner := range *banners {
err := showAnnouncementBanner(session, banner)
if err != nil {
logger.Error(ctx, "agent failed to show announcement banner", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "announcement_banner").Add(1)
break
}
serviceBanner := s.ServiceBanner.Load()
if serviceBanner != nil {
err := showServiceBanner(session, serviceBanner)
if err != nil {
logger.Error(ctx, "agent failed to show service banner", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "service_banner").Add(1)
}
}
}
if !isQuietLogin(s.fs, session.RawCommand()) {
err := showMOTD(s.fs, session, s.config.MOTDFile())
if err != nil {
logger.Error(ctx, "agent failed to show MOTD", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1)
manifest := s.Manifest.Load()
if manifest != nil {
err := showMOTD(s.fs, session, manifest.MOTDFile)
if err != nil {
logger.Error(ctx, "agent failed to show MOTD", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "motd").Add(1)
}
} else {
logger.Warn(ctx, "metadata lookup failed, unable to show MOTD")
}
}
@@ -650,7 +557,7 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
defer server.Close()
err = server.Serve()
if err == nil || errors.Is(err, io.EOF) {
if errors.Is(err, io.EOF) {
// Unless we call `session.Exit(0)` here, the client won't
// receive `exit-status` because `(*sftp.Server).Close()`
// calls `Close()` on the underlying connection (session),
@@ -682,6 +589,11 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
return nil, xerrors.Errorf("get user shell: %w", err)
}
manifest := s.Manifest.Load()
if manifest == nil {
return nil, xerrors.Errorf("no metadata was provided")
}
// OpenSSH executes all commands with the users current shell.
// We replicate that behavior for IDE support.
caller := "-c"
@@ -726,7 +638,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
cmd := pty.CommandContext(ctx, name, args...)
cmd.Dir = s.config.WorkingDirectory()
cmd.Dir = manifest.Directory
// If the metadata directory doesn't exist, we run the command
// in the users home directory.
@@ -740,7 +652,23 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
cmd.Dir = homedir
}
cmd.Env = append(os.Environ(), env...)
executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
// Set environment variables reliable detection of being inside a
// Coder workspace.
cmd.Env = append(cmd.Env, "CODER=true")
cmd.Env = append(cmd.Env, "CODER_WORKSPACE_NAME="+manifest.WorkspaceName)
cmd.Env = append(cmd.Env, "CODER_WORKSPACE_AGENT_NAME="+manifest.AgentName)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath))
// Specific Coder subcommands require the agent token exposed!
cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", s.AgentToken()))
// Set SSH connection environment variables (these are also set by OpenSSH
// and thus expected to be present by SSH clients). Since the agent does
@@ -751,9 +679,30 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort))
cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort))
cmd.Env, err = s.config.UpdateEnv(cmd.Env)
if err != nil {
return nil, xerrors.Errorf("apply env: %w", err)
// This adds the ports dialog to code-server that enables
// proxying a port dynamically.
// If this is empty string, do not set anything. Code-server auto defaults
// using its basepath to construct a path based port proxy.
if manifest.VSCodePortProxyURI != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI))
}
// Hide Coder message on code-server's "Getting Started" page
cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true")
// Load environment variables passed via the agent.
// These should override all variables we manually specify.
for envKey, value := range manifest.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepend
// or append to the $PATH, so allowing expand is required!
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value)))
}
// Agent-level environment variables should take over all!
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
for envKey, value := range s.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
}
return cmd, nil
@@ -948,9 +897,9 @@ func isQuietLogin(fs afero.Fs, rawCommand string) bool {
return err == nil
}
// showAnnouncementBanner will write the service banner if enabled and not blank
// showServiceBanner will write the service banner if enabled and not blank
// along with a blank line for spacing.
func showAnnouncementBanner(session io.Writer, banner codersdk.BannerConfig) error {
func showServiceBanner(session io.Writer, banner *codersdk.ServiceBannerConfig) error {
if banner.Enabled && banner.Message != "" {
// The banner supports Markdown so we might want to parse it but Markdown is
// still fairly readable in its raw form.
+1 -1
View File
@@ -37,7 +37,7 @@ func Test_sessionStart_orphan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
logger := slogtest.Make(t, nil)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
+25 -5
View File
@@ -17,12 +17,14 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
@@ -36,10 +38,14 @@ func TestNewServer_ServeClient(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -77,11 +83,13 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
t.Cleanup(func() {
_ = s.Close()
})
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
t.Run("Basic", func(t *testing.T) {
t.Parallel()
@@ -108,10 +116,14 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -159,10 +171,14 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -224,10 +240,14 @@ func TestNewServer_Signal(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), 0, "")
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
+55 -90
View File
@@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"os"
"path/filepath"
@@ -23,69 +22,61 @@ import (
"cdr.dev/slog"
)
const (
// X11StartPort is the starting port for X11 forwarding, this is the
// port used for "DISPLAY=localhost:0".
X11StartPort = 6000
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
X11DefaultDisplayOffset = 10
)
// x11Callback is called when the client requests X11 forwarding.
func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
// Always allow.
// It adds an Xauthority entry to the Xauthority file.
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
hostname, err := os.Hostname()
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
return false
}
err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
if err != nil {
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
return false
}
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
return false
}
return true
}
// x11Handler is called when a session has requested X11 forwarding.
// It listens for X11 connections and forwards them to the client.
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, handled bool) {
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
if !valid {
s.logger.Warn(ctx, "failed to get server connection")
return -1, false
return false
}
hostname, err := os.Hostname()
// We want to overwrite the socket so that subsequent connections will succeed.
socketPath := filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))
err := os.Remove(socketPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err))
return false
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
return -1, false
}
ln, display, err := createX11Listener(ctx, *s.config.X11DisplayOffset)
if err != nil {
s.logger.Warn(ctx, "failed to create X11 listener", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
return -1, false
}
s.trackListener(ln, true)
defer func() {
if !handled {
s.trackListener(ln, false)
_ = ln.Close()
}
}()
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(display), x11.AuthProtocol, x11.AuthCookie)
if err != nil {
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
return -1, false
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
return false
}
s.trackListener(listener, true)
go func() {
// Don't leave the listener open after the session is gone.
<-ctx.Done()
_ = ln.Close()
}()
go func() {
defer ln.Close()
defer s.trackListener(ln, false)
defer listener.Close()
defer s.trackListener(listener, false)
handledFirstConnection := false
for {
conn, err := ln.Accept()
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
@@ -93,66 +84,40 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) (displayNumber int, ha
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
return
}
if x11.SingleConnection {
s.logger.Debug(ctx, "single connection requested, closing X11 listener")
_ = ln.Close()
if x11.SingleConnection && handledFirstConnection {
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
_ = conn.Close()
continue
}
handledFirstConnection = true
tcpConn, ok := conn.(*net.TCPConn)
unixConn, ok := conn.(*net.UnixConn)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn))
_ = conn.Close()
continue
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
return
}
tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr)
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
if !ok {
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr()))
_ = conn.Close()
continue
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
return
}
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
OriginatorAddress string
OriginatorPort uint32
}{
OriginatorAddress: tcpAddr.IP.String(),
OriginatorPort: uint32(tcpAddr.Port),
OriginatorAddress: unixAddr.Name,
OriginatorPort: 0,
}))
if err != nil {
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
_ = conn.Close()
continue
return
}
go gossh.DiscardRequests(reqs)
if !s.trackConn(ln, conn, true) {
s.logger.Warn(ctx, "failed to track X11 connection")
_ = conn.Close()
continue
}
go func() {
defer s.trackConn(ln, conn, false)
Bicopy(ctx, conn, channel)
}()
go Bicopy(ctx, conn, channel)
}
}()
return display, true
}
// createX11Listener creates a listener for X11 forwarding, it will use
// the next available port starting from X11StartPort and displayOffset.
func createX11Listener(ctx context.Context, displayOffset int) (ln net.Listener, display int, err error) {
var lc net.ListenConfig
// Look for an open port to listen on.
for port := X11StartPort + displayOffset; port < math.MaxUint16; port++ {
ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err == nil {
display = port - X11StartPort
return ln, display, nil
}
}
return nil, -1, xerrors.Errorf("failed to find open port for X11 listener: %w", err)
return true
}
// addXauthEntry adds an Xauthority entry to the Xauthority file.
+11 -33
View File
@@ -1,17 +1,12 @@
package agentssh_test
import (
"bufio"
"bytes"
"context"
"encoding/hex"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"github.com/gliderlabs/ssh"
@@ -19,11 +14,13 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
)
@@ -36,10 +33,15 @@ func TestServer_X11(t *testing.T) {
ctx := context.Background()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
fs := afero.NewOsFs()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, &agentssh.Config{})
dir := t.TempDir()
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, 0, dir)
require.NoError(t, err)
defer s.Close()
// The assumption is that these are set before serving SSH connections.
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -55,45 +57,21 @@ func TestServer_X11(t *testing.T) {
sess, err := c.NewSession()
require.NoError(t, err)
wantScreenNumber := 1
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
AuthProtocol: "MIT-MAGIC-COOKIE-1",
AuthCookie: hex.EncodeToString([]byte("cookie")),
ScreenNumber: uint32(wantScreenNumber),
ScreenNumber: 0,
}))
require.NoError(t, err)
assert.True(t, reply)
// Want: ~DISPLAY=localhost:10.1
out, err := sess.Output("echo DISPLAY=$DISPLAY")
err = sess.Shell()
require.NoError(t, err)
sc := bufio.NewScanner(bytes.NewReader(out))
displayNumber := -1
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
t.Log(line)
if strings.HasPrefix(line, "DISPLAY=") {
parts := strings.SplitN(line, "=", 2)
display := parts[1]
parts = strings.SplitN(display, ":", 2)
parts = strings.SplitN(parts[1], ".", 2)
displayNumber, err = strconv.Atoi(parts[0])
require.NoError(t, err)
assert.GreaterOrEqual(t, displayNumber, 10, "display number should be >= 10")
gotScreenNumber, err := strconv.Atoi(parts[1])
require.NoError(t, err)
assert.Equal(t, wantScreenNumber, gotScreenNumber, "screen number should match")
break
}
}
require.NoError(t, sc.Err())
require.NotEqual(t, -1, displayNumber)
x11Chans := c.HandleChannelOpen("x11")
payload := "hello world"
require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
if err == nil {
_, err = conn.Write([]byte(payload))
assert.NoError(t, err)
+122 -146
View File
@@ -9,12 +9,9 @@ import (
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/durationpb"
"storj.io/drpc"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
@@ -30,13 +27,11 @@ import (
"github.com/coder/coder/v2/testutil"
)
const statsInterval = 500 * time.Millisecond
func NewClient(t testing.TB,
logger slog.Logger,
agentID uuid.UUID,
manifest agentsdk.Manifest,
statsChan chan *agentproto.Stats,
statsChan chan *agentsdk.Stats,
coordinator tailnet.Coordinator,
) *Client {
if manifest.AgentID == uuid.Nil {
@@ -48,7 +43,7 @@ func NewClient(t testing.TB,
derpMapUpdates := make(chan *tailcfg.DERPMap)
drpcService := &tailnet.DRPCService{
CoordPtr: &coordPtr,
Logger: logger.Named("tailnetsvc"),
Logger: logger,
DerpMapUpdateFrequency: time.Microsecond,
DerpMapFn: func() *tailcfg.DERPMap { return <-derpMapUpdates },
}
@@ -56,7 +51,7 @@ func NewClient(t testing.TB,
require.NoError(t, err)
mp, err := agentsdk.ProtoFromManifest(manifest)
require.NoError(t, err)
fakeAAPI := NewFakeAgentAPI(t, logger, mp, statsChan)
fakeAAPI := NewFakeAgentAPI(t, logger, mp)
err = agentproto.DRPCRegisterAgent(mux, fakeAAPI)
require.NoError(t, err)
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
@@ -71,6 +66,7 @@ func NewClient(t testing.TB,
t: t,
logger: logger.Named("client"),
agentID: agentID,
statsChan: statsChan,
coordinator: coordinator,
server: server,
fakeAgentAPI: fakeAAPI,
@@ -82,15 +78,19 @@ type Client struct {
t testing.TB
logger slog.Logger
agentID uuid.UUID
metadata map[string]agentsdk.Metadata
statsChan chan *agentsdk.Stats
coordinator tailnet.Coordinator
server *drpcserver.Server
fakeAgentAPI *FakeAgentAPI
LastWorkspaceAgent func()
PatchWorkspaceLogs func() error
mu sync.Mutex // Protects following.
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
}
func (*Client) RewriteDERPMap(*tailcfg.DERPMap) {}
@@ -108,10 +108,11 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
c.t.Cleanup(c.LastWorkspaceAgent)
serveCtx, cancel := context.WithCancel(ctx)
c.t.Cleanup(cancel)
auth := tailnet.AgentTunnelAuth{}
streamID := tailnet.StreamID{
Name: "agenttest",
ID: c.agentID,
Auth: tailnet.AgentCoordinateeAuth{ID: c.agentID},
Auth: auth,
}
serveCtx = tailnet.WithStreamID(serveCtx, streamID)
go func() {
@@ -120,8 +121,50 @@ func (c *Client) ConnectRPC(ctx context.Context) (drpc.Conn, error) {
return conn, nil
}
func (c *Client) ReportStats(ctx context.Context, _ slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) {
doneCh := make(chan struct{})
ctx, cancel := context.WithCancel(ctx)
go func() {
defer close(doneCh)
setInterval(500 * time.Millisecond)
for {
select {
case <-ctx.Done():
return
case stat := <-statsChan:
select {
case c.statsChan <- stat:
case <-ctx.Done():
return
default:
// We don't want to send old stats.
continue
}
}
}
}()
return closeFunc(func() error {
cancel()
<-doneCh
close(c.statsChan)
return nil
}), nil
}
func (c *Client) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
return c.fakeAgentAPI.GetLifecycleStates()
c.mu.Lock()
defer c.mu.Unlock()
return c.lifecycleStates
}
func (c *Client) PostLifecycle(ctx context.Context, req agentsdk.PostLifecycleRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.lifecycleStates = append(c.lifecycleStates, req.State)
c.logger.Debug(ctx, "post lifecycle", slog.F("req", req))
return nil
}
func (c *Client) GetStartup() <-chan *agentproto.Startup {
@@ -129,7 +172,22 @@ func (c *Client) GetStartup() <-chan *agentproto.Startup {
}
func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
return c.fakeAgentAPI.GetMetadata()
c.mu.Lock()
defer c.mu.Unlock()
return maps.Clone(c.metadata)
}
func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.metadata == nil {
c.metadata = make(map[string]agentsdk.Metadata)
}
for _, md := range req.Metadata {
c.metadata[md.Key] = md
c.logger.Debug(ctx, "post metadata", slog.F("key", md.Key), slog.F("md", md))
}
return nil
}
func (c *Client) GetStartupLogs() []agentsdk.Log {
@@ -138,8 +196,19 @@ func (c *Client) GetStartupLogs() []agentsdk.Log {
return c.logs
}
func (c *Client) SetAnnouncementBannersFunc(f func() ([]codersdk.BannerConfig, error)) {
c.fakeAgentAPI.SetAnnouncementBannersFunc(f)
func (c *Client) PatchLogs(ctx context.Context, logs agentsdk.PatchLogs) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.PatchWorkspaceLogs != nil {
return c.PatchWorkspaceLogs()
}
c.logs = append(c.logs, logs.Logs...)
c.logger.Debug(ctx, "patch startup logs", slog.F("req", logs))
return nil
}
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
c.fakeAgentAPI.SetServiceBannerFunc(f)
}
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
@@ -154,8 +223,10 @@ func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
return nil
}
func (c *Client) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
c.fakeAgentAPI.SetLogsChannel(ch)
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
type FakeAgentAPI struct {
@@ -163,166 +234,71 @@ type FakeAgentAPI struct {
t testing.TB
logger slog.Logger
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
statsCh chan *agentproto.Stats
appHealthCh chan *agentproto.BatchUpdateAppHealthRequest
logsCh chan<- *agentproto.BatchCreateLogsRequest
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata
timings []*agentproto.Timing
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup
getAnnouncementBannersFunc func() ([]codersdk.BannerConfig, error)
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
func (*FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
return &agentproto.ServiceBanner{}, nil
}
func (f *FakeAgentAPI) GetTimings() []*agentproto.Timing {
func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
f.Lock()
defer f.Unlock()
return slices.Clone(f.timings)
f.getServiceBannerFunc = fn
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
}
func (f *FakeAgentAPI) SetAnnouncementBannersFunc(fn func() ([]codersdk.BannerConfig, error)) {
func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
f.Lock()
defer f.Unlock()
f.getAnnouncementBannersFunc = fn
f.logger.Info(context.Background(), "updated notification banners")
}
func (f *FakeAgentAPI) GetAnnouncementBanners(context.Context, *agentproto.GetAnnouncementBannersRequest) (*agentproto.GetAnnouncementBannersResponse, error) {
f.Lock()
defer f.Unlock()
if f.getAnnouncementBannersFunc == nil {
return &agentproto.GetAnnouncementBannersResponse{AnnouncementBanners: []*agentproto.BannerConfig{}}, nil
if f.getServiceBannerFunc == nil {
return &agentproto.ServiceBanner{}, nil
}
banners, err := f.getAnnouncementBannersFunc()
sb, err := f.getServiceBannerFunc()
if err != nil {
return nil, err
}
bannersProto := make([]*agentproto.BannerConfig, 0, len(banners))
for _, banner := range banners {
bannersProto = append(bannersProto, agentsdk.ProtoFromBannerConfig(banner))
}
return &agentproto.GetAnnouncementBannersResponse{AnnouncementBanners: bannersProto}, nil
return agentsdk.ProtoFromServiceBanner(sb), nil
}
func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
f.logger.Debug(ctx, "update stats called", slog.F("req", req))
// empty request is sent to get the interval; but our tests don't want empty stats requests
if req.Stats != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case f.statsCh <- req.Stats:
// OK!
}
}
return &agentproto.UpdateStatsResponse{ReportInterval: durationpb.New(statsInterval)}, nil
func (*FakeAgentAPI) UpdateStats(context.Context, *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
// TODO implement me
panic("implement me")
}
func (f *FakeAgentAPI) GetLifecycleStates() []codersdk.WorkspaceAgentLifecycle {
f.Lock()
defer f.Unlock()
return slices.Clone(f.lifecycleStates)
}
func (f *FakeAgentAPI) UpdateLifecycle(_ context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
f.Lock()
defer f.Unlock()
s, err := agentsdk.LifecycleStateFromProto(req.GetLifecycle().GetState())
if assert.NoError(f.t, err) {
f.lifecycleStates = append(f.lifecycleStates, s)
}
return req.GetLifecycle(), nil
func (*FakeAgentAPI) UpdateLifecycle(context.Context, *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) {
// TODO implement me
panic("implement me")
}
func (f *FakeAgentAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
f.logger.Debug(ctx, "batch update app health", slog.F("req", req))
select {
case <-ctx.Done():
return nil, ctx.Err()
case f.appHealthCh <- req:
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (f *FakeAgentAPI) AppHealthCh() <-chan *agentproto.BatchUpdateAppHealthRequest {
return f.appHealthCh
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
f.startupCh <- req.GetStartup()
return req.GetStartup(), nil
}
func (f *FakeAgentAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case f.startupCh <- req.GetStartup():
return req.GetStartup(), nil
}
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
// TODO implement me
panic("implement me")
}
func (f *FakeAgentAPI) GetMetadata() map[string]agentsdk.Metadata {
f.Lock()
defer f.Unlock()
return maps.Clone(f.metadata)
func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
// TODO implement me
panic("implement me")
}
func (f *FakeAgentAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
f.Lock()
defer f.Unlock()
if f.metadata == nil {
f.metadata = make(map[string]agentsdk.Metadata)
}
for _, md := range req.Metadata {
smd := agentsdk.MetadataFromProto(md)
f.metadata[md.Key] = smd
f.logger.Debug(ctx, "post metadata", slog.F("key", md.Key), slog.F("md", md))
}
return &agentproto.BatchUpdateMetadataResponse{}, nil
}
func (f *FakeAgentAPI) SetLogsChannel(ch chan<- *agentproto.BatchCreateLogsRequest) {
f.Lock()
defer f.Unlock()
f.logsCh = ch
}
func (f *FakeAgentAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) {
f.logger.Info(ctx, "batch create logs called", slog.F("req", req))
f.Lock()
ch := f.logsCh
f.Unlock()
if ch != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case ch <- req:
// ok
}
}
return &agentproto.BatchCreateLogsResponse{}, nil
}
func (f *FakeAgentAPI) ScriptCompleted(_ context.Context, req *agentproto.WorkspaceAgentScriptCompletedRequest) (*agentproto.WorkspaceAgentScriptCompletedResponse, error) {
f.Lock()
f.timings = append(f.timings, req.Timing)
f.Unlock()
return &agentproto.WorkspaceAgentScriptCompletedResponse{}, nil
}
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest, statsCh chan *agentproto.Stats) *FakeAgentAPI {
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
statsCh: statsCh,
startupCh: make(chan *agentproto.Startup, 100),
appHealthCh: make(chan *agentproto.BatchUpdateAppHealthRequest, 100),
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
startupCh: make(chan *agentproto.Startup, 100),
}
}
-7
View File
@@ -35,14 +35,7 @@ func (a *agent) apiHandler() http.Handler {
ignorePorts: cpy,
cacheDuration: cacheDuration,
}
promHandler := PrometheusMetricsHandler(a.prometheusRegistry, a.logger)
r.Get("/api/v0/listening-ports", lp.handler)
r.Get("/api/v0/netcheck", a.HandleNetcheck)
r.Get("/debug/logs", a.HandleHTTPDebugLogs)
r.Get("/debug/magicsock", a.HandleHTTPDebugMagicsock)
r.Get("/debug/magicsock/debug-logging/{state}", a.HandleHTTPMagicsockDebugLoggingState)
r.Get("/debug/manifest", a.HandleHTTPDebugManifest)
r.Get("/debug/prometheus", promHandler.ServeHTTP)
return r
}
+58 -67
View File
@@ -12,9 +12,12 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/quartz"
"github.com/coder/retry"
)
// WorkspaceAgentApps fetches the workspace apps.
type WorkspaceAgentApps func(context.Context) ([]codersdk.WorkspaceApp, error)
// PostWorkspaceAgentAppHealth updates the workspace app health.
type PostWorkspaceAgentAppHealth func(context.Context, agentsdk.PostAppHealthsRequest) error
@@ -23,26 +26,10 @@ type WorkspaceAppHealthReporter func(ctx context.Context)
// NewWorkspaceAppHealthReporter creates a WorkspaceAppHealthReporter that reports app health to coderd.
func NewWorkspaceAppHealthReporter(logger slog.Logger, apps []codersdk.WorkspaceApp, postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth) WorkspaceAppHealthReporter {
return NewAppHealthReporterWithClock(logger, apps, postWorkspaceAgentAppHealth, quartz.NewReal())
}
// NewAppHealthReporterWithClock is only called directly by test code. Product code should call
// NewAppHealthReporter.
func NewAppHealthReporterWithClock(
logger slog.Logger,
apps []codersdk.WorkspaceApp,
postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth,
clk quartz.Clock,
) WorkspaceAppHealthReporter {
logger = logger.Named("apphealth")
return func(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
runHealthcheckLoop := func(ctx context.Context) error {
// no need to run this loop if no apps for this workspace.
if len(apps) == 0 {
return
return nil
}
hasHealthchecksEnabled := false
@@ -57,7 +44,7 @@ func NewAppHealthReporterWithClock(
// no need to run this loop if no health checks are configured.
if !hasHealthchecksEnabled {
return
return nil
}
// run a ticker for each app health check.
@@ -69,29 +56,25 @@ func NewAppHealthReporterWithClock(
}
app := nextApp
go func() {
_ = clk.TickerFunc(ctx, time.Duration(app.Healthcheck.Interval)*time.Second, func() error {
// We time out at the healthcheck interval to prevent getting too backed up, but
// set it 1ms early so that it's not simultaneous with the next tick in testing,
// which makes the test easier to understand.
//
// It would be idiomatic to use the http.Client.Timeout or a context.WithTimeout,
// but we are passing this off to the native http library, which is not aware
// of the clock library we are using. That means in testing, with a mock clock
// it will compare mocked times with real times, and we will get strange results.
// So, we just implement the timeout as a context we cancel with an AfterFunc
reqCtx, reqCancel := context.WithCancel(ctx)
timeout := clk.AfterFunc(
time.Duration(app.Healthcheck.Interval)*time.Second-time.Millisecond,
reqCancel,
"timeout", app.Slug)
defer timeout.Stop()
t := time.NewTicker(time.Duration(app.Healthcheck.Interval) * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
}
// we set the http timeout to the healthcheck interval to prevent getting too backed up.
client := &http.Client{
Timeout: time.Duration(app.Healthcheck.Interval) * time.Second,
}
err := func() error {
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, app.Healthcheck.URL, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, app.Healthcheck.URL, nil)
if err != nil {
return err
}
res, err := http.DefaultClient.Do(req)
res, err := client.Do(req)
if err != nil {
return err
}
@@ -104,7 +87,6 @@ func NewAppHealthReporterWithClock(
return nil
}()
if err != nil {
nowUnhealthy := false
mu.Lock()
if failures[app.ID] < int(app.Healthcheck.Threshold) {
// increment the failure count and keep status the same.
@@ -114,52 +96,61 @@ func NewAppHealthReporterWithClock(
// set to unhealthy if we hit the failure threshold.
// we stop incrementing at the threshold to prevent the failure value from increasing forever.
health[app.ID] = codersdk.WorkspaceAppHealthUnhealthy
nowUnhealthy = true
}
mu.Unlock()
logger.Debug(ctx, "error checking app health",
slog.F("id", app.ID.String()),
slog.F("slug", app.Slug),
slog.F("now_unhealthy", nowUnhealthy), slog.Error(err),
)
} else {
mu.Lock()
// we only need one successful health check to be considered healthy.
health[app.ID] = codersdk.WorkspaceAppHealthHealthy
failures[app.ID] = 0
mu.Unlock()
logger.Debug(ctx, "workspace app healthy", slog.F("id", app.ID.String()), slog.F("slug", app.Slug))
}
return nil
}, "healthcheck", app.Slug)
t.Reset(time.Duration(app.Healthcheck.Interval) * time.Second)
}
}()
}
mu.Lock()
lastHealth := copyHealth(health)
mu.Unlock()
reportTicker := clk.TickerFunc(ctx, time.Second, func() error {
mu.RLock()
changed := healthChanged(lastHealth, health)
mu.RUnlock()
if !changed {
reportTicker := time.NewTicker(time.Second)
defer reportTicker.Stop()
// every second we check if the health values of the apps have changed
// and if there is a change we will report the new values.
for {
select {
case <-ctx.Done():
return nil
}
case <-reportTicker.C:
mu.RLock()
changed := healthChanged(lastHealth, health)
mu.RUnlock()
if !changed {
continue
}
mu.Lock()
lastHealth = copyHealth(health)
mu.Unlock()
err := postWorkspaceAgentAppHealth(ctx, agentsdk.PostAppHealthsRequest{
Healths: lastHealth,
})
if err != nil {
logger.Error(ctx, "failed to report workspace app health", slog.Error(err))
} else {
logger.Debug(ctx, "sent workspace app health", slog.F("health", lastHealth))
mu.Lock()
lastHealth = copyHealth(health)
mu.Unlock()
err := postWorkspaceAgentAppHealth(ctx, agentsdk.PostAppHealthsRequest{
Healths: lastHealth,
})
if err != nil {
logger.Error(ctx, "failed to report workspace app stat", slog.Error(err))
}
}
return nil
}, "report")
_ = reportTicker.Wait() // only possible error is context done
}
}
return func(ctx context.Context) {
for r := retry.New(time.Second, 30*time.Second); r.Wait(ctx); {
err := runHealthcheckLoop(ctx)
if err == nil || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return
}
logger.Error(ctx, "failed running workspace app reporter", slog.Error(err))
}
}
}
+92 -175
View File
@@ -4,39 +4,33 @@ import (
"context"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestAppHealth_Healthy(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
apps := []codersdk.WorkspaceApp{
{
ID: uuid.UUID{1},
Slug: "app1",
Healthcheck: codersdk.Healthcheck{},
Health: codersdk.WorkspaceAppHealthDisabled,
},
{
ID: uuid.UUID{2},
Slug: "app2",
Healthcheck: codersdk.Healthcheck{
// URL: We don't set the URL for this test because the setup will
@@ -46,81 +40,34 @@ func TestAppHealth_Healthy(t *testing.T) {
},
Health: codersdk.WorkspaceAppHealthInitializing,
},
{
ID: uuid.UUID{3},
Slug: "app3",
Healthcheck: codersdk.Healthcheck{
Interval: 2,
Threshold: 1,
},
Health: codersdk.WorkspaceAppHealthInitializing,
},
}
checks2 := 0
checks3 := 0
handlers := []http.Handler{
nil,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
checks2++
httpapi.Write(r.Context(), w, http.StatusOK, nil)
}),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
checks3++
httpapi.Write(r.Context(), w, http.StatusOK, nil)
}),
}
mClock := quartz.NewMock(t)
healthcheckTrap := mClock.Trap().TickerFunc("healthcheck")
defer healthcheckTrap.Close()
reportTrap := mClock.Trap().TickerFunc("report")
defer reportTrap.Close()
fakeAPI, closeFn := setupAppReporter(ctx, t, slices.Clone(apps), handlers, mClock)
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
defer closeFn()
healthchecksStarted := make([]string, 2)
for i := 0; i < 2; i++ {
c := healthcheckTrap.MustWait(ctx)
c.Release()
healthchecksStarted[i] = c.Tags[1]
}
slices.Sort(healthchecksStarted)
require.Equal(t, []string{"app2", "app3"}, healthchecksStarted)
apps, err := getApps(ctx)
require.NoError(t, err)
require.EqualValues(t, codersdk.WorkspaceAppHealthDisabled, apps[0].Health)
require.Eventually(t, func() bool {
apps, err := getApps(ctx)
if err != nil {
return false
}
// advance the clock 1ms before the report ticker starts, so that it's not
// simultaneous with the checks.
mClock.Advance(time.Millisecond).MustWait(ctx)
reportTrap.MustWait(ctx).Release()
mClock.Advance(999 * time.Millisecond).MustWait(ctx) // app2 is now healthy
mClock.Advance(time.Millisecond).MustWait(ctx) // report gets triggered
update := testutil.RequireRecvCtx(ctx, t, fakeAPI.AppHealthCh())
require.Len(t, update.GetUpdates(), 2)
applyUpdate(t, apps, update)
require.Equal(t, codersdk.WorkspaceAppHealthHealthy, apps[1].Health)
require.Equal(t, codersdk.WorkspaceAppHealthInitializing, apps[2].Health)
mClock.Advance(999 * time.Millisecond).MustWait(ctx) // app3 is now healthy
mClock.Advance(time.Millisecond).MustWait(ctx) // report gets triggered
update = testutil.RequireRecvCtx(ctx, t, fakeAPI.AppHealthCh())
require.Len(t, update.GetUpdates(), 2)
applyUpdate(t, apps, update)
require.Equal(t, codersdk.WorkspaceAppHealthHealthy, apps[1].Health)
require.Equal(t, codersdk.WorkspaceAppHealthHealthy, apps[2].Health)
// ensure we aren't spamming
require.Equal(t, 2, checks2)
require.Equal(t, 1, checks3)
return apps[1].Health == codersdk.WorkspaceAppHealthHealthy
}, testutil.WaitLong, testutil.IntervalSlow)
}
func TestAppHealth_500(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
apps := []codersdk.WorkspaceApp{
{
ID: uuid.UUID{2},
Slug: "app2",
Healthcheck: codersdk.Healthcheck{
// URL: We don't set the URL for this test because the setup will
@@ -136,40 +83,59 @@ func TestAppHealth_500(t *testing.T) {
httpapi.Write(r.Context(), w, http.StatusInternalServerError, nil)
}),
}
mClock := quartz.NewMock(t)
healthcheckTrap := mClock.Trap().TickerFunc("healthcheck")
defer healthcheckTrap.Close()
reportTrap := mClock.Trap().TickerFunc("report")
defer reportTrap.Close()
fakeAPI, closeFn := setupAppReporter(ctx, t, slices.Clone(apps), handlers, mClock)
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
defer closeFn()
healthcheckTrap.MustWait(ctx).Release()
// advance the clock 1ms before the report ticker starts, so that it's not
// simultaneous with the checks.
mClock.Advance(time.Millisecond).MustWait(ctx)
reportTrap.MustWait(ctx).Release()
require.Eventually(t, func() bool {
apps, err := getApps(ctx)
if err != nil {
return false
}
mClock.Advance(999 * time.Millisecond).MustWait(ctx) // check gets triggered
mClock.Advance(time.Millisecond).MustWait(ctx) // report gets triggered, but unsent since we are at the threshold
mClock.Advance(999 * time.Millisecond).MustWait(ctx) // 2nd check, crosses threshold
mClock.Advance(time.Millisecond).MustWait(ctx) // 2nd report, sends update
update := testutil.RequireRecvCtx(ctx, t, fakeAPI.AppHealthCh())
require.Len(t, update.GetUpdates(), 1)
applyUpdate(t, apps, update)
require.Equal(t, codersdk.WorkspaceAppHealthUnhealthy, apps[0].Health)
return apps[0].Health == codersdk.WorkspaceAppHealthUnhealthy
}, testutil.WaitLong, testutil.IntervalSlow)
}
func TestAppHealth_Timeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
apps := []codersdk.WorkspaceApp{
{
Slug: "app2",
Healthcheck: codersdk.Healthcheck{
// URL: We don't set the URL for this test because the setup will
// create a httptest server for us and set it for us.
Interval: 1,
Threshold: 1,
},
Health: codersdk.WorkspaceAppHealthInitializing,
},
}
handlers := []http.Handler{
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// sleep longer than the interval to cause the health check to time out
time.Sleep(2 * time.Second)
httpapi.Write(r.Context(), w, http.StatusOK, nil)
}),
}
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
defer closeFn()
require.Eventually(t, func() bool {
apps, err := getApps(ctx)
if err != nil {
return false
}
return apps[0].Health == codersdk.WorkspaceAppHealthUnhealthy
}, testutil.WaitLong, testutil.IntervalSlow)
}
func TestAppHealth_NotSpamming(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
apps := []codersdk.WorkspaceApp{
{
ID: uuid.UUID{2},
Slug: "app2",
Healthcheck: codersdk.Healthcheck{
// URL: We don't set the URL for this test because the setup will
@@ -181,66 +147,22 @@ func TestAppHealth_Timeout(t *testing.T) {
},
}
counter := new(int32)
handlers := []http.Handler{
http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
// allow the request to time out
<-r.Context().Done()
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(counter, 1)
}),
}
mClock := quartz.NewMock(t)
start := mClock.Now()
// for this test, it's easier to think in the number of milliseconds elapsed
// since start.
ms := func(n int) time.Time {
return start.Add(time.Duration(n) * time.Millisecond)
}
healthcheckTrap := mClock.Trap().TickerFunc("healthcheck")
defer healthcheckTrap.Close()
reportTrap := mClock.Trap().TickerFunc("report")
defer reportTrap.Close()
timeoutTrap := mClock.Trap().AfterFunc("timeout")
defer timeoutTrap.Close()
fakeAPI, closeFn := setupAppReporter(ctx, t, apps, handlers, mClock)
_, closeFn := setupAppReporter(ctx, t, apps, handlers)
defer closeFn()
healthcheckTrap.MustWait(ctx).Release()
// advance the clock 1ms before the report ticker starts, so that it's not
// simultaneous with the checks.
mClock.Set(ms(1)).MustWait(ctx)
reportTrap.MustWait(ctx).Release()
w := mClock.Set(ms(1000)) // 1st check starts
timeoutTrap.MustWait(ctx).Release()
mClock.Set(ms(1001)).MustWait(ctx) // report tick, no change
mClock.Set(ms(1999)) // timeout pops
w.MustWait(ctx) // 1st check finished
w = mClock.Set(ms(2000)) // 2nd check starts
timeoutTrap.MustWait(ctx).Release()
mClock.Set(ms(2001)).MustWait(ctx) // report tick, no change
mClock.Set(ms(2999)) // timeout pops
w.MustWait(ctx) // 2nd check finished
// app is now unhealthy after 2 timeouts
mClock.Set(ms(3000)) // 3rd check starts
timeoutTrap.MustWait(ctx).Release()
mClock.Set(ms(3001)).MustWait(ctx) // report tick, sends changes
update := testutil.RequireRecvCtx(ctx, t, fakeAPI.AppHealthCh())
require.Len(t, update.GetUpdates(), 1)
applyUpdate(t, apps, update)
require.Equal(t, codersdk.WorkspaceAppHealthUnhealthy, apps[0].Health)
// Ensure we haven't made more than 2 (expected 1 + 1 for buffer) requests in the last second.
// if there is a bug where we are spamming the healthcheck route this will catch it.
time.Sleep(time.Second)
require.LessOrEqual(t, atomic.LoadInt32(counter), int32(2))
}
func setupAppReporter(
ctx context.Context, t *testing.T,
apps []codersdk.WorkspaceApp,
handlers []http.Handler,
clk quartz.Clock,
) (*agenttest.FakeAgentAPI, func()) {
func setupAppReporter(ctx context.Context, t *testing.T, apps []codersdk.WorkspaceApp, handlers []http.Handler) (agent.WorkspaceAgentApps, func()) {
closers := []func(){}
for _, app := range apps {
require.NotEqual(t, uuid.Nil, app.ID, "all apps must have ID set")
}
for i, handler := range handlers {
if handler == nil {
continue
@@ -252,39 +174,34 @@ func setupAppReporter(
closers = append(closers, ts.Close)
}
// We don't care about manifest or stats in this test since it's not using
// a full agent and these RPCs won't get called.
//
// We use a proper fake agent API so we can test the conversion code and the
// request code as well. Before we were bypassing these by using a custom
// post function.
fakeAAPI := agenttest.NewFakeAgentAPI(t, slogtest.Make(t, nil), nil, nil)
var mu sync.Mutex
workspaceAgentApps := func(context.Context) ([]codersdk.WorkspaceApp, error) {
mu.Lock()
defer mu.Unlock()
var newApps []codersdk.WorkspaceApp
return append(newApps, apps...), nil
}
postWorkspaceAgentAppHealth := func(_ context.Context, req agentsdk.PostAppHealthsRequest) error {
mu.Lock()
for id, health := range req.Healths {
for i, app := range apps {
if app.ID != id {
continue
}
app.Health = health
apps[i] = app
}
}
mu.Unlock()
go agent.NewAppHealthReporterWithClock(
slogtest.Make(t, nil).Leveled(slog.LevelDebug),
apps, agentsdk.AppHealthPoster(fakeAAPI), clk,
)(ctx)
return nil
}
return fakeAAPI, func() {
go agent.NewWorkspaceAppHealthReporter(slogtest.Make(t, nil).Leveled(slog.LevelDebug), apps, postWorkspaceAgentAppHealth)(ctx)
return workspaceAgentApps, func() {
for _, closeFn := range closers {
closeFn()
}
}
}
func applyUpdate(t *testing.T, apps []codersdk.WorkspaceApp, req *proto.BatchUpdateAppHealthRequest) {
t.Helper()
for _, update := range req.Updates {
updateID, err := uuid.FromBytes(update.Id)
require.NoError(t, err)
updateHealth := codersdk.WorkspaceAppHealth(strings.ToLower(proto.AppHealth_name[int32(update.Health)]))
for i, app := range apps {
if app.ID != updateID {
continue
}
app.Health = updateHealth
apps[i] = app
}
}
}
-51
View File
@@ -1,51 +0,0 @@
package agent
import (
"context"
"runtime"
"sync"
"cdr.dev/slog"
)
// checkpoint allows a goroutine to communicate when it is OK to proceed beyond some async condition
// to other dependent goroutines.
type checkpoint struct {
logger slog.Logger
mu sync.Mutex
called bool
done chan struct{}
err error
}
// complete the checkpoint. Pass nil to indicate the checkpoint was ok. It is an error to call this
// more than once.
func (c *checkpoint) complete(err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.called {
b := make([]byte, 2048)
n := runtime.Stack(b, false)
c.logger.Critical(context.Background(), "checkpoint complete called more than once", slog.F("stacktrace", b[:n]))
return
}
c.called = true
c.err = err
close(c.done)
}
func (c *checkpoint) wait(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return c.err
}
}
func newCheckpoint(logger slog.Logger) *checkpoint {
return &checkpoint{
logger: logger,
done: make(chan struct{}),
}
}
-49
View File
@@ -1,49 +0,0 @@
package agent
import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
)
func TestCheckpoint_CompleteWait(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
ctx := testutil.Context(t, testutil.WaitShort)
uut := newCheckpoint(logger)
err := xerrors.New("test")
uut.complete(err)
got := uut.wait(ctx)
require.Equal(t, err, got)
}
func TestCheckpoint_CompleteTwice(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctx := testutil.Context(t, testutil.WaitShort)
uut := newCheckpoint(logger)
err := xerrors.New("test")
uut.complete(err)
uut.complete(nil) // drops CRITICAL log
got := uut.wait(ctx)
require.Equal(t, err, got)
}
func TestCheckpoint_WaitComplete(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
ctx := testutil.Context(t, testutil.WaitShort)
uut := newCheckpoint(logger)
err := xerrors.New("test")
errCh := make(chan error, 1)
go func() {
errCh <- uut.wait(ctx)
}()
uut.complete(err)
got := testutil.RequireRecvCtx(ctx, t, errCh)
require.Equal(t, err, got)
}
-31
View File
@@ -1,31 +0,0 @@
package agent
import (
"net/http"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
)
func (a *agent) HandleNetcheck(rw http.ResponseWriter, r *http.Request) {
ni := a.TailnetConn().GetNetInfo()
ifReport, err := healthsdk.RunInterfacesReport()
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to run interfaces report",
Detail: err.Error(),
})
return
}
httpapi.Write(r.Context(), rw, http.StatusOK, healthsdk.AgentNetcheckReport{
BaseReport: healthsdk.BaseReport{
Severity: health.SeverityOK,
},
NetInfo: ni,
Interfaces: ifReport,
})
}
+15 -24
View File
@@ -10,7 +10,8 @@ import (
"tailscale.com/util/clientmetric"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk/agentsdk"
)
type agentMetrics struct {
@@ -19,7 +20,6 @@ type agentMetrics struct {
// startupScriptSeconds is the time in seconds that the start script(s)
// took to run. This is reported once per agent.
startupScriptSeconds *prometheus.GaugeVec
currentConnections *prometheus.GaugeVec
}
func newAgentMetrics(registerer prometheus.Registerer) *agentMetrics {
@@ -46,24 +46,15 @@ func newAgentMetrics(registerer prometheus.Registerer) *agentMetrics {
}, []string{"success"})
registerer.MustRegister(startupScriptSeconds)
currentConnections := prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "coderd",
Subsystem: "agentstats",
Name: "currently_reachable_peers",
Help: "The number of peers (e.g. clients) that are currently reachable over the encrypted network.",
}, []string{"connection_type"})
registerer.MustRegister(currentConnections)
return &agentMetrics{
connectionsTotal: connectionsTotal,
reconnectingPTYErrors: reconnectingPTYErrors,
startupScriptSeconds: startupScriptSeconds,
currentConnections: currentConnections,
}
}
func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
var collected []*proto.Stats_Metric
func (a *agent) collectMetrics(ctx context.Context) []agentsdk.AgentMetric {
var collected []agentsdk.AgentMetric
// Tailscale internal metrics
metrics := clientmetric.Metrics()
@@ -72,7 +63,7 @@ func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
continue
}
collected = append(collected, &proto.Stats_Metric{
collected = append(collected, agentsdk.AgentMetric{
Name: m.Name(),
Type: asMetricType(m.Type()),
Value: float64(m.Value()),
@@ -90,16 +81,16 @@ func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
labels := toAgentMetricLabels(metric.Label)
if metric.Counter != nil {
collected = append(collected, &proto.Stats_Metric{
collected = append(collected, agentsdk.AgentMetric{
Name: metricFamily.GetName(),
Type: proto.Stats_Metric_COUNTER,
Type: agentsdk.AgentMetricTypeCounter,
Value: metric.Counter.GetValue(),
Labels: labels,
})
} else if metric.Gauge != nil {
collected = append(collected, &proto.Stats_Metric{
collected = append(collected, agentsdk.AgentMetric{
Name: metricFamily.GetName(),
Type: proto.Stats_Metric_GAUGE,
Type: agentsdk.AgentMetricTypeGauge,
Value: metric.Gauge.GetValue(),
Labels: labels,
})
@@ -111,14 +102,14 @@ func (a *agent) collectMetrics(ctx context.Context) []*proto.Stats_Metric {
return collected
}
func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []*proto.Stats_Metric_Label {
func toAgentMetricLabels(metricLabels []*prompb.LabelPair) []agentsdk.AgentMetricLabel {
if len(metricLabels) == 0 {
return nil
}
labels := make([]*proto.Stats_Metric_Label, 0, len(metricLabels))
labels := make([]agentsdk.AgentMetricLabel, 0, len(metricLabels))
for _, metricLabel := range metricLabels {
labels = append(labels, &proto.Stats_Metric_Label{
labels = append(labels, agentsdk.AgentMetricLabel{
Name: metricLabel.GetName(),
Value: metricLabel.GetValue(),
})
@@ -139,12 +130,12 @@ func isIgnoredMetric(metricName string) bool {
return false
}
func asMetricType(typ clientmetric.Type) proto.Stats_Metric_Type {
func asMetricType(typ clientmetric.Type) agentsdk.AgentMetricType {
switch typ {
case clientmetric.TypeGauge:
return proto.Stats_Metric_GAUGE
return agentsdk.AgentMetricTypeGauge
case clientmetric.TypeCounter:
return proto.Stats_Metric_COUNTER
return agentsdk.AgentMetricTypeCounter
default:
panic(fmt.Sprintf("unknown metric type: %d", typ))
}
+1 -2
View File
@@ -9,7 +9,6 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentListeningPort, error) {
@@ -33,7 +32,7 @@ func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.WorkspaceAgentL
seen := make(map[uint16]struct{}, len(tabs))
ports := []codersdk.WorkspaceAgentListeningPort{}
for _, tab := range tabs {
if tab.LocalAddr == nil || tab.LocalAddr.Port < workspacesdk.AgentMinimumListeningPort {
if tab.LocalAddr == nil || tab.LocalAddr.Port < codersdk.WorkspaceAgentMinimumListeningPort {
continue
}
+460 -1075
View File
File diff suppressed because it is too large Load Diff
+1 -49
View File
@@ -41,7 +41,6 @@ message WorkspaceApp {
UNHEALTHY = 4;
}
Health health = 12;
bool hidden = 13;
}
message WorkspaceAgentScript {
@@ -53,8 +52,6 @@ message WorkspaceAgentScript {
bool run_on_stop = 6;
bool start_blocks_login = 7;
google.protobuf.Duration timeout = 8;
string display_name = 9;
bytes id = 10;
}
message WorkspaceAgentMetadata {
@@ -250,50 +247,7 @@ message BatchCreateLogsRequest {
repeated Log logs = 2;
}
message BatchCreateLogsResponse {
bool log_limit_exceeded = 1;
}
message GetAnnouncementBannersRequest {}
message GetAnnouncementBannersResponse {
repeated BannerConfig announcement_banners = 1;
}
message BannerConfig {
bool enabled = 1;
string message = 2;
string background_color = 3;
}
message WorkspaceAgentScriptCompletedRequest {
Timing timing = 1;
}
message WorkspaceAgentScriptCompletedResponse {
}
message Timing {
bytes script_id = 1;
google.protobuf.Timestamp start = 2;
google.protobuf.Timestamp end = 3;
int32 exit_code = 4;
enum Stage {
START = 0;
STOP = 1;
CRON = 2;
}
Stage stage = 5;
enum Status {
OK = 0;
EXIT_FAILURE = 1;
TIMED_OUT = 2;
PIPES_LEFT_OPEN = 3;
}
Status status = 6;
}
message BatchCreateLogsResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
@@ -304,6 +258,4 @@ service Agent {
rpc UpdateStartup(UpdateStartupRequest) returns (Startup);
rpc BatchUpdateMetadata(BatchUpdateMetadataRequest) returns (BatchUpdateMetadataResponse);
rpc BatchCreateLogs(BatchCreateLogsRequest) returns (BatchCreateLogsResponse);
rpc GetAnnouncementBanners(GetAnnouncementBannersRequest) returns (GetAnnouncementBannersResponse);
rpc ScriptCompleted(WorkspaceAgentScriptCompletedRequest) returns (WorkspaceAgentScriptCompletedResponse);
}
+1 -81
View File
@@ -46,8 +46,6 @@ type DRPCAgentClient interface {
UpdateStartup(ctx context.Context, in *UpdateStartupRequest) (*Startup, error)
BatchUpdateMetadata(ctx context.Context, in *BatchUpdateMetadataRequest) (*BatchUpdateMetadataResponse, error)
BatchCreateLogs(ctx context.Context, in *BatchCreateLogsRequest) (*BatchCreateLogsResponse, error)
GetAnnouncementBanners(ctx context.Context, in *GetAnnouncementBannersRequest) (*GetAnnouncementBannersResponse, error)
ScriptCompleted(ctx context.Context, in *WorkspaceAgentScriptCompletedRequest) (*WorkspaceAgentScriptCompletedResponse, error)
}
type drpcAgentClient struct {
@@ -132,24 +130,6 @@ func (c *drpcAgentClient) BatchCreateLogs(ctx context.Context, in *BatchCreateLo
return out, nil
}
func (c *drpcAgentClient) GetAnnouncementBanners(ctx context.Context, in *GetAnnouncementBannersRequest) (*GetAnnouncementBannersResponse, error) {
out := new(GetAnnouncementBannersResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/GetAnnouncementBanners", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcAgentClient) ScriptCompleted(ctx context.Context, in *WorkspaceAgentScriptCompletedRequest) (*WorkspaceAgentScriptCompletedResponse, error) {
out := new(WorkspaceAgentScriptCompletedResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/ScriptCompleted", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -159,8 +139,6 @@ type DRPCAgentServer interface {
UpdateStartup(context.Context, *UpdateStartupRequest) (*Startup, error)
BatchUpdateMetadata(context.Context, *BatchUpdateMetadataRequest) (*BatchUpdateMetadataResponse, error)
BatchCreateLogs(context.Context, *BatchCreateLogsRequest) (*BatchCreateLogsResponse, error)
GetAnnouncementBanners(context.Context, *GetAnnouncementBannersRequest) (*GetAnnouncementBannersResponse, error)
ScriptCompleted(context.Context, *WorkspaceAgentScriptCompletedRequest) (*WorkspaceAgentScriptCompletedResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -197,17 +175,9 @@ func (s *DRPCAgentUnimplementedServer) BatchCreateLogs(context.Context, *BatchCr
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) GetAnnouncementBanners(context.Context, *GetAnnouncementBannersRequest) (*GetAnnouncementBannersResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) ScriptCompleted(context.Context, *WorkspaceAgentScriptCompletedRequest) (*WorkspaceAgentScriptCompletedResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 10 }
func (DRPCAgentDescription) NumMethods() int { return 8 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -283,24 +253,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*BatchCreateLogsRequest),
)
}, DRPCAgentServer.BatchCreateLogs, true
case 8:
return "/coder.agent.v2.Agent/GetAnnouncementBanners", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
GetAnnouncementBanners(
ctx,
in1.(*GetAnnouncementBannersRequest),
)
}, DRPCAgentServer.GetAnnouncementBanners, true
case 9:
return "/coder.agent.v2.Agent/ScriptCompleted", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
ScriptCompleted(
ctx,
in1.(*WorkspaceAgentScriptCompletedRequest),
)
}, DRPCAgentServer.ScriptCompleted, true
default:
return "", nil, nil, nil, false
}
@@ -437,35 +389,3 @@ func (x *drpcAgent_BatchCreateLogsStream) SendAndClose(m *BatchCreateLogsRespons
}
return x.CloseSend()
}
type DRPCAgent_GetAnnouncementBannersStream interface {
drpc.Stream
SendAndClose(*GetAnnouncementBannersResponse) error
}
type drpcAgent_GetAnnouncementBannersStream struct {
drpc.Stream
}
func (x *drpcAgent_GetAnnouncementBannersStream) SendAndClose(m *GetAnnouncementBannersResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCAgent_ScriptCompletedStream interface {
drpc.Stream
SendAndClose(*WorkspaceAgentScriptCompletedResponse) error
}
type drpcAgent_ScriptCompletedStream struct {
drpc.Stream
}
func (x *drpcAgent_ScriptCompletedStream) SendAndClose(m *WorkspaceAgentScriptCompletedResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
-38
View File
@@ -1,38 +0,0 @@
package proto
import (
"context"
"storj.io/drpc"
)
// DRPCAgentClient20 is the Agent API at v2.0. Notably, it is missing GetAnnouncementBanners, but
// is useful when you want to be maximally compatible with Coderd Release Versions from 2.9+
type DRPCAgentClient20 interface {
DRPCConn() drpc.Conn
GetManifest(ctx context.Context, in *GetManifestRequest) (*Manifest, error)
GetServiceBanner(ctx context.Context, in *GetServiceBannerRequest) (*ServiceBanner, error)
UpdateStats(ctx context.Context, in *UpdateStatsRequest) (*UpdateStatsResponse, error)
UpdateLifecycle(ctx context.Context, in *UpdateLifecycleRequest) (*Lifecycle, error)
BatchUpdateAppHealths(ctx context.Context, in *BatchUpdateAppHealthRequest) (*BatchUpdateAppHealthResponse, error)
UpdateStartup(ctx context.Context, in *UpdateStartupRequest) (*Startup, error)
BatchUpdateMetadata(ctx context.Context, in *BatchUpdateMetadataRequest) (*BatchUpdateMetadataResponse, error)
BatchCreateLogs(ctx context.Context, in *BatchCreateLogsRequest) (*BatchCreateLogsResponse, error)
}
// DRPCAgentClient21 is the Agent API at v2.1. It is useful if you want to be maximally compatible
// with Coderd Release Versions from 2.12+
type DRPCAgentClient21 interface {
DRPCConn() drpc.Conn
GetManifest(ctx context.Context, in *GetManifestRequest) (*Manifest, error)
GetServiceBanner(ctx context.Context, in *GetServiceBannerRequest) (*ServiceBanner, error)
UpdateStats(ctx context.Context, in *UpdateStatsRequest) (*UpdateStatsResponse, error)
UpdateLifecycle(ctx context.Context, in *UpdateLifecycleRequest) (*Lifecycle, error)
BatchUpdateAppHealths(ctx context.Context, in *BatchUpdateAppHealthRequest) (*BatchUpdateAppHealthResponse, error)
UpdateStartup(ctx context.Context, in *UpdateStartupRequest) (*Startup, error)
BatchUpdateMetadata(ctx context.Context, in *BatchUpdateMetadataRequest) (*BatchUpdateMetadataResponse, error)
BatchCreateLogs(ctx context.Context, in *BatchCreateLogsRequest) (*BatchCreateLogsResponse, error)
GetAnnouncementBanners(ctx context.Context, in *GetAnnouncementBannersRequest) (*GetAnnouncementBannersResponse, error)
}
+3 -2
View File
@@ -14,7 +14,8 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty"
)
@@ -196,7 +197,7 @@ func (s *ptyState) waitForStateOrContext(ctx context.Context, state State) (Stat
func readConnLoop(ctx context.Context, conn net.Conn, ptty pty.PTYCmd, metrics *prometheus.CounterVec, logger slog.Logger) {
decoder := json.NewDecoder(conn)
for {
var req workspacesdk.ReconnectingPTYRequest
var req codersdk.ReconnectingPTYRequest
err := decoder.Decode(&req)
if xerrors.Is(err, io.EOF) {
return
-7
View File
@@ -81,13 +81,6 @@ func newScreen(ctx context.Context, cmd *pty.Cmd, options *Options, logger slog.
rpty.id = hex.EncodeToString(buf)
settings := []string{
// Disable the startup message that appears for five seconds.
"startup_message off",
// Some message are hard-coded, the best we can do is set msgwait to 0
// which seems to hide them. This can happen for example if screen shows
// the version message when starting up.
"msgminwait 0",
"msgwait 0",
// Tell screen not to handle motion for xterm* terminals which allows
// scrolling the terminal via the mouse wheel or scroll bar (by default
// screen uses it to cycle through the command history). There does not
-59
View File
@@ -1,10 +1,7 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"io"
"net/netip"
"sync"
"testing"
@@ -17,7 +14,6 @@ import (
"tailscale.com/types/netlogtype"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogjson"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/testutil"
@@ -214,58 +210,3 @@ func newFakeStatsDest() *fakeStatsDest {
resps: make(chan *proto.UpdateStatsResponse),
}
}
func Test_logDebouncer(t *testing.T) {
t.Parallel()
var (
buf bytes.Buffer
logger = slog.Make(slogjson.Sink(&buf))
ctx = context.Background()
)
debouncer := &logDebouncer{
logger: logger,
messages: map[string]time.Time{},
interval: time.Minute,
}
fields := map[string]interface{}{
"field_1": float64(1),
"field_2": "2",
}
debouncer.Error(ctx, "my message", "field_1", 1, "field_2", "2")
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
// Shouldn't log this.
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
require.Len(t, debouncer.messages, 2)
type entry struct {
Msg string `json:"msg"`
Level string `json:"level"`
Fields map[string]interface{} `json:"fields"`
}
assertLog := func(msg string, level string, fields map[string]interface{}) {
line, err := buf.ReadString('\n')
require.NoError(t, err)
var e entry
err = json.Unmarshal([]byte(line), &e)
require.NoError(t, err)
require.Equal(t, msg, e.Msg)
require.Equal(t, level, e.Level)
require.Equal(t, fields, e.Fields)
}
assertLog("my message", "ERROR", fields)
assertLog("another message", "WARN", fields)
debouncer.messages["another message"] = time.Now().Add(-2 * time.Minute)
debouncer.Warn(ctx, "another message", "field_1", 1, "field_2", "2")
assertLog("another message", "WARN", fields)
// Assert nothing else was written.
_, err := buf.ReadString('\n')
require.ErrorIs(t, err, io.EOF)
}
+54 -77
View File
@@ -18,8 +18,10 @@ import (
"cloud.google.com/go/compute/metadata"
"golang.org/x/xerrors"
"gopkg.in/natefinch/lumberjack.v2"
"tailscale.com/util/clientmetric"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/expfmt"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
@@ -27,19 +29,17 @@ import (
"cdr.dev/slog/sloggers/slogstackdriver"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/reaper"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/serpent"
)
func (r *RootCmd) workspaceAgent() *serpent.Command {
func (r *RootCmd) workspaceAgent() *clibase.Cmd {
var (
auth string
logDir string
scriptDataDir string
pprofAddress string
noReap bool
sshMaxTimeout time.Duration
@@ -49,16 +49,13 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
slogHumanPath string
slogJSONPath string
slogStackdriverPath string
blockFileTransfer bool
agentHeaderCommand string
agentHeader []string
)
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "agent",
Short: `Starts the Coder workspace agent.`,
// This command isn't useful to manually execute.
Hidden: true,
Handler: func(inv *serpent.Invocation) error {
Handler: func(inv *clibase.Invocation) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
@@ -127,7 +124,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
args := append(os.Args, "--no-reap")
err := reaper.ForkReap(
reaper.WithExecArgs(args...),
reaper.WithCatchSignals(StopSignals...),
reaper.WithCatchSignals(InterruptSignals...),
)
if err != nil {
logger.Error(ctx, "agent process reaper unable to fork", slog.Error(err))
@@ -146,12 +143,12 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
// Note that we don't want to handle these signals in the
// process that runs as PID 1, that's why we do this after
// the reaper forked.
ctx, stopNotify := inv.SignalNotifyContext(ctx, StopSignals...)
ctx, stopNotify := inv.SignalNotifyContext(ctx, InterruptSignals...)
defer stopNotify()
// DumpHandler does signal handling, so we call it after the
// reaper.
go DumpHandler(ctx, "agent")
go DumpHandler(ctx)
logWriter := &lumberjackWriteCloseFixer{w: &lumberjack.Logger{
Filename: filepath.Join(logDir, "coder-agent.log"),
@@ -178,14 +175,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
// with large payloads can take a bit. e.g. startup scripts
// may take a while to insert.
client.SDK.HTTPClient.Timeout = 30 * time.Second
// Attach header transport so we process --agent-header and
// --agent-header-command flags
headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand)
if err != nil {
return xerrors.Errorf("configure header transport: %w", err)
}
headerTransport.Transport = client.SDK.HTTPClient.Transport
client.SDK.HTTPClient.Transport = headerTransport
// Enable pprof handler
// This prevents the pprof import from being accidentally deleted.
@@ -289,21 +278,12 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
subsystems = append(subsystems, subsystem)
}
environmentVariables := map[string]string{
"GIT_ASKPASS": executablePath,
}
if v, ok := os.LookupEnv(agent.EnvProcPrioMgmt); ok {
environmentVariables[agent.EnvProcPrioMgmt] = v
}
if v, ok := os.LookupEnv(agent.EnvProcOOMScore); ok {
environmentVariables[agent.EnvProcOOMScore] = v
}
procTicker := time.NewTicker(time.Second)
defer procTicker.Stop()
agnt := agent.New(agent.Options{
Client: client,
Logger: logger,
LogDir: logDir,
ScriptDataDir: scriptDataDir,
TailnetListenPort: uint16(tailnetListenPort),
ExchangeToken: func(ctx context.Context) (string, error) {
if exchangeToken == nil {
@@ -316,22 +296,22 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
client.SetSessionToken(resp.SessionToken)
return resp.SessionToken, nil
},
EnvironmentVariables: environmentVariables,
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
EnvironmentVariables: map[string]string{
"GIT_ASKPASS": executablePath,
agent.EnvProcPrioMgmt: os.Getenv(agent.EnvProcPrioMgmt),
},
IgnorePorts: ignorePorts,
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
PrometheusRegistry: prometheusRegistry,
Syscaller: agentproc.NewSyscaller(),
// Intentionally set this to nil. It's mainly used
// for testing.
ModifiedProcesses: nil,
BlockFileTransfer: blockFileTransfer,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)
prometheusSrvClose := ServeHandler(ctx, logger, promHandler, prometheusAddress, "prometheus")
prometheusSrvClose := ServeHandler(ctx, logger, prometheusMetricsHandler(prometheusRegistry, logger), prometheusAddress, "prometheus")
defer prometheusSrvClose()
debugSrvClose := ServeHandler(ctx, logger, agnt.HTTPDebug(), debugAddress, "debug")
@@ -342,53 +322,34 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
},
}
cmd.Options = serpent.OptionSet{
cmd.Options = clibase.OptionSet{
{
Flag: "auth",
Default: "token",
Description: "Specify the authentication type to use for the agent.",
Env: "CODER_AGENT_AUTH",
Value: serpent.StringOf(&auth),
Value: clibase.StringOf(&auth),
},
{
Flag: "log-dir",
Default: os.TempDir(),
Description: "Specify the location for the agent log files.",
Env: "CODER_AGENT_LOG_DIR",
Value: serpent.StringOf(&logDir),
},
{
Flag: "script-data-dir",
Default: os.TempDir(),
Description: "Specify the location for storing script data.",
Env: "CODER_AGENT_SCRIPT_DATA_DIR",
Value: serpent.StringOf(&scriptDataDir),
Value: clibase.StringOf(&logDir),
},
{
Flag: "pprof-address",
Default: "127.0.0.1:6060",
Env: "CODER_AGENT_PPROF_ADDRESS",
Value: serpent.StringOf(&pprofAddress),
Value: clibase.StringOf(&pprofAddress),
Description: "The address to serve pprof.",
},
{
Flag: "agent-header-command",
Env: "CODER_AGENT_HEADER_COMMAND",
Value: serpent.StringOf(&agentHeaderCommand),
Description: "An external command that outputs additional HTTP headers added to all requests. The command must output each header as `key=value` on its own line.",
},
{
Flag: "agent-header",
Env: "CODER_AGENT_HEADER",
Value: serpent.StringArrayOf(&agentHeader),
Description: "Additional HTTP headers added to all requests. Provide as " + `key=value` + ". Can be specified multiple times.",
},
{
Flag: "no-reap",
Env: "",
Description: "Do not start a process reaper.",
Value: serpent.BoolOf(&noReap),
Value: clibase.BoolOf(&noReap),
},
{
Flag: "ssh-max-timeout",
@@ -396,27 +357,27 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Default: "72h",
Env: "CODER_AGENT_SSH_MAX_TIMEOUT",
Description: "Specify the max timeout for a SSH connection, it is advisable to set it to a minimum of 60s, but no more than 72h.",
Value: serpent.DurationOf(&sshMaxTimeout),
Value: clibase.DurationOf(&sshMaxTimeout),
},
{
Flag: "tailnet-listen-port",
Default: "0",
Env: "CODER_AGENT_TAILNET_LISTEN_PORT",
Description: "Specify a static port for Tailscale to use for listening.",
Value: serpent.Int64Of(&tailnetListenPort),
Value: clibase.Int64Of(&tailnetListenPort),
},
{
Flag: "prometheus-address",
Default: "127.0.0.1:2112",
Env: "CODER_AGENT_PROMETHEUS_ADDRESS",
Value: serpent.StringOf(&prometheusAddress),
Value: clibase.StringOf(&prometheusAddress),
Description: "The bind address to serve Prometheus metrics.",
},
{
Flag: "debug-address",
Default: "127.0.0.1:2113",
Env: "CODER_AGENT_DEBUG_ADDRESS",
Value: serpent.StringOf(&debugAddress),
Value: clibase.StringOf(&debugAddress),
Description: "The bind address to serve a debug HTTP server.",
},
{
@@ -425,7 +386,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Flag: "log-human",
Env: "CODER_AGENT_LOGGING_HUMAN",
Default: "/dev/stderr",
Value: serpent.StringOf(&slogHumanPath),
Value: clibase.StringOf(&slogHumanPath),
},
{
Name: "JSON Log Location",
@@ -433,7 +394,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Flag: "log-json",
Env: "CODER_AGENT_LOGGING_JSON",
Default: "",
Value: serpent.StringOf(&slogJSONPath),
Value: clibase.StringOf(&slogJSONPath),
},
{
Name: "Stackdriver Log Location",
@@ -441,14 +402,7 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
Flag: "log-stackdriver",
Env: "CODER_AGENT_LOGGING_STACKDRIVER",
Default: "",
Value: serpent.StringOf(&slogStackdriverPath),
},
{
Flag: "block-file-transfer",
Default: "false",
Env: "CODER_AGENT_BLOCK_FILE_TRANSFER",
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
Value: serpent.BoolOf(&blockFileTransfer),
Value: clibase.StringOf(&slogStackdriverPath),
},
}
@@ -536,3 +490,26 @@ func urlPort(u string) (int, error) {
}
return -1, xerrors.Errorf("invalid port: %s", u)
}
func prometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger slog.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
// Based on: https://github.com/tailscale/tailscale/blob/280255acae604796a1113861f5a84e6fa2dc6121/ipn/localapi/localapi.go#L489
clientmetric.WritePrometheusExpositionFormat(w)
metricFamilies, err := prometheusRegistry.Gather()
if err != nil {
logger.Error(context.Background(), "Prometheus handler can't gather metric families", slog.Error(err))
return
}
for _, metricFamily := range metricFamilies {
_, err = expfmt.MetricFamilyToText(w, metricFamily)
if err != nil {
logger.Error(context.Background(), "expfmt.MetricFamilyToText failed", slog.Error(err))
return
}
}
})
}
+3 -46
View File
@@ -3,13 +3,10 @@ package cli_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"github.com/google/uuid"
@@ -22,7 +19,6 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)
@@ -95,8 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
@@ -135,8 +130,7 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).
DialAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
@@ -179,7 +173,7 @@ func TestWorkspaceAgent(t *testing.T) {
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
assert.NotEmpty(t, resources[0].Agents[0].Version)
}
dialer, err := workspacesdk.New(client).DialAgent(ctx, resources[0].Agents[0].ID, nil)
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer dialer.Close()
require.True(t, dialer.AwaitReachable(ctx))
@@ -232,43 +226,6 @@ func TestWorkspaceAgent(t *testing.T) {
require.Equal(t, codersdk.AgentSubsystemEnvbox, resources[0].Agents[0].Subsystems[0])
require.Equal(t, codersdk.AgentSubsystemExectrace, resources[0].Agents[0].Subsystems[1])
})
t.Run("Header", func(t *testing.T) {
t.Parallel()
var url string
var called int64
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
assert.Equal(t, "Ethan was Here!", r.Header.Get("Cool-Header"))
assert.Equal(t, "very-wow-"+url, r.Header.Get("X-Process-Testing"))
assert.Equal(t, "more-wow", r.Header.Get("X-Process-Testing2"))
atomic.AddInt64(&called, 1)
w.WriteHeader(http.StatusGone)
}))
defer srv.Close()
url = srv.URL
coderURLEnv := "$CODER_URL"
if runtime.GOOS == "windows" {
coderURLEnv = "%CODER_URL%"
}
logDir := t.TempDir()
inv, _ := clitest.New(t,
"agent",
"--auth", "token",
"--agent-token", "fake-token",
"--agent-url", srv.URL,
"--log-dir", logDir,
"--agent-header", "X-Testing=wow",
"--agent-header", "Cool-Header=Ethan was Here!",
"--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
)
clitest.Start(t, inv)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&called) > 0
}, testutil.WaitShort, testutil.IntervalFast)
})
}
func matchAgentWithVersion(rs []codersdk.WorkspaceResource) bool {
+6 -6
View File
@@ -6,22 +6,22 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
func (r *RootCmd) autoupdate() *serpent.Command {
func (r *RootCmd) autoupdate() *clibase.Cmd {
client := new(codersdk.Client)
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Annotations: workspaceCommand,
Use: "autoupdate <workspace> <always|never>",
Short: "Toggle auto-update policy for a workspace",
Middleware: serpent.Chain(
serpent.RequireNArgs(2),
Middleware: clibase.Chain(
clibase.RequireNArgs(2),
r.InitClient(client),
),
Handler: func(inv *serpent.Invocation) error {
Handler: func(inv *clibase.Invocation) error {
policy := strings.ToLower(inv.Args[1])
err := validateAutoUpdatePolicy(policy)
if err != nil {
+1 -1
View File
@@ -24,7 +24,7 @@ func TestAutoUpdate(t *testing.T) {
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, member, template.ID)
workspace := coderdtest.CreateWorkspace(t, member, owner.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
require.Equal(t, codersdk.AutomaticUpdatesNever, workspace.AutomaticUpdates)
+80
View File
@@ -0,0 +1,80 @@
// Package clibase offers an all-in-one solution for a highly configurable CLI
// application. Within Coder, we use it for all of our subcommands, which
// demands more functionality than cobra/viber offers.
//
// The Command interface is loosely based on the chi middleware pattern and
// http.Handler/HandlerFunc.
package clibase
import (
"strings"
"golang.org/x/exp/maps"
)
// Group describes a hierarchy of groups that an option or command belongs to.
type Group struct {
Parent *Group `json:"parent,omitempty"`
Name string `json:"name,omitempty"`
YAML string `json:"yaml,omitempty"`
Description string `json:"description,omitempty"`
}
// Ancestry returns the group and all of its parents, in order.
func (g *Group) Ancestry() []Group {
if g == nil {
return nil
}
groups := []Group{*g}
for p := g.Parent; p != nil; p = p.Parent {
// Prepend to the slice so that the order is correct.
groups = append([]Group{*p}, groups...)
}
return groups
}
func (g *Group) FullName() string {
var names []string
for _, g := range g.Ancestry() {
names = append(names, g.Name)
}
return strings.Join(names, " / ")
}
// Annotations is an arbitrary key-mapping used to extend the Option and Command types.
// Its methods won't panic if the map is nil.
type Annotations map[string]string
// Mark sets a value on the annotations map, creating one
// if it doesn't exist. Mark does not mutate the original and
// returns a copy. It is suitable for chaining.
func (a Annotations) Mark(key string, value string) Annotations {
var aa Annotations
if a != nil {
aa = maps.Clone(a)
} else {
aa = make(Annotations)
}
aa[key] = value
return aa
}
// IsSet returns true if the key is set in the annotations map.
func (a Annotations) IsSet(key string) bool {
if a == nil {
return false
}
_, ok := a[key]
return ok
}
// Get retrieves a key from the map, returning false if the key is not found
// or the map is nil.
func (a Annotations) Get(key string) (string, bool) {
if a == nil {
return "", false
}
v, ok := a[key]
return v, ok
}
+621
View File
@@ -0,0 +1,621 @@
package clibase
import (
"context"
"errors"
"flag"
"fmt"
"io"
"os"
"os/signal"
"strings"
"testing"
"unicode"
"cdr.dev/slog"
"github.com/spf13/pflag"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"
"github.com/coder/coder/v2/coderd/util/slice"
)
// Cmd describes an executable command.
type Cmd struct {
// Parent is the direct parent of the command.
Parent *Cmd
// Children is a list of direct descendants.
Children []*Cmd
// Use is provided in form "command [flags] [args...]".
Use string
// Aliases is a list of alternative names for the command.
Aliases []string
// Short is a one-line description of the command.
Short string
// Hidden determines whether the command should be hidden from help.
Hidden bool
// RawArgs determines whether the command should receive unparsed arguments.
// No flags are parsed when set, and the command is responsible for parsing
// its own flags.
RawArgs bool
// Long is a detailed description of the command,
// presented on its help page. It may contain examples.
Long string
Options OptionSet
Annotations Annotations
// Middleware is called before the Handler.
// Use Chain() to combine multiple middlewares.
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
}
// AddSubcommands adds the given subcommands, setting their
// Parent field automatically.
func (c *Cmd) AddSubcommands(cmds ...*Cmd) {
for _, cmd := range cmds {
cmd.Parent = c
c.Children = append(c.Children, cmd)
}
}
// Walk calls fn for the command and all its children.
func (c *Cmd) Walk(fn func(*Cmd)) {
fn(c)
for _, child := range c.Children {
child.Parent = c
child.Walk(fn)
}
}
// PrepareAll performs initialization and linting on the command and all its children.
func (c *Cmd) PrepareAll() error {
if c.Use == "" {
return xerrors.New("command must have a Use field so that it has a name")
}
var merr error
for i := range c.Options {
opt := &c.Options[i]
if opt.Name == "" {
switch {
case opt.Flag != "":
opt.Name = opt.Flag
case opt.Env != "":
opt.Name = opt.Env
case opt.YAML != "":
opt.Name = opt.YAML
default:
merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field"))
}
}
if opt.Description != "" {
// Enforce that description uses sentence form.
if unicode.IsLower(rune(opt.Description[0])) {
merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name))
}
if !strings.HasSuffix(opt.Description, ".") {
merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name))
}
}
}
slices.SortFunc(c.Options, func(a, b Option) int {
return slice.Ascending(a.Name, b.Name)
})
slices.SortFunc(c.Children, func(a, b *Cmd) int {
return slice.Ascending(a.Name(), b.Name())
})
for _, child := range c.Children {
child.Parent = c
err := child.PrepareAll()
if err != nil {
merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err))
}
}
return merr
}
// Name returns the first word in the Use string.
func (c *Cmd) Name() string {
return strings.Split(c.Use, " ")[0]
}
// FullName returns the full invocation name of the command,
// as seen on the command line.
func (c *Cmd) FullName() string {
var names []string
if c.Parent != nil {
names = append(names, c.Parent.FullName())
}
names = append(names, c.Name())
return strings.Join(names, " ")
}
// FullName returns usage of the command, preceded
// by the usage of its parents.
func (c *Cmd) FullUsage() string {
var uses []string
if c.Parent != nil {
uses = append(uses, c.Parent.FullName())
}
uses = append(uses, c.Use)
return strings.Join(uses, " ")
}
// FullOptions returns the options of the command and its parents.
func (c *Cmd) FullOptions() OptionSet {
var opts OptionSet
if c.Parent != nil {
opts = append(opts, c.Parent.FullOptions()...)
}
opts = append(opts, c.Options...)
return opts
}
// Invoke creates a new invocation of the command, with
// stdio discarded.
//
// The returned invocation is not live until Run() is called.
func (c *Cmd) Invoke(args ...string) *Invocation {
return &Invocation{
Command: c,
Args: args,
Stdout: io.Discard,
Stderr: io.Discard,
Stdin: strings.NewReader(""),
Logger: slog.Make(),
}
}
// Invocation represents an instance of a command being executed.
type Invocation struct {
ctx context.Context
Command *Cmd
parsedFlags *pflag.FlagSet
Args []string
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
Logger slog.Logger
Net Net
// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
}
// WithOS returns the invocation as a main package, filling in the invocation's unset
// fields with OS defaults.
func (inv *Invocation) WithOS() *Invocation {
return inv.with(func(i *Invocation) {
i.Stdout = os.Stdout
i.Stderr = os.Stderr
i.Stdin = os.Stdin
i.Args = os.Args[1:]
i.Environ = ParseEnviron(os.Environ(), "")
i.Net = osNet{}
})
}
// WithTestSignalNotifyContext allows overriding the default implementation of SignalNotifyContext.
// This should only be used in testing.
func (inv *Invocation) WithTestSignalNotifyContext(
_ testing.TB, // ensure we only call this from tests
f func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc),
) *Invocation {
return inv.with(func(i *Invocation) {
i.signalNotifyContext = f
})
}
// SignalNotifyContext is equivalent to signal.NotifyContext, but supports being overridden in
// tests.
func (inv *Invocation) SignalNotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
if inv.signalNotifyContext == nil {
return signal.NotifyContext(parent, signals...)
}
return inv.signalNotifyContext(parent, signals...)
}
func (inv *Invocation) WithTestParsedFlags(
_ testing.TB, // ensure we only call this from tests
parsedFlags *pflag.FlagSet,
) *Invocation {
return inv.with(func(i *Invocation) {
i.parsedFlags = parsedFlags
})
}
func (inv *Invocation) Context() context.Context {
if inv.ctx == nil {
return context.Background()
}
return inv.ctx
}
func (inv *Invocation) ParsedFlags() *pflag.FlagSet {
if inv.parsedFlags == nil {
panic("flags not parsed, has Run() been called?")
}
return inv.parsedFlags
}
type runState struct {
allArgs []string
commandDepth int
flagParseErr error
}
func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
fs2 := pflag.NewFlagSet("", pflag.ContinueOnError)
fs2.Usage = func() {}
fs.VisitAll(func(f *pflag.Flag) {
if f.Name == without {
return
}
fs2.AddFlag(f)
})
return fs2
}
// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
func (inv *Invocation) run(state *runState) error {
err := inv.Command.Options.ParseEnv(inv.Environ)
if err != nil {
return xerrors.Errorf("parsing env: %w", err)
}
// Now the fun part, argument parsing!
children := make(map[string]*Cmd)
for _, child := range inv.Command.Children {
child.Parent = inv.Command
for _, name := range append(child.Aliases, child.Name()) {
if _, ok := children[name]; ok {
return xerrors.Errorf("duplicate command name: %s", name)
}
children[name] = child
}
}
if inv.parsedFlags == nil {
inv.parsedFlags = pflag.NewFlagSet(inv.Command.Name(), pflag.ContinueOnError)
// We handle Usage ourselves.
inv.parsedFlags.Usage = func() {}
}
// If we find a duplicate flag, we want the deeper command's flag to override
// the shallow one. Unfortunately, pflag has no way to remove a flag, so we
// have to create a copy of the flagset without a value.
inv.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) {
if inv.parsedFlags.Lookup(f.Name) != nil {
inv.parsedFlags = copyFlagSetWithout(inv.parsedFlags, f.Name)
}
inv.parsedFlags.AddFlag(f)
})
var parsedArgs []string
if !inv.Command.RawArgs {
// Flag parsing will fail on intermediate commands in the command tree,
// so we check the error after looking for a child command.
state.flagParseErr = inv.parsedFlags.Parse(state.allArgs)
parsedArgs = inv.parsedFlags.Args()
}
// Set value sources for flags.
for i, opt := range inv.Command.Options {
if fl := inv.parsedFlags.Lookup(opt.Flag); fl != nil && fl.Changed {
inv.Command.Options[i].ValueSource = ValueSourceFlag
}
}
// Read YAML configs, if any.
for _, opt := range inv.Command.Options {
path, ok := opt.Value.(*YAMLConfigPath)
if !ok || path.String() == "" {
continue
}
byt, err := os.ReadFile(path.String())
if err != nil {
return xerrors.Errorf("reading yaml: %w", err)
}
var n yaml.Node
err = yaml.Unmarshal(byt, &n)
if err != nil {
return xerrors.Errorf("decoding yaml: %w", err)
}
err = inv.Command.Options.UnmarshalYAML(&n)
if err != nil {
return xerrors.Errorf("applying yaml: %w", err)
}
}
err = inv.Command.Options.SetDefaults()
if err != nil {
return xerrors.Errorf("setting defaults: %w", err)
}
// Run child command if found (next child only)
// We must do subcommand detection after flag parsing so we don't mistake flag
// values for subcommand names.
if len(parsedArgs) > state.commandDepth {
nextArg := parsedArgs[state.commandDepth]
if child, ok := children[nextArg]; ok {
child.Parent = inv.Command
inv.Command = child
state.commandDepth++
return inv.run(state)
}
}
// Flag parse errors are irrelevant for raw args commands.
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
inv.Command.FullName(), state.flagParseErr,
)
}
// All options should be set. Check all required options have sources,
// meaning they were set by the user in some way (env, flag, etc).
var missing []string
for _, opt := range inv.Command.Options {
if opt.Required && opt.ValueSource == ValueSourceNone {
missing = append(missing, opt.Flag)
}
}
if len(missing) > 0 {
return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", "))
}
if inv.Command.RawArgs {
// If we're at the root command, then the name is omitted
// from the arguments, so we can just use the entire slice.
if state.commandDepth == 0 {
inv.Args = state.allArgs
} else {
argPos, err := findArg(inv.Command.Name(), state.allArgs, inv.parsedFlags)
if err != nil {
panic(err)
}
inv.Args = state.allArgs[argPos+1:]
}
} else {
// In non-raw-arg mode, we want to skip over flags.
inv.Args = parsedArgs[state.commandDepth:]
}
mw := inv.Command.Middleware
if mw == nil {
mw = Chain()
}
ctx := inv.ctx
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
inv = inv.WithContext(ctx)
if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if inv.Command.HelpHandler == nil {
return xerrors.Errorf("no handler or help for command %s", inv.Command.FullName())
}
return inv.Command.HelpHandler(inv)
}
err = mw(inv.Command.Handler)(inv)
if err != nil {
return &RunCommandError{
Cmd: inv.Command,
Err: err,
}
}
return nil
}
type RunCommandError struct {
Cmd *Cmd
Err error
}
func (e *RunCommandError) Unwrap() error {
return e.Err
}
func (e *RunCommandError) Error() string {
return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err)
}
// findArg returns the index of the first occurrence of arg in args, skipping
// over all flags.
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
for i := 0; i < len(args); i++ {
arg := args[i]
if !strings.HasPrefix(arg, "-") {
if arg == want {
return i, nil
}
continue
}
// This is a flag!
if strings.Contains(arg, "=") {
// The flag contains the value in the same arg, just skip.
continue
}
// We need to check if NoOptValue is set, then we should not wait
// for the next arg to be the value.
f := fs.Lookup(strings.TrimLeft(arg, "-"))
if f == nil {
return -1, xerrors.Errorf("unknown flag: %s", arg)
}
if f.NoOptDefVal != "" {
continue
}
if i == len(args)-1 {
return -1, xerrors.Errorf("flag %s requires a value", arg)
}
// Skip the value.
i++
}
return -1, xerrors.Errorf("arg %s not found", want)
}
// Run executes the command.
// If two command share a flag name, the first command wins.
//
//nolint:revive
func (inv *Invocation) Run() (err error) {
defer func() {
// Pflag is panicky, so additional context is helpful in tests.
if flag.Lookup("test.v") == nil {
return
}
if r := recover(); r != nil {
err = xerrors.Errorf("panic recovered for %s: %v", inv.Command.FullName(), r)
panic(err)
}
}()
// We close Stdin to prevent deadlocks, e.g. when the command
// has ended but an io.Copy is still reading from Stdin.
defer func() {
if inv.Stdin == nil {
return
}
rc, ok := inv.Stdin.(io.ReadCloser)
if !ok {
return
}
e := rc.Close()
err = errors.Join(err, e)
}()
err = inv.run(&runState{
allArgs: inv.Args,
})
return err
}
// WithContext returns a copy of the Invocation with the given context.
func (inv *Invocation) WithContext(ctx context.Context) *Invocation {
return inv.with(func(i *Invocation) {
i.ctx = ctx
})
}
// with returns a copy of the Invocation with the given function applied.
func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
i2 := *inv
fn(&i2)
return &i2
}
// MiddlewareFunc returns the next handler in the chain,
// or nil if there are no more.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
func chain(ms ...MiddlewareFunc) MiddlewareFunc {
return MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
if len(ms) > 0 {
return chain(ms[1:]...)(ms[0](next))
}
return next
})
}
// Chain returns a Handler that first calls middleware in order.
//
//nolint:revive
func Chain(ms ...MiddlewareFunc) MiddlewareFunc {
// We need to reverse the array to provide top-to-bottom execution
// order when defining a command.
reversed := make([]MiddlewareFunc, len(ms))
for i := range ms {
reversed[len(ms)-1-i] = ms[i]
}
return chain(reversed...)
}
func RequireNArgs(want int) MiddlewareFunc {
return RequireRangeArgs(want, want)
}
// RequireRangeArgs returns a Middleware that requires the number of arguments
// to be between start and end (inclusive). If end is -1, then the number of
// arguments must be at least start.
func RequireRangeArgs(start, end int) MiddlewareFunc {
if start < 0 {
panic("start must be >= 0")
}
return func(next HandlerFunc) HandlerFunc {
return func(i *Invocation) error {
got := len(i.Args)
switch {
case start == end && got != start:
switch start {
case 0:
if len(i.Command.Children) > 0 {
return xerrors.Errorf("unrecognized subcommand %q", i.Args[0])
}
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
default:
return xerrors.Errorf(
"wanted %v args but got %v %v",
start,
got,
i.Args,
)
}
case start > 0 && end == -1:
switch {
case got < start:
return xerrors.Errorf(
"wanted at least %v args but got %v",
start,
got,
)
default:
return next(i)
}
case start > end:
panic("start must be <= end")
case got < start || got > end:
return xerrors.Errorf(
"wanted between %v and %v args but got %v",
start, end,
got,
)
default:
return next(i)
}
}
}
}
// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error
+719
View File
@@ -0,0 +1,719 @@
package clibase_test
import (
"bytes"
"context"
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/clibase"
)
// ioBufs is the standard input, output, and error for a command.
type ioBufs struct {
Stdin bytes.Buffer
Stdout bytes.Buffer
Stderr bytes.Buffer
}
// fakeIO sets Stdin, Stdout, and Stderr to buffers.
func fakeIO(i *clibase.Invocation) *ioBufs {
var b ioBufs
i.Stdout = &b.Stdout
i.Stderr = &b.Stderr
i.Stdin = &b.Stdin
return &b
}
func TestCommand(t *testing.T) {
t.Parallel()
cmd := func() *clibase.Cmd {
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
)
return &clibase.Cmd{
Use: "root [subcommand]",
Options: clibase.OptionSet{
clibase.Option{
Name: "verbose",
Flag: "verbose",
Value: clibase.BoolOf(&verbose),
},
clibase.Option{
Name: "prefix",
Flag: "prefix",
Value: clibase.StringOf(&prefix),
},
},
Children: []*clibase.Cmd{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: clibase.OptionSet{
clibase.Option{
Name: "req-bool",
Flag: "req-bool",
Value: clibase.BoolOf(&reqBool),
Required: true,
},
clibase.Option{
Name: "req-string",
Flag: "req-string",
Value: clibase.Validate(clibase.StringOf(&reqStr), func(value *clibase.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
},
Handler: func(i *clibase.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: clibase.Chain(
clibase.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: clibase.OptionSet{
clibase.Option{
Name: "lower",
Flag: "lower",
Value: clibase.BoolOf(&lower),
},
},
Handler: func(i *clibase.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
i.Stdout.Write([]byte("!!!"))
}
return nil
},
},
},
}
}
t.Run("SimpleOK", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke("toupper", "hello")
io := fakeIO(i)
i.Run()
require.Equal(t, "HELLO", io.Stdout.String())
})
t.Run("Alias", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"up", "hello",
)
io := fakeIO(i)
i.Run()
require.Equal(t, "HELLO", io.Stdout.String())
})
t.Run("NoSubcommand", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"na",
)
io := fakeIO(i)
err := i.Run()
require.Empty(t, io.Stdout.String())
require.Error(t, err)
})
t.Run("BadArgs", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"toupper",
)
io := fakeIO(i)
err := i.Run()
require.Empty(t, io.Stdout.String())
require.Error(t, err)
})
t.Run("UnknownFlags", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"toupper", "--unknown",
)
io := fakeIO(i)
err := i.Run()
require.Empty(t, io.Stdout.String())
require.Error(t, err)
})
t.Run("Verbose", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"--verbose", "toupper", "hello",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "HELLO!!!", io.Stdout.String())
})
t.Run("Verbose=", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"--verbose=true", "toupper", "hello",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "HELLO!!!", io.Stdout.String())
})
t.Run("PrefixSpace", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"--prefix", "conv: ", "toupper", "hello",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "conv: HELLO", io.Stdout.String())
})
t.Run("GlobalFlagsAnywhere", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"toupper", "--prefix", "conv: ", "hello", "--verbose",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "conv: HELLO!!!", io.Stdout.String())
})
t.Run("LowerVerbose", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"toupper", "--verbose", "hello", "--lower",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "hello!!!", io.Stdout.String())
})
t.Run("ParsedFlags", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"toupper", "--verbose", "hello", "--lower",
)
_ = fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t,
"true",
i.ParsedFlags().Lookup("verbose").Value.String(),
)
})
t.Run("NoDeepChild", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"root", "level", "level", "toupper", "--verbose", "hello", "--lower",
)
fio := fakeIO(i)
require.Error(t, i.Run(), fio.Stdout.String())
})
t.Run("RequiredFlagsMissing", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"required-flag",
)
fio := fakeIO(i)
err := i.Run()
require.Error(t, err, fio.Stdout.String())
require.ErrorContains(t, err, "Missing values")
})
t.Run("RequiredFlagsMissingBool", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"required-flag", "--req-string", "foo bar",
)
fio := fakeIO(i)
err := i.Run()
require.Error(t, err, fio.Stdout.String())
require.ErrorContains(t, err, "Missing values for the required flags: req-bool")
})
t.Run("RequiredFlagsMissingString", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"required-flag", "--req-bool", "true",
)
fio := fakeIO(i)
err := i.Run()
require.Error(t, err, fio.Stdout.String())
require.ErrorContains(t, err, "Missing values for the required flags: req-string")
})
t.Run("RequiredFlagsInvalid", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"required-flag", "--req-string", "nospace",
)
fio := fakeIO(i)
err := i.Run()
require.Error(t, err, fio.Stdout.String())
require.ErrorContains(t, err, "string must contain a space")
})
t.Run("RequiredFlagsOK", func(t *testing.T) {
t.Parallel()
i := cmd().Invoke(
"required-flag", "--req-bool", "true", "--req-string", "foo bar",
)
fio := fakeIO(i)
err := i.Run()
require.NoError(t, err, fio.Stdout.String())
})
}
func TestCommand_DeepNest(t *testing.T) {
t.Parallel()
cmd := &clibase.Cmd{
Use: "1",
Children: []*clibase.Cmd{
{
Use: "2",
Children: []*clibase.Cmd{
{
Use: "3",
Handler: func(i *clibase.Invocation) error {
i.Stdout.Write([]byte("3"))
return nil
},
},
},
},
},
}
inv := cmd.Invoke("2", "3")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Equal(t, "3", stdio.Stdout.String())
}
func TestCommand_FlagOverride(t *testing.T) {
t.Parallel()
var flag string
cmd := &clibase.Cmd{
Use: "1",
Options: clibase.OptionSet{
{
Name: "flag",
Flag: "f",
Value: clibase.DiscardValue,
},
},
Children: []*clibase.Cmd{
{
Use: "2",
Options: clibase.OptionSet{
{
Name: "flag",
Flag: "f",
Value: clibase.StringOf(&flag),
},
},
Handler: func(i *clibase.Invocation) error {
return nil
},
},
},
}
err := cmd.Invoke("2", "--f", "mhmm").Run()
require.NoError(t, err)
require.Equal(t, "mhmm", flag)
}
func TestCommand_MiddlewareOrder(t *testing.T) {
t.Parallel()
mw := func(letter string) clibase.MiddlewareFunc {
return func(next clibase.HandlerFunc) clibase.HandlerFunc {
return (func(i *clibase.Invocation) error {
_, _ = i.Stdout.Write([]byte(letter))
return next(i)
})
}
}
cmd := &clibase.Cmd{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: clibase.Chain(
mw("A"),
mw("B"),
mw("C"),
),
Handler: (func(i *clibase.Invocation) error {
return nil
}),
}
i := cmd.Invoke(
"hello", "world",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "ABC", io.Stdout.String())
}
func TestCommand_RawArgs(t *testing.T) {
t.Parallel()
cmd := func() *clibase.Cmd {
return &clibase.Cmd{
Use: "root",
Options: clibase.OptionSet{
{
Name: "password",
Flag: "password",
Value: clibase.StringOf(new(string)),
},
},
Children: []*clibase.Cmd{
{
Use: "sushi <args...>",
Short: "Throws back raw output",
RawArgs: true,
Handler: (func(i *clibase.Invocation) error {
if v := i.ParsedFlags().Lookup("password").Value.String(); v != "codershack" {
return xerrors.Errorf("password %q is wrong!", v)
}
i.Stdout.Write([]byte(strings.Join(i.Args, " ")))
return nil
}),
},
},
}
}
t.Run("OK", func(t *testing.T) {
// Flag parsed before the raw arg command should still work.
t.Parallel()
i := cmd().Invoke(
"--password", "codershack", "sushi", "hello", "--verbose", "world",
)
io := fakeIO(i)
require.NoError(t, i.Run())
require.Equal(t, "hello --verbose world", io.Stdout.String())
})
t.Run("BadFlag", func(t *testing.T) {
// Verbose before the raw arg command should fail.
t.Parallel()
i := cmd().Invoke(
"--password", "codershack", "--verbose", "sushi", "hello", "world",
)
io := fakeIO(i)
require.Error(t, i.Run())
require.Empty(t, io.Stdout.String())
})
t.Run("NoPassword", func(t *testing.T) {
// Flag parsed before the raw arg command should still work.
t.Parallel()
i := cmd().Invoke(
"sushi", "hello", "--verbose", "world",
)
_ = fakeIO(i)
require.Error(t, i.Run())
})
}
func TestCommand_RootRaw(t *testing.T) {
t.Parallel()
cmd := &clibase.Cmd{
RawArgs: true,
Handler: func(i *clibase.Invocation) error {
i.Stdout.Write([]byte(strings.Join(i.Args, " ")))
return nil
},
}
inv := cmd.Invoke("hello", "--verbose", "--friendly")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Equal(t, "hello --verbose --friendly", stdio.Stdout.String())
}
func TestCommand_HyphenHyphen(t *testing.T) {
t.Parallel()
cmd := &clibase.Cmd{
Handler: (func(i *clibase.Invocation) error {
i.Stdout.Write([]byte(strings.Join(i.Args, " ")))
return nil
}),
}
inv := cmd.Invoke("--", "--verbose", "--friendly")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Equal(t, "--verbose --friendly", stdio.Stdout.String())
}
func TestCommand_ContextCancels(t *testing.T) {
t.Parallel()
var gotCtx context.Context
cmd := &clibase.Cmd{
Handler: (func(i *clibase.Invocation) error {
gotCtx = i.Context()
if err := gotCtx.Err(); err != nil {
return xerrors.Errorf("unexpected context error: %w", i.Context().Err())
}
return nil
}),
}
err := cmd.Invoke().Run()
require.NoError(t, err)
require.Error(t, gotCtx.Err())
}
func TestCommand_Help(t *testing.T) {
t.Parallel()
cmd := func() *clibase.Cmd {
return &clibase.Cmd{
Use: "root",
HelpHandler: (func(i *clibase.Invocation) error {
i.Stdout.Write([]byte("abdracadabra"))
return nil
}),
Handler: (func(i *clibase.Invocation) error {
return xerrors.New("should not be called")
}),
}
}
t.Run("NoHandler", func(t *testing.T) {
t.Parallel()
c := cmd()
c.HelpHandler = nil
err := c.Invoke("--help").Run()
require.Error(t, err)
})
t.Run("Long", func(t *testing.T) {
t.Parallel()
inv := cmd().Invoke("--help")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stdio.Stdout.String(), "abdracadabra")
})
t.Run("Short", func(t *testing.T) {
t.Parallel()
inv := cmd().Invoke("-h")
stdio := fakeIO(inv)
err := inv.Run()
require.NoError(t, err)
require.Contains(t, stdio.Stdout.String(), "abdracadabra")
})
}
func TestCommand_SliceFlags(t *testing.T) {
t.Parallel()
cmd := func(want ...string) *clibase.Cmd {
var got []string
return &clibase.Cmd{
Use: "root",
Options: clibase.OptionSet{
{
Name: "arr",
Flag: "arr",
Default: "bad,bad,bad",
Value: clibase.StringArrayOf(&got),
},
},
Handler: (func(i *clibase.Invocation) error {
require.Equal(t, want, got)
return nil
}),
}
}
err := cmd("good", "good", "good").Invoke("--arr", "good", "--arr", "good", "--arr", "good").Run()
require.NoError(t, err)
err = cmd("bad", "bad", "bad").Invoke().Run()
require.NoError(t, err)
}
func TestCommand_EmptySlice(t *testing.T) {
t.Parallel()
cmd := func(want ...string) *clibase.Cmd {
var got []string
return &clibase.Cmd{
Use: "root",
Options: clibase.OptionSet{
{
Name: "arr",
Flag: "arr",
Default: "def,def,def",
Env: "ARR",
Value: clibase.StringArrayOf(&got),
},
},
Handler: (func(i *clibase.Invocation) error {
require.Equal(t, want, got)
return nil
}),
}
}
// Base-case, uses default.
err := cmd("def", "def", "def").Invoke().Run()
require.NoError(t, err)
// Empty-env uses default, too.
inv := cmd("def", "def", "def").Invoke()
inv.Environ.Set("ARR", "")
require.NoError(t, err)
// Reset to nothing at all via flag.
inv = cmd().Invoke("--arr", "")
inv.Environ.Set("ARR", "cant see")
err = inv.Run()
require.NoError(t, err)
// Reset to a specific value with flag.
inv = cmd("great").Invoke("--arr", "great")
inv.Environ.Set("ARR", "")
err = inv.Run()
require.NoError(t, err)
}
func TestCommand_DefaultsOverride(t *testing.T) {
t.Parallel()
test := func(name string, want string, fn func(t *testing.T, inv *clibase.Invocation)) {
t.Run(name, func(t *testing.T) {
t.Parallel()
var (
got string
config clibase.YAMLConfigPath
)
cmd := &clibase.Cmd{
Options: clibase.OptionSet{
{
Name: "url",
Flag: "url",
Default: "def.com",
Env: "URL",
Value: clibase.StringOf(&got),
YAML: "url",
},
{
Name: "config",
Flag: "config",
Default: "",
Value: &config,
},
},
Handler: (func(i *clibase.Invocation) error {
_, _ = fmt.Fprintf(i.Stdout, "%s", got)
return nil
}),
}
inv := cmd.Invoke()
stdio := fakeIO(inv)
fn(t, inv)
err := inv.Run()
require.NoError(t, err)
require.Equal(t, want, stdio.Stdout.String())
})
}
test("DefaultOverNothing", "def.com", func(t *testing.T, inv *clibase.Invocation) {})
test("FlagOverDefault", "good.com", func(t *testing.T, inv *clibase.Invocation) {
inv.Args = []string{"--url", "good.com"}
})
test("EnvOverDefault", "good.com", func(t *testing.T, inv *clibase.Invocation) {
inv.Environ.Set("URL", "good.com")
})
test("FlagOverEnv", "good.com", func(t *testing.T, inv *clibase.Invocation) {
inv.Environ.Set("URL", "bad.com")
inv.Args = []string{"--url", "good.com"}
})
test("FlagOverYAML", "good.com", func(t *testing.T, inv *clibase.Invocation) {
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
require.NoError(t, err)
defer fi.Close()
_, err = fi.WriteString("url: bad.com")
require.NoError(t, err)
inv.Args = []string{"--config", fi.Name(), "--url", "good.com"}
})
test("YAMLOverDefault", "good.com", func(t *testing.T, inv *clibase.Invocation) {
fi, err := os.CreateTemp(t.TempDir(), "config.yaml")
require.NoError(t, err)
defer fi.Close()
_, err = fi.WriteString("url: good.com")
require.NoError(t, err)
inv.Args = []string{"--config", fi.Name()}
})
}
+76
View File
@@ -0,0 +1,76 @@
package clibase
import "strings"
// name returns the name of the environment variable.
func envName(line string) string {
return strings.ToUpper(
strings.SplitN(line, "=", 2)[0],
)
}
// value returns the value of the environment variable.
func envValue(line string) string {
tokens := strings.SplitN(line, "=", 2)
if len(tokens) < 2 {
return ""
}
return tokens[1]
}
// Var represents a single environment variable of form
// NAME=VALUE.
type EnvVar struct {
Name string
Value string
}
type Environ []EnvVar
func (e Environ) ToOS() []string {
var env []string
for _, v := range e {
env = append(env, v.Name+"="+v.Value)
}
return env
}
func (e Environ) Lookup(name string) (string, bool) {
for _, v := range e {
if v.Name == name {
return v.Value, true
}
}
return "", false
}
func (e Environ) Get(name string) string {
v, _ := e.Lookup(name)
return v
}
func (e *Environ) Set(name, value string) {
for i, v := range *e {
if v.Name == name {
(*e)[i].Value = value
return
}
}
*e = append(*e, EnvVar{Name: name, Value: value})
}
// ParseEnviron returns all environment variables starting with
// prefix without said prefix.
func ParseEnviron(environ []string, prefix string) Environ {
var filtered []EnvVar
for _, line := range environ {
name := envName(line)
if strings.HasPrefix(name, prefix) {
filtered = append(filtered, EnvVar{
Name: strings.TrimPrefix(name, prefix),
Value: envValue(line),
})
}
}
return filtered
}
+44
View File
@@ -0,0 +1,44 @@
package clibase_test
import (
"reflect"
"testing"
"github.com/coder/coder/v2/cli/clibase"
)
func TestFilterNamePrefix(t *testing.T) {
t.Parallel()
type args struct {
environ []string
prefix string
}
tests := []struct {
name string
args args
want clibase.Environ
}{
{"empty", args{[]string{}, "SHIRE"}, nil},
{
"ONE",
args{
[]string{
"SHIRE_BRANDYBUCK=hmm",
},
"SHIRE_",
},
[]clibase.EnvVar{
{Name: "BRANDYBUCK", Value: "hmm"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := clibase.ParseEnviron(tt.args.environ, tt.args.prefix); !reflect.DeepEqual(got, tt.want) {
t.Errorf("FilterNamePrefix() = %v, want %v", got, tt.want)
}
})
}
}
+50
View File
@@ -0,0 +1,50 @@
package clibase
import (
"net"
"strconv"
"github.com/pion/udp"
"golang.org/x/xerrors"
)
// Net abstracts CLI commands interacting with the operating system networking.
//
// At present, it covers opening local listening sockets, since doing this
// in testing is a challenge without flakes, since it's hard to pick a port we
// know a priori will be free.
type Net interface {
// Listen has the same semantics as `net.Listen` but also supports `udp`
Listen(network, address string) (net.Listener, error)
}
// osNet is an implementation that call the real OS for networking.
type osNet struct{}
func (osNet) Listen(network, address string) (net.Listener, error) {
switch network {
case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
return net.Listen(network, address)
case "udp":
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, xerrors.Errorf("split %q: %w", address, err)
}
var portInt int
portInt, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("parse port %v from %q as int: %w", port, address, err)
}
// Use pion here so that we get a stream-style net.Conn listener, instead
// of a packet-oriented connection that can read and write to multiple
// addresses.
return udp.Listen(network, &net.UDPAddr{
IP: net.ParseIP(host),
Port: portInt,
})
default:
return nil, xerrors.Errorf("unknown listen network %q", network)
}
}
+346
View File
@@ -0,0 +1,346 @@
package clibase
import (
"bytes"
"encoding/json"
"os"
"strings"
"github.com/hashicorp/go-multierror"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
)
type ValueSource string
const (
ValueSourceNone ValueSource = ""
ValueSourceFlag ValueSource = "flag"
ValueSourceEnv ValueSource = "env"
ValueSourceYAML ValueSource = "yaml"
ValueSourceDefault ValueSource = "default"
)
// Option is a configuration option for a CLI application.
type Option struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
// Required means this value must be set by some means. It requires
// `ValueSource != ValueSourceNone`
// If `Default` is set, then `Required` is ignored.
Required bool `json:"required,omitempty"`
// Flag is the long name of the flag used to configure this option. If unset,
// flag configuring is disabled.
Flag string `json:"flag,omitempty"`
// FlagShorthand is the one-character shorthand for the flag. If unset, no
// shorthand is used.
FlagShorthand string `json:"flag_shorthand,omitempty"`
// Env is the environment variable used to configure this option. If unset,
// environment configuring is disabled.
Env string `json:"env,omitempty"`
// YAML is the YAML key used to configure this option. If unset, YAML
// configuring is disabled.
YAML string `json:"yaml,omitempty"`
// Default is parsed into Value if set.
Default string `json:"default,omitempty"`
// Value includes the types listed in values.go.
Value pflag.Value `json:"value,omitempty"`
// Annotations enable extensions to clibase higher up in the stack. It's useful for
// help formatting and documentation generation.
Annotations Annotations `json:"annotations,omitempty"`
// Group is a group hierarchy that helps organize this option in help, configs
// and other documentation.
Group *Group `json:"group,omitempty"`
// UseInstead is a list of options that should be used instead of this one.
// The field is used to generate a deprecation warning.
UseInstead []Option `json:"use_instead,omitempty"`
Hidden bool `json:"hidden,omitempty"`
ValueSource ValueSource `json:"value_source,omitempty"`
}
// optionNoMethods is just a wrapper around Option so we can defer to the
// default json.Unmarshaler behavior.
type optionNoMethods Option
func (o *Option) UnmarshalJSON(data []byte) error {
// If an option has no values, we have no idea how to unmarshal it.
// So just discard the json data.
if o.Value == nil {
o.Value = &DiscardValue
}
return json.Unmarshal(data, (*optionNoMethods)(o))
}
func (o Option) YAMLPath() string {
if o.YAML == "" {
return ""
}
var gs []string
for _, g := range o.Group.Ancestry() {
gs = append(gs, g.YAML)
}
return strings.Join(append(gs, o.YAML), ".")
}
// OptionSet is a group of options that can be applied to a command.
type OptionSet []Option
// UnmarshalJSON implements json.Unmarshaler for OptionSets. Options have an
// interface Value type that cannot handle unmarshalling because the types cannot
// be inferred. Since it is a slice, instantiating the Options first does not
// help.
//
// However, we typically do instantiate the slice to have the correct types.
// So this unmarshaller will attempt to find the named option in the existing
// set, if it cannot, the value is discarded. If the option exists, the value
// is unmarshalled into the existing option, and replaces the existing option.
//
// The value is discarded if it's type cannot be inferred. This behavior just
// feels "safer", although it should never happen if the correct option set
// is passed in. The situation where this could occur is if a client and server
// are on different versions with different options.
func (optSet *OptionSet) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewBuffer(data))
// Should be a json array, so consume the starting open bracket.
t, err := dec.Token()
if err != nil {
return xerrors.Errorf("read array open bracket: %w", err)
}
if t != json.Delim('[') {
return xerrors.Errorf("expected array open bracket, got %q", t)
}
// As long as json elements exist, consume them. The counter is used for
// better errors.
var i int
OptionSetDecodeLoop:
for dec.More() {
var opt Option
// jValue is a placeholder value that allows us to capture the
// raw json for the value to attempt to unmarshal later.
var jValue jsonValue
opt.Value = &jValue
err := dec.Decode(&opt)
if err != nil {
return xerrors.Errorf("decode %d option: %w", i, err)
}
// This counter is used to contextualize errors to show which element of
// the array we failed to decode. It is only used in the error above, as
// if the above works, we can instead use the Option.Name which is more
// descriptive and useful. So increment here for the next decode.
i++
// Try to see if the option already exists in the option set.
// If it does, just update the existing option.
for optIndex, have := range *optSet {
if have.Name == opt.Name {
if jValue != nil {
err := json.Unmarshal(jValue, &(*optSet)[optIndex].Value)
if err != nil {
return xerrors.Errorf("decode option %q value: %w", have.Name, err)
}
// Set the opt's value
opt.Value = (*optSet)[optIndex].Value
} else {
// Hopefully the user passed empty values in the option set. There is no easy way
// to tell, and if we do not do this, it breaks json.Marshal if we do it again on
// this new option set.
opt.Value = (*optSet)[optIndex].Value
}
// Override the existing.
(*optSet)[optIndex] = opt
// Go to the next option to decode.
continue OptionSetDecodeLoop
}
}
// If the option doesn't exist, the value will be discarded.
// We do this because we cannot infer the type of the value.
opt.Value = DiscardValue
*optSet = append(*optSet, opt)
}
t, err = dec.Token()
if err != nil {
return xerrors.Errorf("read array close bracket: %w", err)
}
if t != json.Delim(']') {
return xerrors.Errorf("expected array close bracket, got %q", t)
}
return nil
}
// Add adds the given Options to the OptionSet.
func (optSet *OptionSet) Add(opts ...Option) {
*optSet = append(*optSet, opts...)
}
// Filter will only return options that match the given filter. (return true)
func (optSet OptionSet) Filter(filter func(opt Option) bool) OptionSet {
cpy := make(OptionSet, 0)
for _, opt := range optSet {
if filter(opt) {
cpy = append(cpy, opt)
}
}
return cpy
}
// FlagSet returns a pflag.FlagSet for the OptionSet.
func (optSet *OptionSet) FlagSet() *pflag.FlagSet {
if optSet == nil {
return &pflag.FlagSet{}
}
fs := pflag.NewFlagSet("", pflag.ContinueOnError)
for _, opt := range *optSet {
if opt.Flag == "" {
continue
}
var noOptDefValue string
{
no, ok := opt.Value.(NoOptDefValuer)
if ok {
noOptDefValue = no.NoOptDefValue()
}
}
val := opt.Value
if val == nil {
val = DiscardValue
}
fs.AddFlag(&pflag.Flag{
Name: opt.Flag,
Shorthand: opt.FlagShorthand,
Usage: opt.Description,
Value: val,
DefValue: "",
Changed: false,
Deprecated: "",
NoOptDefVal: noOptDefValue,
Hidden: opt.Hidden,
})
}
fs.Usage = func() {
_, _ = os.Stderr.WriteString("Override (*FlagSet).Usage() to print help text.\n")
}
return fs
}
// ParseEnv parses the given environment variables into the OptionSet.
// Use EnvsWithPrefix to filter out prefixes.
func (optSet *OptionSet) ParseEnv(vs []EnvVar) error {
if optSet == nil {
return nil
}
var merr *multierror.Error
// We parse environment variables first instead of using a nested loop to
// avoid N*M complexity when there are a lot of options and environment
// variables.
envs := make(map[string]string)
for _, v := range vs {
envs[v.Name] = v.Value
}
for i, opt := range *optSet {
if opt.Env == "" {
continue
}
envVal, ok := envs[opt.Env]
if !ok {
// Homebrew strips all environment variables that do not start with `HOMEBREW_`.
// This prevented using brew to invoke the Coder agent, because the environment
// variables to not get passed down.
//
// A customer wanted to use their custom tap inside a workspace, which was failing
// because the agent lacked the environment variables to authenticate with Git.
envVal, ok = envs[`HOMEBREW_`+opt.Env]
}
// Currently, empty values are treated as if the environment variable is
// unset. This behavior is technically not correct as there is now no
// way for a user to change a Default value to an empty string from
// the environment. Unfortunately, we have old configuration files
// that rely on the faulty behavior.
//
// TODO: We should remove this hack in May 2023, when deployments
// have had months to migrate to the new behavior.
if !ok || envVal == "" {
continue
}
(*optSet)[i].ValueSource = ValueSourceEnv
if err := opt.Value.Set(envVal); err != nil {
merr = multierror.Append(
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
)
}
}
return merr.ErrorOrNil()
}
// SetDefaults sets the default values for each Option, skipping values
// that already have a value source.
func (optSet *OptionSet) SetDefaults() error {
if optSet == nil {
return nil
}
var merr *multierror.Error
for i, opt := range *optSet {
// Skip values that may have already been set by the user.
if opt.ValueSource != ValueSourceNone {
continue
}
if opt.Default == "" {
continue
}
if opt.Value == nil {
merr = multierror.Append(
merr,
xerrors.Errorf(
"parse %q: no Value field set\nFull opt: %+v",
opt.Name, opt,
),
)
continue
}
(*optSet)[i].ValueSource = ValueSourceDefault
if err := opt.Value.Set(opt.Default); err != nil {
merr = multierror.Append(
merr, xerrors.Errorf("parse %q: %w", opt.Name, err),
)
}
}
return merr.ErrorOrNil()
}
// ByName returns the Option with the given name, or nil if no such option
// exists.
func (optSet *OptionSet) ByName(name string) *Option {
for i := range *optSet {
opt := &(*optSet)[i]
if opt.Name == name {
return opt
}
}
return nil
}
+391
View File
@@ -0,0 +1,391 @@
package clibase_test
import (
"bytes"
"encoding/json"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
)
func TestOptionSet_ParseFlags(t *testing.T) {
t.Parallel()
t.Run("SimpleString", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
Flag: "workspace-name",
FlagShorthand: "n",
},
}
var err error
err = os.FlagSet().Parse([]string{"--workspace-name", "foo"})
require.NoError(t, err)
require.EqualValues(t, "foo", workspaceName)
err = os.FlagSet().Parse([]string{"-n", "f"})
require.NoError(t, err)
require.EqualValues(t, "f", workspaceName)
})
t.Run("StringArray", func(t *testing.T) {
t.Parallel()
var names clibase.StringArray
os := clibase.OptionSet{
clibase.Option{
Name: "name",
Value: &names,
Flag: "name",
FlagShorthand: "n",
},
}
err := os.SetDefaults()
require.NoError(t, err)
err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"})
require.NoError(t, err)
require.EqualValues(t, []string{"foo", "bar"}, names)
})
t.Run("ExtraFlags", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
},
}
err := os.FlagSet().Parse([]string{"--some-unknown", "foo"})
require.Error(t, err)
})
t.Run("RegexValid", func(t *testing.T) {
t.Parallel()
var regexpString clibase.Regexp
os := clibase.OptionSet{
clibase.Option{
Name: "RegexpString",
Value: &regexpString,
Flag: "regexp-string",
},
}
err := os.FlagSet().Parse([]string{"--regexp-string", "$test^"})
require.NoError(t, err)
})
t.Run("RegexInvalid", func(t *testing.T) {
t.Parallel()
var regexpString clibase.Regexp
os := clibase.OptionSet{
clibase.Option{
Name: "RegexpString",
Value: &regexpString,
Flag: "regexp-string",
},
}
err := os.FlagSet().Parse([]string{"--regexp-string", "(("})
require.Error(t, err)
})
}
func TestOptionSet_ParseEnv(t *testing.T) {
t.Parallel()
t.Run("SimpleString", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
Env: "WORKSPACE_NAME",
},
}
err := os.ParseEnv([]clibase.EnvVar{
{Name: "WORKSPACE_NAME", Value: "foo"},
})
require.NoError(t, err)
require.EqualValues(t, "foo", workspaceName)
})
t.Run("EmptyValue", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
Default: "defname",
Env: "WORKSPACE_NAME",
},
}
err := os.SetDefaults()
require.NoError(t, err)
err = os.ParseEnv(clibase.ParseEnviron([]string{"CODER_WORKSPACE_NAME="}, "CODER_"))
require.NoError(t, err)
require.EqualValues(t, "defname", workspaceName)
})
t.Run("StringSlice", func(t *testing.T) {
t.Parallel()
var actual clibase.StringArray
expected := []string{"foo", "bar", "baz"}
os := clibase.OptionSet{
clibase.Option{
Name: "name",
Value: &actual,
Env: "NAMES",
},
}
err := os.SetDefaults()
require.NoError(t, err)
err = os.ParseEnv([]clibase.EnvVar{
{Name: "NAMES", Value: "foo,bar,baz"},
})
require.NoError(t, err)
require.EqualValues(t, expected, actual)
})
t.Run("StructMapStringString", func(t *testing.T) {
t.Parallel()
var actual clibase.Struct[map[string]string]
expected := map[string]string{"foo": "bar", "baz": "zap"}
os := clibase.OptionSet{
clibase.Option{
Name: "labels",
Value: &actual,
Env: "LABELS",
},
}
err := os.SetDefaults()
require.NoError(t, err)
err = os.ParseEnv([]clibase.EnvVar{
{Name: "LABELS", Value: `{"foo":"bar","baz":"zap"}`},
})
require.NoError(t, err)
require.EqualValues(t, expected, actual.Value)
})
t.Run("Homebrew", func(t *testing.T) {
t.Parallel()
var agentToken clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Agent Token",
Value: &agentToken,
Env: "AGENT_TOKEN",
},
}
err := os.ParseEnv([]clibase.EnvVar{
{Name: "HOMEBREW_AGENT_TOKEN", Value: "foo"},
})
require.NoError(t, err)
require.EqualValues(t, "foo", agentToken)
})
}
func TestOptionSet_JsonMarshal(t *testing.T) {
t.Parallel()
// This unit test ensures if the source optionset is missing the option
// and cannot determine the type, it will not panic. The unmarshal will
// succeed with a best effort.
t.Run("MissingSrcOption", func(t *testing.T) {
t.Parallel()
var str clibase.String = "something"
var arr clibase.StringArray = []string{"foo", "bar"}
opts := clibase.OptionSet{
clibase.Option{
Name: "StringOpt",
Value: &str,
},
clibase.Option{
Name: "ArrayOpt",
Value: &arr,
},
}
data, err := json.Marshal(opts)
require.NoError(t, err, "marshal option set")
tgt := clibase.OptionSet{}
err = json.Unmarshal(data, &tgt)
require.NoError(t, err, "unmarshal option set")
for i := range opts {
compareOptionsExceptValues(t, opts[i], tgt[i])
require.Empty(t, tgt[i].Value.String(), "unknown value types are empty")
}
})
t.Run("RegexCase", func(t *testing.T) {
t.Parallel()
val := clibase.Regexp(*regexp.MustCompile(".*"))
opts := clibase.OptionSet{
clibase.Option{
Name: "Regex",
Value: &val,
Default: ".*",
},
}
data, err := json.Marshal(opts)
require.NoError(t, err, "marshal option set")
var foundVal clibase.Regexp
newOpts := clibase.OptionSet{
clibase.Option{
Name: "Regex",
Value: &foundVal,
},
}
err = json.Unmarshal(data, &newOpts)
require.NoError(t, err, "unmarshal option set")
require.EqualValues(t, opts[0].Value.String(), newOpts[0].Value.String())
})
t.Run("AllValues", func(t *testing.T) {
t.Parallel()
vals := coderdtest.DeploymentValues(t)
opts := vals.Options()
sources := []clibase.ValueSource{
clibase.ValueSourceNone,
clibase.ValueSourceFlag,
clibase.ValueSourceEnv,
clibase.ValueSourceYAML,
clibase.ValueSourceDefault,
}
for i := range opts {
opts[i].ValueSource = sources[i%len(sources)]
}
data, err := json.Marshal(opts)
require.NoError(t, err, "marshal option set")
newOpts := (&codersdk.DeploymentValues{}).Options()
err = json.Unmarshal(data, &newOpts)
require.NoError(t, err, "unmarshal option set")
for i := range opts {
exp := opts[i]
found := newOpts[i]
compareOptionsExceptValues(t, exp, found)
compareValues(t, exp, found)
}
thirdOpts := (&codersdk.DeploymentValues{}).Options()
data, err = json.Marshal(newOpts)
require.NoError(t, err, "marshal option set")
err = json.Unmarshal(data, &thirdOpts)
require.NoError(t, err, "unmarshal option set")
// Compare to the original opts again
for i := range opts {
exp := opts[i]
found := thirdOpts[i]
compareOptionsExceptValues(t, exp, found)
compareValues(t, exp, found)
}
})
}
func compareOptionsExceptValues(t *testing.T, exp, found clibase.Option) {
t.Helper()
require.Equalf(t, exp.Name, found.Name, "option name %q", exp.Name)
require.Equalf(t, exp.Description, found.Description, "option description %q", exp.Name)
require.Equalf(t, exp.Required, found.Required, "option required %q", exp.Name)
require.Equalf(t, exp.Flag, found.Flag, "option flag %q", exp.Name)
require.Equalf(t, exp.FlagShorthand, found.FlagShorthand, "option flag shorthand %q", exp.Name)
require.Equalf(t, exp.Env, found.Env, "option env %q", exp.Name)
require.Equalf(t, exp.YAML, found.YAML, "option yaml %q", exp.Name)
require.Equalf(t, exp.Default, found.Default, "option default %q", exp.Name)
require.Equalf(t, exp.ValueSource, found.ValueSource, "option value source %q", exp.Name)
require.Equalf(t, exp.Hidden, found.Hidden, "option hidden %q", exp.Name)
require.Equalf(t, exp.Annotations, found.Annotations, "option annotations %q", exp.Name)
require.Equalf(t, exp.Group, found.Group, "option group %q", exp.Name)
// UseInstead is the same comparison problem, just check the length
require.Equalf(t, len(exp.UseInstead), len(found.UseInstead), "option use instead %q", exp.Name)
}
func compareValues(t *testing.T, exp, found clibase.Option) {
t.Helper()
if (exp.Value == nil || found.Value == nil) || (exp.Value.String() != found.Value.String() && found.Value.String() == "") {
// If the string values are different, this can be a "nil" issue.
// So only run this case if the found string is the empty string.
// We use MarshalYAML for struct strings, and it will return an
// empty string '""' for nil slices/maps/etc.
// So use json to compare.
expJSON, err := json.Marshal(exp.Value)
require.NoError(t, err, "marshal")
foundJSON, err := json.Marshal(found.Value)
require.NoError(t, err, "marshal")
expJSON = normalizeJSON(expJSON)
foundJSON = normalizeJSON(foundJSON)
assert.Equalf(t, string(expJSON), string(foundJSON), "option value %q", exp.Name)
} else {
assert.Equal(t,
exp.Value.String(),
found.Value.String(),
"option value %q", exp.Name)
}
}
// normalizeJSON handles the fact that an empty map/slice is not the same
// as a nil empty/slice. For our purposes, they are the same.
func normalizeJSON(data []byte) []byte {
if bytes.Equal(data, []byte("[]")) || bytes.Equal(data, []byte("{}")) {
return []byte("null")
}
return data
}
+593
View File
@@ -0,0 +1,593 @@
package clibase
import (
"encoding/csv"
"encoding/json"
"fmt"
"net"
"net/url"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"
)
// NoOptDefValuer describes behavior when no
// option is passed into the flag.
//
// This is useful for boolean or otherwise binary flags.
type NoOptDefValuer interface {
NoOptDefValue() string
}
// Validator is a wrapper around a pflag.Value that allows for validation
// of the value after or before it has been set.
type Validator[T pflag.Value] struct {
Value T
// validate is called after the value is set.
validate func(T) error
}
func Validate[T pflag.Value](opt T, validate func(value T) error) *Validator[T] {
return &Validator[T]{Value: opt, validate: validate}
}
func (i *Validator[T]) String() string {
return i.Value.String()
}
func (i *Validator[T]) Set(input string) error {
err := i.Value.Set(input)
if err != nil {
return err
}
if i.validate != nil {
err = i.validate(i.Value)
if err != nil {
return err
}
}
return nil
}
func (i *Validator[T]) Type() string {
return i.Value.Type()
}
func (i *Validator[T]) MarshalYAML() (interface{}, error) {
m, ok := any(i.Value).(yaml.Marshaler)
if !ok {
return i.Value, nil
}
return m.MarshalYAML()
}
func (i *Validator[T]) UnmarshalYAML(n *yaml.Node) error {
return n.Decode(i.Value)
}
func (i *Validator[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(i.Value)
}
func (i *Validator[T]) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, i.Value)
}
func (i *Validator[T]) Underlying() pflag.Value { return i.Value }
// values.go contains a standard set of value types that can be used as
// Option Values.
type Int64 int64
func Int64Of(i *int64) *Int64 {
return (*Int64)(i)
}
func (i *Int64) Set(s string) error {
ii, err := strconv.ParseInt(s, 10, 64)
*i = Int64(ii)
return err
}
func (i Int64) Value() int64 {
return int64(i)
}
func (i Int64) String() string {
return strconv.Itoa(int(i))
}
func (Int64) Type() string {
return "int"
}
type Bool bool
func BoolOf(b *bool) *Bool {
return (*Bool)(b)
}
func (b *Bool) Set(s string) error {
if s == "" {
*b = Bool(false)
return nil
}
bb, err := strconv.ParseBool(s)
*b = Bool(bb)
return err
}
func (*Bool) NoOptDefValue() string {
return "true"
}
func (b Bool) String() string {
return strconv.FormatBool(bool(b))
}
func (b Bool) Value() bool {
return bool(b)
}
func (Bool) Type() string {
return "bool"
}
type String string
func StringOf(s *string) *String {
return (*String)(s)
}
func (*String) NoOptDefValue() string {
return ""
}
func (s *String) Set(v string) error {
*s = String(v)
return nil
}
func (s String) String() string {
return string(s)
}
func (s String) Value() string {
return string(s)
}
func (String) Type() string {
return "string"
}
var _ pflag.SliceValue = &StringArray{}
// StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue.
type StringArray []string
func StringArrayOf(ss *[]string) *StringArray {
return (*StringArray)(ss)
}
func (s *StringArray) Append(v string) error {
*s = append(*s, v)
return nil
}
func (s *StringArray) Replace(vals []string) error {
*s = vals
return nil
}
func (s *StringArray) GetSlice() []string {
return *s
}
func readAsCSV(v string) ([]string, error) {
return csv.NewReader(strings.NewReader(v)).Read()
}
func writeAsCSV(vals []string) string {
var sb strings.Builder
err := csv.NewWriter(&sb).Write(vals)
if err != nil {
return fmt.Sprintf("error: %s", err)
}
return sb.String()
}
func (s *StringArray) Set(v string) error {
if v == "" {
*s = nil
return nil
}
ss, err := readAsCSV(v)
if err != nil {
return err
}
*s = append(*s, ss...)
return nil
}
func (s StringArray) String() string {
return writeAsCSV([]string(s))
}
func (s StringArray) Value() []string {
return []string(s)
}
func (StringArray) Type() string {
return "string-array"
}
type Duration time.Duration
func DurationOf(d *time.Duration) *Duration {
return (*Duration)(d)
}
func (d *Duration) Set(v string) error {
dd, err := time.ParseDuration(v)
*d = Duration(dd)
return err
}
func (d *Duration) Value() time.Duration {
return time.Duration(*d)
}
func (d *Duration) String() string {
return time.Duration(*d).String()
}
func (Duration) Type() string {
return "duration"
}
func (d *Duration) MarshalYAML() (interface{}, error) {
return yaml.Node{
Kind: yaml.ScalarNode,
Value: d.String(),
}, nil
}
func (d *Duration) UnmarshalYAML(n *yaml.Node) error {
return d.Set(n.Value)
}
type URL url.URL
func URLOf(u *url.URL) *URL {
return (*URL)(u)
}
func (u *URL) Set(v string) error {
uu, err := url.Parse(v)
if err != nil {
return err
}
*u = URL(*uu)
return nil
}
func (u *URL) String() string {
uu := url.URL(*u)
return uu.String()
}
func (u *URL) MarshalYAML() (interface{}, error) {
return yaml.Node{
Kind: yaml.ScalarNode,
Value: u.String(),
}, nil
}
func (u *URL) UnmarshalYAML(n *yaml.Node) error {
return u.Set(n.Value)
}
func (u *URL) MarshalJSON() ([]byte, error) {
return json.Marshal(u.String())
}
func (u *URL) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
return u.Set(s)
}
func (*URL) Type() string {
return "url"
}
func (u *URL) Value() *url.URL {
return (*url.URL)(u)
}
// HostPort is a host:port pair.
type HostPort struct {
Host string
Port string
}
func (hp *HostPort) Set(v string) error {
if v == "" {
return xerrors.Errorf("must not be empty")
}
var err error
hp.Host, hp.Port, err = net.SplitHostPort(v)
return err
}
func (hp *HostPort) String() string {
if hp.Host == "" && hp.Port == "" {
return ""
}
// Warning: net.JoinHostPort must be used over concatenation to support
// IPv6 addresses.
return net.JoinHostPort(hp.Host, hp.Port)
}
func (hp *HostPort) MarshalJSON() ([]byte, error) {
return json.Marshal(hp.String())
}
func (hp *HostPort) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
if s == "" {
hp.Host = ""
hp.Port = ""
return nil
}
return hp.Set(s)
}
func (hp *HostPort) MarshalYAML() (interface{}, error) {
return yaml.Node{
Kind: yaml.ScalarNode,
Value: hp.String(),
}, nil
}
func (hp *HostPort) UnmarshalYAML(n *yaml.Node) error {
return hp.Set(n.Value)
}
func (*HostPort) Type() string {
return "host:port"
}
var (
_ yaml.Marshaler = new(Struct[struct{}])
_ yaml.Unmarshaler = new(Struct[struct{}])
)
// Struct is a special value type that encodes an arbitrary struct.
// It implements the flag.Value interface, but in general these values should
// only be accepted via config for ergonomics.
//
// The string encoding type is YAML.
type Struct[T any] struct {
Value T
}
//nolint:revive
func (s *Struct[T]) Set(v string) error {
return yaml.Unmarshal([]byte(v), &s.Value)
}
//nolint:revive
func (s *Struct[T]) String() string {
byt, err := yaml.Marshal(s.Value)
if err != nil {
return "decode failed: " + err.Error()
}
return string(byt)
}
// nolint:revive
func (s *Struct[T]) MarshalYAML() (interface{}, error) {
var n yaml.Node
err := n.Encode(s.Value)
if err != nil {
return nil, err
}
return n, nil
}
// nolint:revive
func (s *Struct[T]) UnmarshalYAML(n *yaml.Node) error {
// HACK: for compatibility with flags, we use nil slices instead of empty
// slices. In most cases, nil slices and empty slices are treated
// the same, so this behavior may be removed at some point.
if typ := reflect.TypeOf(s.Value); typ.Kind() == reflect.Slice && len(n.Content) == 0 {
reflect.ValueOf(&s.Value).Elem().Set(reflect.Zero(typ))
return nil
}
return n.Decode(&s.Value)
}
//nolint:revive
func (s *Struct[T]) Type() string {
return fmt.Sprintf("struct[%T]", s.Value)
}
// nolint:revive
func (s *Struct[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(s.Value)
}
// nolint:revive
func (s *Struct[T]) UnmarshalJSON(b []byte) error {
return json.Unmarshal(b, &s.Value)
}
// DiscardValue does nothing but implements the pflag.Value interface.
// It's useful in cases where you want to accept an option, but access the
// underlying value directly instead of through the Option methods.
var DiscardValue discardValue
type discardValue struct{}
func (discardValue) Set(string) error {
return nil
}
func (discardValue) String() string {
return ""
}
func (discardValue) Type() string {
return "discard"
}
func (discardValue) UnmarshalJSON([]byte) error {
return nil
}
// jsonValue is intentionally not exported. It is just used to store the raw JSON
// data for a value to defer it's unmarshal. It implements the pflag.Value to be
// usable in an Option.
type jsonValue json.RawMessage
func (jsonValue) Set(string) error {
return xerrors.Errorf("json value is read-only")
}
func (jsonValue) String() string {
return ""
}
func (jsonValue) Type() string {
return "json"
}
func (j *jsonValue) UnmarshalJSON(data []byte) error {
if j == nil {
return xerrors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
var _ pflag.Value = (*Enum)(nil)
type Enum struct {
Choices []string
Value *string
}
func EnumOf(v *string, choices ...string) *Enum {
return &Enum{
Choices: choices,
Value: v,
}
}
func (e *Enum) Set(v string) error {
for _, c := range e.Choices {
if v == c {
*e.Value = v
return nil
}
}
return xerrors.Errorf("invalid choice: %s, should be one of %v", v, e.Choices)
}
func (e *Enum) Type() string {
return fmt.Sprintf("enum[%v]", strings.Join(e.Choices, "\\|"))
}
func (e *Enum) String() string {
return *e.Value
}
type Regexp regexp.Regexp
func (r *Regexp) MarshalJSON() ([]byte, error) {
return json.Marshal(r.String())
}
func (r *Regexp) UnmarshalJSON(data []byte) error {
var source string
err := json.Unmarshal(data, &source)
if err != nil {
return err
}
exp, err := regexp.Compile(source)
if err != nil {
return xerrors.Errorf("invalid regex expression: %w", err)
}
*r = Regexp(*exp)
return nil
}
func (r *Regexp) MarshalYAML() (interface{}, error) {
return yaml.Node{
Kind: yaml.ScalarNode,
Value: r.String(),
}, nil
}
func (r *Regexp) UnmarshalYAML(n *yaml.Node) error {
return r.Set(n.Value)
}
func (r *Regexp) Set(v string) error {
exp, err := regexp.Compile(v)
if err != nil {
return xerrors.Errorf("invalid regex expression: %w", err)
}
*r = Regexp(*exp)
return nil
}
func (r Regexp) String() string {
return r.Value().String()
}
func (r *Regexp) Value() *regexp.Regexp {
if r == nil {
return nil
}
return (*regexp.Regexp)(r)
}
func (Regexp) Type() string {
return "regexp"
}
var _ pflag.Value = (*YAMLConfigPath)(nil)
// YAMLConfigPath is a special value type that encodes a path to a YAML
// configuration file where options are read from.
type YAMLConfigPath string
func (p *YAMLConfigPath) Set(v string) error {
*p = YAMLConfigPath(v)
return nil
}
func (p *YAMLConfigPath) String() string {
return string(*p)
}
func (*YAMLConfigPath) Type() string {
return "yaml-config-path"
}
+299
View File
@@ -0,0 +1,299 @@
package clibase
import (
"errors"
"fmt"
"strings"
"github.com/mitchellh/go-wordwrap"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"
)
var (
_ yaml.Marshaler = new(OptionSet)
_ yaml.Unmarshaler = new(OptionSet)
)
// deepMapNode returns the mapping node at the given path,
// creating it if it doesn't exist.
func deepMapNode(n *yaml.Node, path []string, headComment string) *yaml.Node {
if len(path) == 0 {
return n
}
// Name is every two nodes.
for i := 0; i < len(n.Content)-1; i += 2 {
if n.Content[i].Value == path[0] {
// Found matching name, recurse.
return deepMapNode(n.Content[i+1], path[1:], headComment)
}
}
// Not found, create it.
nameNode := yaml.Node{
Kind: yaml.ScalarNode,
Value: path[0],
HeadComment: headComment,
}
valueNode := yaml.Node{
Kind: yaml.MappingNode,
}
n.Content = append(n.Content, &nameNode)
n.Content = append(n.Content, &valueNode)
return deepMapNode(&valueNode, path[1:], headComment)
}
// MarshalYAML converts the option set to a YAML node, that can be
// converted into bytes via yaml.Marshal.
//
// The node is returned to enable post-processing higher up in
// the stack.
//
// It is isomorphic with FromYAML.
func (optSet *OptionSet) MarshalYAML() (any, error) {
root := yaml.Node{
Kind: yaml.MappingNode,
}
for _, opt := range *optSet {
if opt.YAML == "" {
continue
}
defValue := opt.Default
if defValue == "" {
defValue = "<unset>"
}
comment := wordwrap.WrapString(
fmt.Sprintf("%s\n(default: %s, type: %s)", opt.Description, defValue, opt.Value.Type()),
80,
)
nameNode := yaml.Node{
Kind: yaml.ScalarNode,
Value: opt.YAML,
HeadComment: comment,
}
_, isValidator := opt.Value.(interface{ Underlying() pflag.Value })
var valueNode yaml.Node
if opt.Value == nil {
valueNode = yaml.Node{
Kind: yaml.ScalarNode,
Value: "null",
}
} else if m, ok := opt.Value.(yaml.Marshaler); ok && !isValidator {
// Validators do a wrap, and should be handled by the else statement.
v, err := m.MarshalYAML()
if err != nil {
return nil, xerrors.Errorf(
"marshal %q: %w", opt.Name, err,
)
}
valueNode, ok = v.(yaml.Node)
if !ok {
return nil, xerrors.Errorf(
"marshal %q: unexpected underlying type %T",
opt.Name, v,
)
}
} else {
// The all-other types case.
//
// A bit of a hack, we marshal and then unmarshal to get
// the underlying node.
byt, err := yaml.Marshal(opt.Value)
if err != nil {
return nil, xerrors.Errorf(
"marshal %q: %w", opt.Name, err,
)
}
var docNode yaml.Node
err = yaml.Unmarshal(byt, &docNode)
if err != nil {
return nil, xerrors.Errorf(
"unmarshal %q: %w", opt.Name, err,
)
}
if len(docNode.Content) != 1 {
return nil, xerrors.Errorf(
"unmarshal %q: expected one node, got %d",
opt.Name, len(docNode.Content),
)
}
valueNode = *docNode.Content[0]
}
var group []string
for _, g := range opt.Group.Ancestry() {
if g.YAML == "" {
return nil, xerrors.Errorf(
"group yaml name is empty for %q, groups: %+v",
opt.Name,
opt.Group,
)
}
group = append(group, g.YAML)
}
var groupDesc string
if opt.Group != nil {
groupDesc = wordwrap.WrapString(opt.Group.Description, 80)
}
parentValueNode := deepMapNode(
&root, group,
groupDesc,
)
parentValueNode.Content = append(
parentValueNode.Content,
&nameNode,
&valueNode,
)
}
return &root, nil
}
// mapYAMLNodes converts parent into a map with keys of form "group.subgroup.option"
// and values as the corresponding YAML nodes.
func mapYAMLNodes(parent *yaml.Node) (map[string]*yaml.Node, error) {
if parent.Kind != yaml.MappingNode {
return nil, xerrors.Errorf("expected mapping node, got type %v", parent.Kind)
}
if len(parent.Content)%2 != 0 {
return nil, xerrors.Errorf("expected an even number of k/v pairs, got %d", len(parent.Content))
}
var (
key string
m = make(map[string]*yaml.Node, len(parent.Content)/2)
merr error
)
for i, child := range parent.Content {
if i%2 == 0 {
if child.Kind != yaml.ScalarNode {
// We immediately because the rest of the code is bound to fail
// if we don't know to expect a key or a value.
return nil, xerrors.Errorf("expected scalar node for key, got type %v", child.Kind)
}
key = child.Value
continue
}
// We don't know if this is a grouped simple option or complex option,
// so we store both "key" and "group.key". Since we're storing pointers,
// the additional memory is of little concern.
m[key] = child
if child.Kind != yaml.MappingNode {
continue
}
sub, err := mapYAMLNodes(child)
if err != nil {
merr = errors.Join(merr, xerrors.Errorf("mapping node %q: %w", key, err))
continue
}
for k, v := range sub {
m[key+"."+k] = v
}
}
return m, nil
}
func (o *Option) setFromYAMLNode(n *yaml.Node) error {
o.ValueSource = ValueSourceYAML
if um, ok := o.Value.(yaml.Unmarshaler); ok {
return um.UnmarshalYAML(n)
}
switch n.Kind {
case yaml.ScalarNode:
return o.Value.Set(n.Value)
case yaml.SequenceNode:
// We treat empty values as nil for consistency with other option
// mechanisms.
if len(n.Content) == 0 {
o.Value = nil
return nil
}
return n.Decode(o.Value)
case yaml.MappingNode:
return xerrors.Errorf("mapping nodes must implement yaml.Unmarshaler")
default:
return xerrors.Errorf("unexpected node kind %v", n.Kind)
}
}
// UnmarshalYAML converts the given YAML node into the option set.
// It is isomorphic with ToYAML.
func (optSet *OptionSet) UnmarshalYAML(rootNode *yaml.Node) error {
// The rootNode will be a DocumentNode if it's read from a file. We do
// not support multiple documents in a single file.
if rootNode.Kind == yaml.DocumentNode {
if len(rootNode.Content) != 1 {
return xerrors.Errorf("expected one node in document, got %d", len(rootNode.Content))
}
rootNode = rootNode.Content[0]
}
yamlNodes, err := mapYAMLNodes(rootNode)
if err != nil {
return xerrors.Errorf("mapping nodes: %w", err)
}
matchedNodes := make(map[string]*yaml.Node, len(yamlNodes))
var merr error
for i := range *optSet {
opt := &(*optSet)[i]
if opt.YAML == "" {
continue
}
var group []string
for _, g := range opt.Group.Ancestry() {
if g.YAML == "" {
return xerrors.Errorf(
"group yaml name is empty for %q, groups: %+v",
opt.Name,
opt.Group,
)
}
group = append(group, g.YAML)
delete(yamlNodes, strings.Join(group, "."))
}
key := strings.Join(append(group, opt.YAML), ".")
node, ok := yamlNodes[key]
if !ok {
continue
}
matchedNodes[key] = node
if opt.ValueSource != ValueSourceNone {
continue
}
if err := opt.setFromYAMLNode(node); err != nil {
merr = errors.Join(merr, xerrors.Errorf("setting %q: %w", opt.YAML, err))
}
}
// Remove all matched nodes and their descendants from yamlNodes so we
// can accurately report unknown options.
for k := range yamlNodes {
var key string
for _, part := range strings.Split(k, ".") {
if key != "" {
key += "."
}
key += part
if _, ok := matchedNodes[key]; ok {
delete(yamlNodes, k)
}
}
}
for k := range yamlNodes {
merr = errors.Join(merr, xerrors.Errorf("unknown option %q", k))
}
return merr
}
+202
View File
@@ -0,0 +1,202 @@
package clibase_test
import (
"testing"
"github.com/spf13/pflag"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"gopkg.in/yaml.v3"
"github.com/coder/coder/v2/cli/clibase"
)
func TestOptionSet_YAML(t *testing.T) {
t.Parallel()
t.Run("RequireKey", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
Default: "billie",
},
}
node, err := os.MarshalYAML()
require.NoError(t, err)
require.Len(t, node.(*yaml.Node).Content, 0)
})
t.Run("SimpleString", func(t *testing.T) {
t.Parallel()
var workspaceName clibase.String
os := clibase.OptionSet{
clibase.Option{
Name: "Workspace Name",
Value: &workspaceName,
Default: "billie",
Description: "The workspace's name.",
Group: &clibase.Group{YAML: "names"},
YAML: "workspaceName",
},
}
err := os.SetDefaults()
require.NoError(t, err)
n, err := os.MarshalYAML()
require.NoError(t, err)
// Visually inspect for now.
byt, err := yaml.Marshal(n)
require.NoError(t, err)
t.Logf("Raw YAML:\n%s", string(byt))
})
}
func TestOptionSet_YAMLUnknownOptions(t *testing.T) {
t.Parallel()
os := clibase.OptionSet{
{
Name: "Workspace Name",
Default: "billie",
Description: "The workspace's name.",
YAML: "workspaceName",
Value: new(clibase.String),
},
}
const yamlDoc = `something: else`
err := yaml.Unmarshal([]byte(yamlDoc), &os)
require.Error(t, err)
require.Empty(t, os[0].Value.String())
os[0].YAML = "something"
err = yaml.Unmarshal([]byte(yamlDoc), &os)
require.NoError(t, err)
require.Equal(t, "else", os[0].Value.String())
}
// TestOptionSet_YAMLIsomorphism tests that the YAML representations of an
// OptionSet converts to the same OptionSet when read back in.
func TestOptionSet_YAMLIsomorphism(t *testing.T) {
t.Parallel()
// This is used to form a generic.
//nolint:unused
type kid struct {
Name string `yaml:"name"`
Age int `yaml:"age"`
}
for _, tc := range []struct {
name string
os clibase.OptionSet
zeroValue func() pflag.Value
}{
{
name: "SimpleString",
os: clibase.OptionSet{
{
Name: "Workspace Name",
Default: "billie",
Description: "The workspace's name.",
Group: &clibase.Group{YAML: "names"},
YAML: "workspaceName",
},
},
zeroValue: func() pflag.Value {
return clibase.StringOf(new(string))
},
},
{
name: "Array",
os: clibase.OptionSet{
{
YAML: "names",
Default: "jill,jack,joan",
},
},
zeroValue: func() pflag.Value {
return clibase.StringArrayOf(&[]string{})
},
},
{
name: "ComplexObject",
os: clibase.OptionSet{
{
YAML: "kids",
Default: `- name: jill
age: 12
- name: jack
age: 13`,
},
},
zeroValue: func() pflag.Value {
return &clibase.Struct[[]kid]{}
},
},
{
name: "DeepGroup",
os: clibase.OptionSet{
{
YAML: "names",
Default: "jill,jack,joan",
Group: &clibase.Group{YAML: "kids", Parent: &clibase.Group{YAML: "family"}},
},
},
zeroValue: func() pflag.Value {
return clibase.StringArrayOf(&[]string{})
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Set initial values.
for i := range tc.os {
tc.os[i].Value = tc.zeroValue()
}
err := tc.os.SetDefaults()
require.NoError(t, err)
y, err := tc.os.MarshalYAML()
require.NoError(t, err)
toByt, err := yaml.Marshal(y)
require.NoError(t, err)
t.Logf("Raw YAML:\n%s", string(toByt))
var y2 yaml.Node
err = yaml.Unmarshal(toByt, &y2)
require.NoError(t, err)
os2 := slices.Clone(tc.os)
for i := range os2 {
os2[i].Value = tc.zeroValue()
os2[i].ValueSource = clibase.ValueSourceNone
}
// os2 values should be zeroed whereas tc.os should be
// set to defaults.
// This check makes sure we aren't mixing pointers.
require.NotEqual(t, tc.os, os2)
err = os2.UnmarshalYAML(&y2)
require.NoError(t, err)
want := tc.os
for i := range want {
want[i].ValueSource = clibase.ValueSourceYAML
}
require.Equal(t, tc.os, os2)
})
}
}
+2 -2
View File
@@ -14,9 +14,9 @@ import (
"cdr.dev/slog/sloggers/sloghuman"
"cdr.dev/slog/sloggers/slogjson"
"cdr.dev/slog/sloggers/slogstackdriver"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
type (
@@ -86,7 +86,7 @@ func FromDeploymentValues(vals *codersdk.DeploymentValues) Option {
}
}
func (b *Builder) Build(inv *serpent.Invocation) (log slog.Logger, closeLog func(), err error) {
func (b *Builder) Build(inv *clibase.Invocation) (log slog.Logger, closeLog func(), err error) {
var (
sinks = []slog.Sink{}
closers = []func() error{}
+14 -14
View File
@@ -8,10 +8,10 @@ import (
"strings"
"testing"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/clilog"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -23,7 +23,7 @@ func TestBuilder(t *testing.T) {
t.Run("NoConfiguration", func(t *testing.T) {
t.Parallel()
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t),
}
@@ -35,7 +35,7 @@ func TestBuilder(t *testing.T) {
t.Parallel()
tempFile := filepath.Join(t.TempDir(), "test.log")
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t,
clilog.WithHuman(tempFile),
@@ -51,7 +51,7 @@ func TestBuilder(t *testing.T) {
t.Parallel()
tempFile := filepath.Join(t.TempDir(), "test.log")
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t,
clilog.WithHuman(tempFile),
@@ -68,7 +68,7 @@ func TestBuilder(t *testing.T) {
t.Parallel()
tempFile := filepath.Join(t.TempDir(), "test.log")
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t, clilog.WithHuman(tempFile)),
}
@@ -81,7 +81,7 @@ func TestBuilder(t *testing.T) {
t.Parallel()
tempFile := filepath.Join(t.TempDir(), "test.log")
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t, clilog.WithJSON(tempFile), clilog.WithVerbose()),
}
@@ -107,7 +107,7 @@ func TestBuilder(t *testing.T) {
// Use the default deployment values.
dv := coderdtest.DeploymentValues(t)
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t, clilog.FromDeploymentValues(dv)),
}
@@ -127,15 +127,15 @@ func TestBuilder(t *testing.T) {
dv := &codersdk.DeploymentValues{
Logging: codersdk.LoggingConfig{
Filter: []string{"foo", "baz"},
Human: serpent.String(tempFile),
JSON: serpent.String(tempJSON),
Human: clibase.String(tempFile),
JSON: clibase.String(tempJSON),
},
Verbose: true,
Trace: codersdk.TraceConfig{
Enable: true,
},
}
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: testHandler(t, clilog.FromDeploymentValues(dv)),
}
@@ -150,9 +150,9 @@ func TestBuilder(t *testing.T) {
t.Parallel()
tempFile := filepath.Join(t.TempDir(), "doesnotexist", "test.log")
cmd := &serpent.Command{
cmd := &clibase.Cmd{
Use: "test",
Handler: func(inv *serpent.Invocation) error {
Handler: func(inv *clibase.Invocation) error {
logger, closeLog, err := clilog.New(
clilog.WithFilter("foo", "baz"),
clilog.WithHuman(tempFile),
@@ -181,10 +181,10 @@ var (
filterLog = "this is an important debug message you want to see"
)
func testHandler(t testing.TB, opts ...clilog.Option) serpent.HandlerFunc {
func testHandler(t testing.TB, opts ...clilog.Option) clibase.HandlerFunc {
t.Helper()
return func(inv *serpent.Invocation) error {
return func(inv *clibase.Invocation) error {
logger, closeLog, err := clilog.New(opts...).Build(inv)
if err != nil {
return err
+8 -18
View File
@@ -20,16 +20,16 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
// New creates a CLI instance with a configuration pointed to a
// temporary testing directory.
func New(t testing.TB, args ...string) (*serpent.Invocation, config.Root) {
func New(t testing.TB, args ...string) (*clibase.Invocation, config.Root) {
var root cli.RootCmd
cmd, err := root.Command(root.AGPL())
@@ -56,15 +56,15 @@ func (l *logWriter) Write(p []byte) (n int, err error) {
}
func NewWithCommand(
t testing.TB, cmd *serpent.Command, args ...string,
) (*serpent.Invocation, config.Root) {
t testing.TB, cmd *clibase.Cmd, args ...string,
) (*clibase.Invocation, config.Root) {
configDir := config.Root(t.TempDir())
// I really would like to fail test on error logs, but realistically, turning on by default
// in all our CLI tests is going to create a lot of flaky noise.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).
Leveled(slog.LevelDebug).
Named("cli")
i := &serpent.Invocation{
i := &clibase.Invocation{
Command: cmd,
Args: append([]string{"--global-config", string(configDir)}, args...),
Stdin: io.LimitReader(nil, 0),
@@ -140,11 +140,7 @@ func extractTar(t *testing.T, data []byte, directory string) {
// Start runs the command in a goroutine and cleans it up when the test
// completed.
func Start(t *testing.T, inv *serpent.Invocation) {
StartWithAssert(t, inv, nil)
}
func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(t *testing.T, err error)) { //nolint:revive
func Start(t *testing.T, inv *clibase.Invocation) {
t.Helper()
closeCh := make(chan struct{})
@@ -159,12 +155,6 @@ func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(
go func() {
defer close(closeCh)
err := waiter.Wait()
if assertCallback != nil {
assertCallback(t, err)
return
}
switch {
case errors.Is(err, context.Canceled):
return
@@ -175,7 +165,7 @@ func StartWithAssert(t *testing.T, inv *serpent.Invocation, assertCallback func(
}
// Run runs the command and asserts that there is no error.
func Run(t *testing.T, inv *serpent.Invocation) {
func Run(t *testing.T, inv *clibase.Invocation) {
t.Helper()
err := inv.Run()
@@ -228,7 +218,7 @@ func (w *ErrorWaiter) RequireAs(want interface{}) {
// StartWithWaiter runs the command in a goroutine but returns the error instead
// of asserting it. This is useful for testing error cases.
func StartWithWaiter(t *testing.T, inv *serpent.Invocation) *ErrorWaiter {
func StartWithWaiter(t *testing.T, inv *clibase.Invocation) *ErrorWaiter {
t.Helper()
var (
+38 -48
View File
@@ -11,15 +11,14 @@ import (
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
// UpdateGoldenFiles indicates golden files should be updated.
@@ -49,7 +48,7 @@ func DefaultCases() []CommandHelpCase {
// TestCommandHelp will test the help output of the given commands
// using golden files.
func TestCommandHelp(t *testing.T, getRoot func(t *testing.T) *serpent.Command, cases []CommandHelpCase) {
func TestCommandHelp(t *testing.T, getRoot func(t *testing.T) *clibase.Cmd, cases []CommandHelpCase) {
t.Parallel()
rootClient, replacements := prepareTestData(t)
@@ -88,45 +87,40 @@ ExtractCommandPathsLoop:
StartWithWaiter(t, inv.WithContext(ctx)).RequireSuccess()
TestGoldenFile(t, tt.Name, outBuf.Bytes(), replacements)
actual := outBuf.Bytes()
if len(actual) == 0 {
t.Fatal("no output")
}
for k, v := range replacements {
actual = bytes.ReplaceAll(actual, []byte(k), []byte(v))
}
actual = NormalizeGoldenFile(t, actual)
goldenPath := filepath.Join("testdata", strings.Replace(tt.Name, " ", "_", -1)+".golden")
if *UpdateGoldenFiles {
t.Logf("update golden file for: %q: %s", tt.Name, goldenPath)
err := os.WriteFile(goldenPath, actual, 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
expected = NormalizeGoldenFile(t, expected)
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
goldenPath,
)
})
}
}
// TestGoldenFile will test the given bytes slice input against the
// golden file with the given file name, optionally using the given replacements.
func TestGoldenFile(t *testing.T, fileName string, actual []byte, replacements map[string]string) {
if len(actual) == 0 {
t.Fatal("no output")
}
for k, v := range replacements {
actual = bytes.ReplaceAll(actual, []byte(k), []byte(v))
}
actual = normalizeGoldenFile(t, actual)
goldenPath := filepath.Join("testdata", strings.ReplaceAll(fileName, " ", "_")+".golden")
if *UpdateGoldenFiles {
t.Logf("update golden file for: %q: %s", fileName, goldenPath)
err := os.WriteFile(goldenPath, actual, 0o600)
require.NoError(t, err, "update golden file")
}
expected, err := os.ReadFile(goldenPath)
require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes")
expected = normalizeGoldenFile(t, expected)
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes",
goldenPath,
)
}
// normalizeGoldenFile replaces any strings that are system or timing dependent
// NormalizeGoldenFile replaces any strings that are system or timing dependent
// with a placeholder so that the golden files can be compared with a simple
// equality check.
func normalizeGoldenFile(t *testing.T, byt []byte) []byte {
func NormalizeGoldenFile(t *testing.T, byt []byte) []byte {
// Replace any timestamps with a placeholder.
byt = timestampRegex.ReplaceAll(byt, []byte("[timestamp]"))
@@ -154,7 +148,7 @@ func normalizeGoldenFile(t *testing.T, byt []byte) []byte {
return byt
}
func extractVisibleCommandPaths(cmdPath []string, cmds []*serpent.Command) [][]string {
func extractVisibleCommandPaths(cmdPath []string, cmds []*clibase.Cmd) [][]string {
var cmdPaths [][]string
for _, c := range cmds {
if c.Hidden {
@@ -173,22 +167,18 @@ func prepareTestData(t *testing.T) (*codersdk.Client, map[string]string) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// This needs to be a fixed timezone because timezones increase the length
// of timestamp strings. The increased length can pad table formatting's
// and differ the table header spacings.
//nolint:gocritic
db, pubsub := dbtestutil.NewDB(t, dbtestutil.WithTimezone("UTC"))
db, pubsub := dbtestutil.NewDB(t)
rootClient := coderdtest.New(t, &coderdtest.Options{
Database: db,
Pubsub: pubsub,
IncludeProvisionerDaemon: true,
})
firstUser := coderdtest.CreateFirstUser(t, rootClient)
secondUser, err := rootClient.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
Email: "testuser2@coder.com",
Username: "testuser2",
Password: coderdtest.FirstUserParams.Password,
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
secondUser, err := rootClient.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "testuser2@coder.com",
Username: "testuser2",
Password: coderdtest.FirstUserParams.Password,
OrganizationID: firstUser.OrganizationID,
})
require.NoError(t, err)
version := coderdtest.CreateTemplateVersion(t, rootClient, firstUser.OrganizationID, nil)
@@ -196,7 +186,7 @@ func prepareTestData(t *testing.T) (*codersdk.Client, map[string]string) {
template := coderdtest.CreateTemplate(t, rootClient, firstUser.OrganizationID, version.ID, func(req *codersdk.CreateTemplateRequest) {
req.Name = "test-template"
})
workspace := coderdtest.CreateWorkspace(t, rootClient, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
workspace := coderdtest.CreateWorkspace(t, rootClient, firstUser.OrganizationID, template.ID, func(req *codersdk.CreateWorkspaceRequest) {
req.Name = "test-workspace"
})
workspaceBuild := coderdtest.AwaitWorkspaceBuildJobCompleted(t, rootClient, workspace.LatestBuild.ID)
+4 -4
View File
@@ -3,7 +3,7 @@ package clitest
import (
"testing"
"github.com/coder/serpent"
"github.com/coder/coder/v2/cli/clibase"
)
// HandlersOK asserts that all commands have a handler.
@@ -11,11 +11,11 @@ import (
// non-root commands (like 'groups' or 'users'), a handler is required.
// These handlers are likely just the 'help' handler, but this must be
// explicitly set.
func HandlersOK(t *testing.T, cmd *serpent.Command) {
cmd.Walk(func(cmd *serpent.Command) {
func HandlersOK(t *testing.T, cmd *clibase.Cmd) {
cmd.Walk(func(cmd *clibase.Cmd) {
if cmd.Handler == nil {
// If you see this error, make the Handler a helper invoker.
// Handler: func(inv *serpent.Invocation) error {
// Handler: func(inv *clibase.Invocation) error {
// return inv.Command.HelpHandler(inv)
// },
t.Errorf("command %q has no handler, change to a helper invoker using: 'inv.Command.HelpHandler(inv)'", cmd.Name())
+6 -186
View File
@@ -2,20 +2,13 @@ package cliui
import (
"context"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
)
var errAgentShuttingDown = xerrors.New("agent is shutting down")
@@ -25,7 +18,6 @@ type AgentOptions struct {
Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error)
FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentLog, io.Closer, error)
Wait bool // If true, wait for the agent to be ready (startup script).
DocsURL string
}
// Agent displays a spinning indicator that waits for a workspace agent to connect.
@@ -120,7 +112,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
if agent.Status == codersdk.WorkspaceAgentTimeout {
now := time.Now()
sw.Log(now, codersdk.LogLevelInfo, "The workspace agent is having trouble connecting, wait for it to connect or restart your workspace.")
sw.Log(now, codersdk.LogLevelInfo, troubleshootingMessage(agent, fmt.Sprintf("%s/templates#agent-connection-issues", opts.DocsURL)))
sw.Log(now, codersdk.LogLevelInfo, troubleshootingMessage(agent, "https://coder.com/docs/v2/latest/templates#agent-connection-issues"))
for agent.Status == codersdk.WorkspaceAgentTimeout {
if agent, err = fetch(); err != nil {
return xerrors.Errorf("fetch: %w", err)
@@ -136,14 +128,11 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
}
stage := "Running workspace agent startup scripts"
follow := opts.Wait && agent.LifecycleState.Starting()
follow := opts.Wait
if !follow {
stage += " (non-blocking)"
}
sw.Start(stage)
if follow {
sw.Log(time.Time{}, codersdk.LogLevelInfo, "==> ︎ To connect immediately, reconnect with --wait=no or CODER_SSH_WAIT=no, see --help for more information.")
}
err = func() error { // Use func because of defer in for loop.
logStream, logsCloser, err := opts.FetchLogs(ctx, agent.ID, 0, follow)
@@ -213,25 +202,19 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
case codersdk.WorkspaceAgentLifecycleReady:
sw.Complete(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
case codersdk.WorkspaceAgentLifecycleStartTimeout:
// Backwards compatibility: Avoid printing warning if
// coderd is old and doesn't set ReadyAt for timeouts.
if agent.ReadyAt == nil {
sw.Fail(stage, 0)
} else {
sw.Fail(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
}
sw.Fail(stage, 0)
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script timed out and your workspace may be incomplete.")
case codersdk.WorkspaceAgentLifecycleStartError:
sw.Fail(stage, safeDuration(sw, agent.ReadyAt, agent.StartedAt))
// Use zero time (omitted) to separate these from the startup logs.
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Warning: A startup script exited with an error and your workspace may be incomplete.")
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/templates#startup-script-exited-with-an-error", opts.DocsURL)))
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, "https://coder.com/docs/v2/latest/templates#startup-script-exited-with-an-error"))
default:
switch {
case agent.LifecycleState.Starting():
// Use zero time (omitted) to separate these from the startup logs.
sw.Log(time.Time{}, codersdk.LogLevelWarn, "Notice: The startup scripts are still running and your workspace may be incomplete.")
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/templates#your-workspace-may-be-incomplete", opts.DocsURL)))
sw.Log(time.Time{}, codersdk.LogLevelWarn, troubleshootingMessage(agent, "https://coder.com/docs/v2/latest/templates#your-workspace-may-be-incomplete"))
// Note: We don't complete or fail the stage here, it's
// intentionally left open to indicate this stage didn't
// complete.
@@ -253,7 +236,7 @@ func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentO
stage := "The workspace agent lost connection"
sw.Start(stage)
sw.Log(time.Now(), codersdk.LogLevelWarn, "Wait for it to reconnect or restart your workspace.")
sw.Log(time.Now(), codersdk.LogLevelWarn, troubleshootingMessage(agent, fmt.Sprintf("%s/templates#agent-connection-issues", opts.DocsURL)))
sw.Log(time.Now(), codersdk.LogLevelWarn, troubleshootingMessage(agent, "https://coder.com/docs/v2/latest/templates#agent-connection-issues"))
disconnectedAt := agent.DisconnectedAt
for agent.Status == codersdk.WorkspaceAgentDisconnected {
@@ -298,166 +281,3 @@ type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}
func PeerDiagnostics(w io.Writer, d tailnet.PeerDiagnostics) {
if d.PreferredDERP > 0 {
rn, ok := d.DERPRegionNames[d.PreferredDERP]
if !ok {
rn = "unknown"
}
_, _ = fmt.Fprintf(w, "✔ preferred DERP region: %d (%s)\n", d.PreferredDERP, rn)
} else {
_, _ = fmt.Fprint(w, "✘ not connected to DERP\n")
}
if d.SentNode {
_, _ = fmt.Fprint(w, "✔ sent local data to Coder networking coordinator\n")
} else {
_, _ = fmt.Fprint(w, "✘ have not sent local data to Coder networking coordinator\n")
}
if d.ReceivedNode != nil {
dp := d.ReceivedNode.DERP
dn := ""
// should be 127.3.3.40:N where N is the DERP region
ap := strings.Split(dp, ":")
if len(ap) == 2 {
dp = ap[1]
di, err := strconv.Atoi(dp)
if err == nil {
var ok bool
dn, ok = d.DERPRegionNames[di]
if ok {
dn = fmt.Sprintf("(%s)", dn)
} else {
dn = "(unknown)"
}
}
}
_, _ = fmt.Fprintf(w,
"✔ received remote agent data from Coder networking coordinator\n preferred DERP region: %s %s\n endpoints: %s\n",
dp, dn, strings.Join(d.ReceivedNode.Endpoints, ", "))
} else {
_, _ = fmt.Fprint(w, "✘ have not received remote agent data from Coder networking coordinator\n")
}
if !d.LastWireguardHandshake.IsZero() {
ago := time.Since(d.LastWireguardHandshake)
symbol := "✔"
// wireguard is supposed to refresh handshake on 5 minute intervals
if ago > 5*time.Minute {
symbol = "⚠"
}
_, _ = fmt.Fprintf(w, "%s Wireguard handshake %s ago\n", symbol, ago.Round(time.Second))
} else {
_, _ = fmt.Fprint(w, "✘ Wireguard is not connected\n")
}
}
type ConnDiags struct {
ConnInfo workspacesdk.AgentConnectionInfo
PingP2P bool
DisableDirect bool
LocalNetInfo *tailcfg.NetInfo
LocalInterfaces *healthsdk.InterfacesReport
AgentNetcheck *healthsdk.AgentNetcheckReport
ClientIPIsAWS bool
AgentIPIsAWS bool
Verbose bool
TroubleshootingURL string
}
func (d ConnDiags) Write(w io.Writer) {
_, _ = fmt.Fprintln(w, "")
general, client, agent := d.splitDiagnostics()
for _, msg := range general {
_, _ = fmt.Fprintln(w, msg)
}
if len(general) > 0 {
_, _ = fmt.Fprintln(w, "")
}
if len(client) > 0 {
_, _ = fmt.Fprint(w, "Possible client-side issues with direct connection:\n\n")
for _, msg := range client {
_, _ = fmt.Fprintf(w, " - %s\n\n", msg)
}
}
if len(agent) > 0 {
_, _ = fmt.Fprint(w, "Possible agent-side issues with direct connections:\n\n")
for _, msg := range agent {
_, _ = fmt.Fprintf(w, " - %s\n\n", msg)
}
}
}
func (d ConnDiags) splitDiagnostics() (general, client, agent []string) {
if d.AgentNetcheck != nil {
for _, msg := range d.AgentNetcheck.Interfaces.Warnings {
agent = append(agent, msg.Message)
}
if len(d.AgentNetcheck.Interfaces.Warnings) > 0 {
agent[len(agent)-1] += fmt.Sprintf("\n%s#low-mtu", d.TroubleshootingURL)
}
}
if d.LocalInterfaces != nil {
for _, msg := range d.LocalInterfaces.Warnings {
client = append(client, msg.Message)
}
if len(d.LocalInterfaces.Warnings) > 0 {
client[len(client)-1] += fmt.Sprintf("\n%s#low-mtu", d.TroubleshootingURL)
}
}
if d.PingP2P && !d.Verbose {
return general, client, agent
}
if d.DisableDirect {
general = append(general, "❗ Direct connections are disabled locally, by `--disable-direct` or `CODER_DISABLE_DIRECT`")
if !d.Verbose {
return general, client, agent
}
}
if d.ConnInfo.DisableDirectConnections {
general = append(general,
fmt.Sprintf("❗ Your Coder administrator has blocked direct connections\n %s#disabled-deployment-wide", d.TroubleshootingURL))
if !d.Verbose {
return general, client, agent
}
}
if !d.ConnInfo.DERPMap.HasSTUN() {
general = append(general,
fmt.Sprintf("❗ The DERP map is not configured to use STUN\n %s#no-stun-servers", d.TroubleshootingURL))
} else if d.LocalNetInfo != nil && !d.LocalNetInfo.UDP {
client = append(client,
fmt.Sprintf("Client could not connect to STUN over UDP\n %s#udp-blocked", d.TroubleshootingURL))
}
if d.LocalNetInfo != nil && d.LocalNetInfo.MappingVariesByDestIP.EqualBool(true) {
client = append(client,
fmt.Sprintf("Client is potentially behind a hard NAT, as multiple endpoints were retrieved from different STUN servers\n %s#endpoint-dependent-nat-hard-nat", d.TroubleshootingURL))
}
if d.AgentNetcheck != nil && d.AgentNetcheck.NetInfo != nil {
if d.AgentNetcheck.NetInfo.MappingVariesByDestIP.EqualBool(true) {
agent = append(agent,
fmt.Sprintf("Agent is potentially behind a hard NAT, as multiple endpoints were retrieved from different STUN servers\n %s#endpoint-dependent-nat-hard-nat", d.TroubleshootingURL))
}
if !d.AgentNetcheck.NetInfo.UDP {
agent = append(agent,
fmt.Sprintf("Agent could not connect to STUN over UDP\n %s#udp-blocked", d.TroubleshootingURL))
}
}
if d.ClientIPIsAWS {
client = append(client,
fmt.Sprintf("Client IP address is within an AWS range (AWS uses hard NAT)\n %s#endpoint-dependent-nat-hard-nat", d.TroubleshootingURL))
}
if d.AgentIPIsAWS {
agent = append(agent,
fmt.Sprintf("Agent IP address is within an AWS range (AWS uses hard NAT)\n %s#endpoint-dependent-nat-hard-nat", d.TroubleshootingURL))
}
return general, client, agent
}
+7 -400
View File
@@ -6,7 +6,6 @@ import (
"context"
"io"
"os"
"regexp"
"strings"
"sync/atomic"
"testing"
@@ -16,18 +15,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/healthcheck/health"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/healthsdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func TestAgent(t *testing.T) {
@@ -98,8 +92,6 @@ func TestAgent(t *testing.T) {
iter: []func(context.Context, *testing.T, *codersdk.WorkspaceAgent, <-chan string, chan []codersdk.WorkspaceAgentLog) error{
func(_ context.Context, _ *testing.T, agent *codersdk.WorkspaceAgent, _ <-chan string, _ chan []codersdk.WorkspaceAgentLog) error {
agent.Status = codersdk.WorkspaceAgentConnecting
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting
agent.StartedAt = ptr.Ref(time.Now())
return nil
},
func(_ context.Context, t *testing.T, agent *codersdk.WorkspaceAgent, output <-chan string, _ chan []codersdk.WorkspaceAgentLog) error {
@@ -109,7 +101,6 @@ func TestAgent(t *testing.T) {
agent.Status = codersdk.WorkspaceAgentConnected
agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartTimeout
agent.FirstConnectedAt = ptr.Ref(time.Now())
agent.ReadyAt = ptr.Ref(time.Now())
return nil
},
},
@@ -232,7 +223,6 @@ func TestAgent(t *testing.T) {
},
want: []string{
"⧗ Running workspace agent startup scripts",
"︎ To connect immediately, reconnect with --wait=no or CODER_SSH_WAIT=no, see --help for more information.",
"testing: Hello world",
"Bye now",
"✔ Running workspace agent startup scripts",
@@ -261,9 +251,9 @@ func TestAgent(t *testing.T) {
},
},
want: []string{
"⧗ Running workspace agent startup scripts (non-blocking)",
"⧗ Running workspace agent startup scripts",
"Hello world",
"✘ Running workspace agent startup scripts (non-blocking)",
"✘ Running workspace agent startup scripts",
"Warning: A startup script exited with an error and your workspace may be incomplete.",
"For more information and troubleshooting, see",
},
@@ -313,7 +303,6 @@ func TestAgent(t *testing.T) {
},
want: []string{
"⧗ Running workspace agent startup scripts",
"︎ To connect immediately, reconnect with --wait=no or CODER_SSH_WAIT=no, see --help for more information.",
"Hello world",
"✔ Running workspace agent startup scripts",
},
@@ -390,8 +379,8 @@ func TestAgent(t *testing.T) {
output := make(chan string, 100) // Buffered to avoid blocking, overflow is discarded.
logs := make(chan []codersdk.WorkspaceAgentLog, 1)
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) {
t.Log("iter", len(tc.iter))
var err error
@@ -458,8 +447,8 @@ func TestAgent(t *testing.T) {
t.Parallel()
var fetchCalled uint64
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
buf := bytes.Buffer{}
err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{
FetchInterval: 10 * time.Millisecond,
@@ -487,385 +476,3 @@ func TestAgent(t *testing.T) {
require.NoError(t, cmd.Invoke().Run())
})
}
func TestPeerDiagnostics(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
diags tailnet.PeerDiagnostics
want []*regexp.Regexp // must be ordered, can omit lines
}{
{
name: "noPreferredDERP",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: make(map[int]string),
SentNode: true,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Now(),
},
want: []*regexp.Regexp{
regexp.MustCompile("^✘ not connected to DERP$"),
},
},
{
name: "preferredDERP",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 23,
DERPRegionNames: map[int]string{
23: "testo",
},
SentNode: true,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Now(),
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ preferred DERP region: 23 \(testo\)$`),
},
},
{
name: "sentNode",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: true,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ sent local data to Coder networking coordinator$`),
},
},
{
name: "didntSendNode",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: false,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✘ have not sent local data to Coder networking coordinator$`),
},
},
{
name: "receivedNodeDERPOKNoEndpoints",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{999: "Embedded"},
SentNode: true,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ received remote agent data from Coder networking coordinator$`),
regexp.MustCompile(`preferred DERP region: 999 \(Embedded\)$`),
regexp.MustCompile(`endpoints: $`),
},
},
{
name: "receivedNodeDERPUnknownNoEndpoints",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: true,
ReceivedNode: &tailcfg.Node{DERP: "127.3.3.40:999"},
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ received remote agent data from Coder networking coordinator$`),
regexp.MustCompile(`preferred DERP region: 999 \(unknown\)$`),
regexp.MustCompile(`endpoints: $`),
},
},
{
name: "receivedNodeEndpointsNoDERP",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{999: "Embedded"},
SentNode: true,
ReceivedNode: &tailcfg.Node{Endpoints: []string{"99.88.77.66:4555", "33.22.11.0:3444"}},
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ received remote agent data from Coder networking coordinator$`),
regexp.MustCompile(`preferred DERP region:\s*$`),
regexp.MustCompile(`endpoints: 99\.88\.77\.66:4555, 33\.22\.11\.0:3444$`),
},
},
{
name: "didntReceiveNode",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: false,
ReceivedNode: nil,
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✘ have not received remote agent data from Coder networking coordinator$`),
},
},
{
name: "noWireguardHandshake",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: false,
ReceivedNode: nil,
LastWireguardHandshake: time.Time{},
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✘ Wireguard is not connected$`),
},
},
{
name: "wireguardHandshakeRecent",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: false,
ReceivedNode: nil,
LastWireguardHandshake: time.Now().Add(-5 * time.Second),
},
want: []*regexp.Regexp{
regexp.MustCompile(`^✔ Wireguard handshake \d+s ago$`),
},
},
{
name: "wireguardHandshakeOld",
diags: tailnet.PeerDiagnostics{
PreferredDERP: 0,
DERPRegionNames: map[int]string{},
SentNode: false,
ReceivedNode: nil,
LastWireguardHandshake: time.Now().Add(-450 * time.Second), // 7m30s
},
want: []*regexp.Regexp{
regexp.MustCompile(`^⚠ Wireguard handshake 7m\d+s ago$`),
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r, w := io.Pipe()
go func() {
defer w.Close()
cliui.PeerDiagnostics(w, tc.diags)
}()
s := bufio.NewScanner(r)
i := 0
got := make([]string, 0)
for s.Scan() {
got = append(got, s.Text())
if i < len(tc.want) {
reg := tc.want[i]
if reg.Match(s.Bytes()) {
i++
}
}
}
if i < len(tc.want) {
t.Logf("failed to match regexp: %s\ngot:\n%s", tc.want[i].String(), strings.Join(got, "\n"))
t.FailNow()
}
})
}
}
func TestConnDiagnostics(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
diags cliui.ConnDiags
want []string
}{
{
name: "DirectBlocked",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
DisableDirectConnections: true,
},
},
want: []string{
`❗ Your Coder administrator has blocked direct connections`,
},
},
{
name: "NoStun",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
LocalNetInfo: &tailcfg.NetInfo{},
},
want: []string{
`The DERP map is not configured to use STUN`,
},
},
{
name: "ClientHasStunNoUDP",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
999: {
Nodes: []*tailcfg.DERPNode{
{
STUNPort: 1337,
},
},
},
},
},
},
LocalNetInfo: &tailcfg.NetInfo{
UDP: false,
},
},
want: []string{
`Client could not connect to STUN over UDP`,
},
},
{
name: "AgentHasStunNoUDP",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
999: {
Nodes: []*tailcfg.DERPNode{
{
STUNPort: 1337,
},
},
},
},
},
},
AgentNetcheck: &healthsdk.AgentNetcheckReport{
NetInfo: &tailcfg.NetInfo{
UDP: false,
},
},
},
want: []string{
`Agent could not connect to STUN over UDP`,
},
},
{
name: "ClientHardNat",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
LocalNetInfo: &tailcfg.NetInfo{
MappingVariesByDestIP: "true",
},
},
want: []string{
`Client is potentially behind a hard NAT, as multiple endpoints were retrieved from different STUN servers`,
},
},
{
name: "AgentHardNat",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
LocalNetInfo: &tailcfg.NetInfo{},
AgentNetcheck: &healthsdk.AgentNetcheckReport{
NetInfo: &tailcfg.NetInfo{MappingVariesByDestIP: "true"},
},
},
want: []string{
`Agent is potentially behind a hard NAT, as multiple endpoints were retrieved from different STUN servers`,
},
},
{
name: "AgentInterfaceWarnings",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
AgentNetcheck: &healthsdk.AgentNetcheckReport{
Interfaces: healthsdk.InterfacesReport{
BaseReport: healthsdk.BaseReport{
Warnings: []health.Message{
health.Messagef(health.CodeInterfaceSmallMTU, "Network interface eth0 has MTU 1280, (less than 1378), which may degrade the quality of direct connections"),
},
},
},
},
},
want: []string{
`Network interface eth0 has MTU 1280, (less than 1378), which may degrade the quality of direct connections`,
},
},
{
name: "LocalInterfaceWarnings",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
LocalInterfaces: &healthsdk.InterfacesReport{
BaseReport: healthsdk.BaseReport{
Warnings: []health.Message{
health.Messagef(health.CodeInterfaceSmallMTU, "Network interface eth1 has MTU 1310, (less than 1378), which may degrade the quality of direct connections"),
},
},
},
},
want: []string{
`Network interface eth1 has MTU 1310, (less than 1378), which may degrade the quality of direct connections`,
},
},
{
name: "ClientAWSIP",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
ClientIPIsAWS: true,
AgentIPIsAWS: false,
},
want: []string{
`Client IP address is within an AWS range (AWS uses hard NAT)`,
},
},
{
name: "AgentAWSIP",
diags: cliui.ConnDiags{
ConnInfo: workspacesdk.AgentConnectionInfo{
DERPMap: &tailcfg.DERPMap{},
},
ClientIPIsAWS: false,
AgentIPIsAWS: true,
},
want: []string{
`Agent IP address is within an AWS range (AWS uses hard NAT)`,
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r, w := io.Pipe()
go func() {
defer w.Close()
tc.diags.Write(w)
}()
bytes, err := io.ReadAll(r)
require.NoError(t, err)
output := string(bytes)
for _, want := range tc.want {
require.Contains(t, output, want)
}
})
}
}
+22 -34
View File
@@ -22,7 +22,6 @@ type Styles struct {
DateTimeStamp,
Error,
Field,
Hyperlink,
Keyword,
Placeholder,
Prompt,
@@ -38,21 +37,17 @@ var (
)
var (
// ANSI color codes
red = Color("1")
green = Color("2")
yellow = Color("3")
magenta = Color("5")
white = Color("7")
brightBlue = Color("12")
brightMagenta = Color("13")
Green = Color("#04B575")
Red = Color("#ED567A")
Fuchsia = Color("#EE6FF8")
Yellow = Color("#ECFD65")
Blue = Color("#5000ff")
)
// Color returns a color for the given string.
func Color(s string) termenv.Color {
colorOnce.Do(func() {
color = termenv.NewOutput(os.Stdout).EnvColorProfile()
color = termenv.NewOutput(os.Stdout).ColorProfile()
if flag.Lookup("test.v") != nil {
// Use a consistent colorless profile in tests so that results
// are deterministic.
@@ -128,49 +123,42 @@ func init() {
DefaultStyles = Styles{
Code: pretty.Style{
ifTerm(pretty.XPad(1, 1)),
pretty.FgColor(Color("#ED567A")),
pretty.BgColor(Color("#2C2C2C")),
pretty.FgColor(Red),
pretty.BgColor(color.Color("#2c2c2c")),
},
DateTimeStamp: pretty.Style{
pretty.FgColor(brightBlue),
pretty.FgColor(color.Color("#7571F9")),
},
Error: pretty.Style{
pretty.FgColor(red),
pretty.FgColor(Red),
},
Field: pretty.Style{
pretty.XPad(1, 1),
pretty.FgColor(Color("#FFFFFF")),
pretty.BgColor(Color("#2B2A2A")),
},
Fuchsia: pretty.Style{
pretty.FgColor(brightMagenta),
},
FocusedPrompt: pretty.Style{
pretty.FgColor(white),
pretty.Wrap("> ", ""),
pretty.FgColor(brightBlue),
},
Hyperlink: pretty.Style{
pretty.FgColor(magenta),
pretty.Underline(),
pretty.FgColor(color.Color("#FFFFFF")),
pretty.BgColor(color.Color("#2b2a2a")),
},
Keyword: pretty.Style{
pretty.FgColor(green),
pretty.FgColor(Green),
},
Placeholder: pretty.Style{
pretty.FgColor(magenta),
pretty.FgColor(color.Color("#4d46b3")),
},
Prompt: pretty.Style{
pretty.FgColor(white),
pretty.Wrap(" ", ""),
pretty.FgColor(color.Color("#5C5C5C")),
pretty.Wrap("> ", ""),
},
Warn: pretty.Style{
pretty.FgColor(yellow),
pretty.FgColor(Yellow),
},
Wrap: pretty.Style{
pretty.LineWrap(80),
},
}
DefaultStyles.FocusedPrompt = append(
DefaultStyles.Prompt,
pretty.FgColor(Blue),
)
}
// ValidateNotEmpty is a helper function to disallow empty inputs!
+4 -4
View File
@@ -3,13 +3,13 @@ package cliui
import (
"fmt"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func DeprecationWarning(message string) serpent.MiddlewareFunc {
return func(next serpent.HandlerFunc) serpent.HandlerFunc {
return func(i *serpent.Invocation) error {
func DeprecationWarning(message string) clibase.MiddlewareFunc {
return func(next clibase.HandlerFunc) clibase.HandlerFunc {
return func(i *clibase.Invocation) error {
_, _ = fmt.Fprintln(i.Stdout, "\n"+pretty.Sprint(DefaultStyles.Wrap,
pretty.Sprint(
DefaultStyles.Warn,
-3
View File
@@ -37,9 +37,6 @@ func ExternalAuth(ctx context.Context, writer io.Writer, opts ExternalAuthOption
if auth.Authenticated {
return nil
}
if auth.Optional {
continue
}
_, _ = fmt.Fprintf(writer, "You must authenticate with %s to create a workspace with this template. Visit:\n\n\t%s\n\n", auth.DisplayName, auth.AuthenticateURL)
+3 -3
View File
@@ -8,11 +8,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func TestExternalAuth(t *testing.T) {
@@ -22,8 +22,8 @@ func TestExternalAuth(t *testing.T) {
defer cancel()
ptty := ptytest.New(t)
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
var fetched atomic.Bool
return cliui.ExternalAuth(inv.Context(), inv.Stdout, cliui.ExternalAuthOptions{
Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionExternalAuth, error) {
+8 -8
View File
@@ -1,8 +1,8 @@
package cliui
import (
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/serpent"
)
var defaultQuery = "owner:me"
@@ -11,12 +11,12 @@ var defaultQuery = "owner:me"
// and allows easy integration to a CLI command.
// Example usage:
//
// func (r *RootCmd) MyCmd() *serpent.Command {
// func (r *RootCmd) MyCmd() *clibase.Cmd {
// var (
// filter cliui.WorkspaceFilter
// ...
// )
// cmd := &serpent.Command{
// cmd := &clibase.Cmd{
// ...
// }
// filter.AttachOptions(&cmd.Options)
@@ -44,20 +44,20 @@ func (w *WorkspaceFilter) Filter() codersdk.WorkspaceFilter {
return f
}
func (w *WorkspaceFilter) AttachOptions(opts *serpent.OptionSet) {
func (w *WorkspaceFilter) AttachOptions(opts *clibase.OptionSet) {
*opts = append(*opts,
serpent.Option{
clibase.Option{
Flag: "all",
FlagShorthand: "a",
Description: "Specifies whether all workspaces will be listed or not.",
Value: serpent.BoolOf(&w.all),
Value: clibase.BoolOf(&w.all),
},
serpent.Option{
clibase.Option{
Flag: "search",
Description: "Search for a workspace with a query.",
Default: defaultQuery,
Value: serpent.StringOf(&w.searchQuery),
Value: clibase.StringOf(&w.searchQuery),
},
)
}
+15 -20
View File
@@ -7,15 +7,14 @@ import (
"reflect"
"strings"
"github.com/jedib0t/go-pretty/v6/table"
"golang.org/x/xerrors"
"github.com/coder/serpent"
"github.com/coder/coder/v2/cli/clibase"
)
type OutputFormat interface {
ID() string
AttachOptions(opts *serpent.OptionSet)
AttachOptions(opts *clibase.OptionSet)
Format(ctx context.Context, data any) (string, error)
}
@@ -50,7 +49,7 @@ func NewOutputFormatter(formats ...OutputFormat) *OutputFormatter {
// AttachOptions attaches the --output flag to the given command, and any
// additional flags required by the output formatters.
func (f *OutputFormatter) AttachOptions(opts *serpent.OptionSet) {
func (f *OutputFormatter) AttachOptions(opts *clibase.OptionSet) {
for _, format := range f.formats {
format.AttachOptions(opts)
}
@@ -61,12 +60,12 @@ func (f *OutputFormatter) AttachOptions(opts *serpent.OptionSet) {
}
*opts = append(*opts,
serpent.Option{
clibase.Option{
Flag: "output",
FlagShorthand: "o",
Default: f.formats[0].ID(),
Value: serpent.EnumOf(&f.formatID, formatNames...),
Description: "Output format.",
Value: clibase.StringOf(&f.formatID),
Description: "Output format. Available formats: " + strings.Join(formatNames, ", ") + ".",
},
)
}
@@ -107,7 +106,7 @@ func TableFormat(out any, defaultColumns []string) OutputFormat {
}
// Get the list of table column headers.
headers, defaultSort, err := typeToTableHeaders(v.Type().Elem(), true)
headers, defaultSort, err := typeToTableHeaders(v.Type().Elem())
if err != nil {
panic("parse table headers: " + err.Error())
}
@@ -130,25 +129,21 @@ func (*tableFormat) ID() string {
}
// AttachOptions implements OutputFormat.
func (f *tableFormat) AttachOptions(opts *serpent.OptionSet) {
func (f *tableFormat) AttachOptions(opts *clibase.OptionSet) {
*opts = append(*opts,
serpent.Option{
clibase.Option{
Flag: "column",
FlagShorthand: "c",
Default: strings.Join(f.defaultColumns, ","),
Value: serpent.EnumArrayOf(&f.columns, f.allColumns...),
Description: "Columns to display in table output.",
Value: clibase.StringArrayOf(&f.columns),
Description: "Columns to display in table output. Available columns: " + strings.Join(f.allColumns, ", ") + ".",
},
)
}
// Format implements OutputFormat.
func (f *tableFormat) Format(_ context.Context, data any) (string, error) {
headers := make(table.Row, len(f.allColumns))
for i, header := range f.allColumns {
headers[i] = header
}
return renderTable(data, f.sort, headers, f.columns)
return DisplayTable(data, f.sort, f.columns)
}
type jsonFormat struct{}
@@ -166,7 +161,7 @@ func (jsonFormat) ID() string {
}
// AttachOptions implements OutputFormat.
func (jsonFormat) AttachOptions(_ *serpent.OptionSet) {}
func (jsonFormat) AttachOptions(_ *clibase.OptionSet) {}
// Format implements OutputFormat.
func (jsonFormat) Format(_ context.Context, data any) (string, error) {
@@ -192,7 +187,7 @@ func (textFormat) ID() string {
return "text"
}
func (textFormat) AttachOptions(_ *serpent.OptionSet) {}
func (textFormat) AttachOptions(_ *clibase.OptionSet) {}
func (textFormat) Format(_ context.Context, data any) (string, error) {
return fmt.Sprintf("%s", data), nil
@@ -218,7 +213,7 @@ func (d *DataChangeFormat) ID() string {
return d.format.ID()
}
func (d *DataChangeFormat) AttachOptions(opts *serpent.OptionSet) {
func (d *DataChangeFormat) AttachOptions(opts *clibase.OptionSet) {
d.format.AttachOptions(opts)
}
+16 -15
View File
@@ -8,13 +8,13 @@ import (
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/serpent"
)
type format struct {
id string
attachOptionsFn func(opts *serpent.OptionSet)
attachOptionsFn func(opts *clibase.OptionSet)
formatFn func(ctx context.Context, data any) (string, error)
}
@@ -24,7 +24,7 @@ func (f *format) ID() string {
return f.id
}
func (f *format) AttachOptions(opts *serpent.OptionSet) {
func (f *format) AttachOptions(opts *clibase.OptionSet) {
if f.attachOptionsFn != nil {
f.attachOptionsFn(opts)
}
@@ -85,12 +85,12 @@ func Test_OutputFormatter(t *testing.T) {
cliui.JSONFormat(),
&format{
id: "foo",
attachOptionsFn: func(opts *serpent.OptionSet) {
opts.Add(serpent.Option{
attachOptionsFn: func(opts *clibase.OptionSet) {
opts.Add(clibase.Option{
Name: "foo",
Flag: "foo",
FlagShorthand: "f",
Value: serpent.DiscardValue,
Value: clibase.DiscardValue,
Description: "foo flag 1234",
})
},
@@ -101,16 +101,16 @@ func Test_OutputFormatter(t *testing.T) {
},
)
cmd := &serpent.Command{}
cmd := &clibase.Cmd{}
f.AttachOptions(&cmd.Options)
fs := cmd.Options.FlagSet()
selected := cmd.Options.ByFlag("output")
require.NotNil(t, selected)
require.Equal(t, "json", selected.Value.String())
selected, err := fs.GetString("output")
require.NoError(t, err)
require.Equal(t, "json", selected)
usage := fs.FlagUsages()
require.Contains(t, usage, "Output format.")
require.Contains(t, usage, "Available formats: json, foo")
require.Contains(t, usage, "foo flag 1234")
ctx := context.Background()
@@ -129,10 +129,11 @@ func Test_OutputFormatter(t *testing.T) {
require.Equal(t, "foo", out)
require.EqualValues(t, 1, atomic.LoadInt64(&called))
require.Error(t, fs.Set("output", "bar"))
require.NoError(t, fs.Set("output", "bar"))
out, err = f.Format(ctx, data)
require.NoError(t, err)
require.Equal(t, "foo", out)
require.EqualValues(t, 2, atomic.LoadInt64(&called))
require.Error(t, err)
require.ErrorContains(t, err, "bar")
require.Equal(t, "", out)
require.EqualValues(t, 1, atomic.LoadInt64(&called))
})
}
+6 -14
View File
@@ -5,12 +5,12 @@ import (
"fmt"
"strings"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.TemplateVersionParameter, defaultOverrides map[string]string) (string, error) {
func RichParameter(inv *clibase.Invocation, templateVersionParameter codersdk.TemplateVersionParameter) (string, error) {
label := templateVersionParameter.Name
if templateVersionParameter.DisplayName != "" {
label = templateVersionParameter.DisplayName
@@ -26,11 +26,6 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
_, _ = fmt.Fprintln(inv.Stdout, " "+strings.TrimSpace(strings.Join(strings.Split(templateVersionParameter.DescriptionPlaintext, "\n"), "\n "))+"\n")
}
defaultValue := templateVersionParameter.DefaultValue
if v, ok := defaultOverrides[templateVersionParameter.Name]; ok {
defaultValue = v
}
var err error
var value string
if templateVersionParameter.Type == "list(string)" {
@@ -43,10 +38,7 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
return "", err
}
values, err := MultiSelect(inv, MultiSelectOptions{
Options: options,
Defaults: options,
})
values, err := MultiSelect(inv, options)
if err == nil {
v, err := json.Marshal(&values)
if err != nil {
@@ -66,7 +58,7 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
var richParameterOption *codersdk.TemplateVersionParameterOption
richParameterOption, err = RichSelect(inv, RichSelectOptions{
Options: templateVersionParameter.Options,
Default: defaultValue,
Default: templateVersionParameter.DefaultValue,
HideSearch: true,
})
if err == nil {
@@ -77,7 +69,7 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
} else {
text := "Enter a value"
if !templateVersionParameter.Required {
text += fmt.Sprintf(" (default: %q)", defaultValue)
text += fmt.Sprintf(" (default: %q)", templateVersionParameter.DefaultValue)
}
text += ":"
@@ -95,7 +87,7 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
// If they didn't specify anything, use the default value if set.
if len(templateVersionParameter.Options) == 0 && value == "" {
value = defaultValue
value = templateVersionParameter.DefaultValue
}
return value, nil
+5 -5
View File
@@ -13,8 +13,8 @@ import (
"github.com/mattn/go-isatty"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
// PromptOptions supply a set of options to the prompt.
@@ -30,13 +30,13 @@ const skipPromptFlag = "yes"
// SkipPromptOption adds a "--yes/-y" flag to the cmd that can be used to skip
// prompts.
func SkipPromptOption() serpent.Option {
return serpent.Option{
func SkipPromptOption() clibase.Option {
return clibase.Option{
Flag: skipPromptFlag,
FlagShorthand: "y",
Description: "Bypass prompts.",
// Discard
Value: serpent.BoolOf(new(bool)),
Value: clibase.BoolOf(new(bool)),
}
}
@@ -46,7 +46,7 @@ const (
)
// Prompt asks the user for input.
func Prompt(inv *serpent.Invocation, opts PromptOptions) (string, error) {
func Prompt(inv *clibase.Invocation, opts PromptOptions) (string, error) {
// If the cmd has a "yes" flag for skipping confirm prompts, honor it.
// If it's not a "Confirm" prompt, then don't skip. As the default value of
// "yes" makes no sense.
+7 -7
View File
@@ -11,11 +11,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/pty"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func TestPrompt(t *testing.T) {
@@ -77,7 +77,7 @@ func TestPrompt(t *testing.T) {
resp, err := newPrompt(ptty, cliui.PromptOptions{
Text: "ShouldNotSeeThis",
IsConfirm: true,
}, func(inv *serpent.Invocation) {
}, func(inv *clibase.Invocation) {
inv.Command.Options = append(inv.Command.Options, cliui.SkipPromptOption())
inv.Args = []string{"-y"}
})
@@ -145,10 +145,10 @@ func TestPrompt(t *testing.T) {
})
}
func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *serpent.Invocation)) (string, error) {
func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *clibase.Invocation)) (string, error) {
value := ""
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
var err error
value, err = cliui.Prompt(inv, opts)
return err
@@ -210,8 +210,8 @@ func TestPasswordTerminalState(t *testing.T) {
// nolint:unused
func passwordHelper() {
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
cliui.Prompt(inv, cliui.PromptOptions{
Text: "Password:",
Secret: true,
+5 -35
View File
@@ -54,11 +54,6 @@ func (err *ProvisionerJobError) Error() string {
return err.Message
}
const (
ProvisioningStateQueued = "Queued"
ProvisioningStateRunning = "Running"
)
// ProvisionerJob renders a provisioner job with interactive cancellation.
func ProvisionerJob(ctx context.Context, wr io.Writer, opts ProvisionerJobOptions) error {
if opts.FetchInterval == 0 {
@@ -68,9 +63,8 @@ func ProvisionerJob(ctx context.Context, wr io.Writer, opts ProvisionerJobOption
defer cancelFunc()
var (
currentStage = ProvisioningStateQueued
currentStage = "Queued"
currentStageStartedAt = time.Now().UTC()
currentQueuePos = -1
errChan = make(chan error, 1)
job codersdk.ProvisionerJob
@@ -80,20 +74,7 @@ func ProvisionerJob(ctx context.Context, wr io.Writer, opts ProvisionerJobOption
sw := &stageWriter{w: wr, verbose: opts.Verbose, silentLogs: opts.Silent}
printStage := func() {
out := currentStage
if currentStage == ProvisioningStateQueued && currentQueuePos > 0 {
var queuePos string
if currentQueuePos == 1 {
queuePos = "next"
} else {
queuePos = fmt.Sprintf("position: %d", currentQueuePos)
}
out = pretty.Sprintf(DefaultStyles.Warn, "%s (%s)", currentStage, queuePos)
}
sw.Start(out)
sw.Start(currentStage)
}
updateStage := func(stage string, startedAt time.Time) {
@@ -122,26 +103,15 @@ func ProvisionerJob(ctx context.Context, wr io.Writer, opts ProvisionerJobOption
errChan <- xerrors.Errorf("fetch: %w", err)
return
}
if job.QueuePosition != currentQueuePos {
initialState := currentQueuePos == -1
currentQueuePos = job.QueuePosition
// Print an update when the queue position changes, but:
// - not initially, because the stage is printed at startup
// - not when we're first in the queue, because it's redundant
if !initialState && currentQueuePos != 0 {
printStage()
}
}
if job.StartedAt == nil {
return
}
if currentStage != ProvisioningStateQueued {
if currentStage != "Queued" {
// If another stage is already running, there's no need
// for us to notify the user we're running!
return
}
updateStage(ProvisioningStateRunning, *job.StartedAt)
updateStage("Running", *job.StartedAt)
}
if opts.Cancel != nil {
@@ -173,8 +143,8 @@ func ProvisionerJob(ctx context.Context, wr io.Writer, opts ProvisionerJobOption
}
// The initial stage needs to print after the signal handler has been registered.
updateJob()
printStage()
updateJob()
logs, closer, err := opts.Logs()
if err != nil {
+26 -120
View File
@@ -2,10 +2,8 @@ package cliui_test
import (
"context"
"fmt"
"io"
"os"
"regexp"
"runtime"
"sync"
"testing"
@@ -13,13 +11,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/testutil"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/serpent"
)
// This cannot be ran in parallel because it uses a signal.
@@ -29,11 +25,7 @@ func TestProvisionerJob(t *testing.T) {
t.Parallel()
test := newProvisionerJob(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
testutil.Go(t, func() {
go func() {
<-test.Next
test.JobMutex.Lock()
test.Job.Status = codersdk.ProvisionerJobRunning
@@ -47,26 +39,20 @@ func TestProvisionerJob(t *testing.T) {
test.Job.CompletedAt = &now
close(test.Logs)
test.JobMutex.Unlock()
})
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
test.PTY.ExpectMatch(cliui.ProvisioningStateRunning)
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateRunning)
return true
}, testutil.IntervalFast)
}()
test.PTY.ExpectMatch("Queued")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Queued")
test.PTY.ExpectMatch("Running")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Running")
})
t.Run("Stages", func(t *testing.T) {
t.Parallel()
test := newProvisionerJob(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
testutil.Go(t, func() {
go func() {
<-test.Next
test.JobMutex.Lock()
test.Job.Status = codersdk.ProvisionerJobRunning
@@ -84,86 +70,13 @@ func TestProvisionerJob(t *testing.T) {
test.Job.CompletedAt = &now
close(test.Logs)
test.JobMutex.Unlock()
})
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
test.PTY.ExpectMatch("Something")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Something")
return true
}, testutil.IntervalFast)
})
t.Run("Queue Position", func(t *testing.T) {
t.Parallel()
stage := cliui.ProvisioningStateQueued
tests := []struct {
name string
queuePos int
expected string
}{
{
name: "first",
queuePos: 0,
expected: fmt.Sprintf("%s$", stage),
},
{
name: "next",
queuePos: 1,
expected: fmt.Sprintf(`%s %s$`, stage, regexp.QuoteMeta("(next)")),
},
{
name: "other",
queuePos: 4,
expected: fmt.Sprintf(`%s %s$`, stage, regexp.QuoteMeta("(position: 4)")),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
test := newProvisionerJob(t)
test.JobMutex.Lock()
test.Job.QueuePosition = tc.queuePos
test.Job.QueueSize = tc.queuePos
test.JobMutex.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
testutil.Go(t, func() {
<-test.Next
test.JobMutex.Lock()
test.Job.Status = codersdk.ProvisionerJobRunning
now := dbtime.Now()
test.Job.StartedAt = &now
test.JobMutex.Unlock()
<-test.Next
test.JobMutex.Lock()
test.Job.Status = codersdk.ProvisionerJobSucceeded
now = dbtime.Now()
test.Job.CompletedAt = &now
close(test.Logs)
test.JobMutex.Unlock()
})
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
test.PTY.ExpectRegexMatch(tc.expected)
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued) // step completed
test.PTY.ExpectMatch(cliui.ProvisioningStateRunning)
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateRunning)
return true
}, testutil.IntervalFast)
})
}
}()
test.PTY.ExpectMatch("Queued")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Queued")
test.PTY.ExpectMatch("Something")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Something")
})
// This cannot be ran in parallel because it uses a signal.
@@ -177,11 +90,7 @@ func TestProvisionerJob(t *testing.T) {
}
test := newProvisionerJob(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
testutil.Go(t, func() {
go func() {
<-test.Next
currentProcess, err := os.FindProcess(os.Getpid())
assert.NoError(t, err)
@@ -194,15 +103,12 @@ func TestProvisionerJob(t *testing.T) {
test.Job.CompletedAt = &now
close(test.Logs)
test.JobMutex.Unlock()
})
testutil.Eventually(ctx, t, func(ctx context.Context) (done bool) {
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
test.Next <- struct{}{}
test.PTY.ExpectMatch("Gracefully canceling")
test.Next <- struct{}{}
test.PTY.ExpectMatch(cliui.ProvisioningStateQueued)
return true
}, testutil.IntervalFast)
}()
test.PTY.ExpectMatch("Queued")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Gracefully canceling")
test.Next <- struct{}{}
test.PTY.ExpectMatch("Queued")
})
}
@@ -221,8 +127,8 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
}
jobLock := sync.Mutex{}
logs := make(chan codersdk.ProvisionerJobLog, 1)
cmd := &serpent.Command{
Handler: func(inv *serpent.Invocation) error {
cmd := &clibase.Cmd{
Handler: func(inv *clibase.Invocation) error {
return cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{
FetchInterval: time.Millisecond,
Fetch: func() (codersdk.ProvisionerJob, error) {
+83 -431
View File
@@ -1,59 +1,61 @@
package cliui
import (
"errors"
"flag"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/AlecAivazis/survey/v2"
"github.com/AlecAivazis/survey/v2/terminal"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
const defaultSelectModelHeight = 7
func init() {
survey.SelectQuestionTemplate = `
{{- define "option"}}
{{- " " }}{{- if eq .SelectedIndex .CurrentIndex }}{{color "green" }}{{ .Config.Icons.SelectFocus.Text }} {{else}}{{color "default"}} {{end}}
{{- .CurrentOpt.Value}}
{{- color "reset"}}
{{end}}
type terminateMsg struct{}
{{- if not .ShowAnswer }}
{{- if .Config.Icons.Help.Text }}
{{- if .FilterMessage }}{{ "Search:" }}{{ .FilterMessage }}
{{- else }}
{{- color "black+h"}}{{- "Type to search" }}{{color "reset"}}
{{- end }}
{{- "\n" }}
{{- end }}
{{- "\n" }}
{{- range $ix, $option := .PageEntries}}
{{- template "option" $.IterateOption $ix $option}}
{{- end}}
{{- end }}`
func installSignalHandler(p *tea.Program) func() {
ch := make(chan struct{})
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
defer func() {
signal.Stop(sig)
close(ch)
}()
for {
select {
case <-ch:
return
case <-sig:
p.Send(terminateMsg{})
}
}
}()
return func() {
ch <- struct{}{}
}
survey.MultiSelectQuestionTemplate = `
{{- define "option"}}
{{- if eq .SelectedIndex .CurrentIndex }}{{color .Config.Icons.SelectFocus.Format }}{{ .Config.Icons.SelectFocus.Text }}{{color "reset"}}{{else}} {{end}}
{{- if index .Checked .CurrentOpt.Index }}{{color .Config.Icons.MarkedOption.Format }} {{ .Config.Icons.MarkedOption.Text }} {{else}}{{color .Config.Icons.UnmarkedOption.Format }} {{ .Config.Icons.UnmarkedOption.Text }} {{end}}
{{- color "reset"}}
{{- " "}}{{- .CurrentOpt.Value}}
{{end}}
{{- if .ShowHelp }}{{- color .Config.Icons.Help.Format }}{{ .Config.Icons.Help.Text }} {{ .Help }}{{color "reset"}}{{"\n"}}{{end}}
{{- if not .ShowAnswer }}
{{- "\n"}}
{{- range $ix, $option := .PageEntries}}
{{- template "option" $.IterateOption $ix $option}}
{{- end}}
{{- end}}`
}
type SelectOptions struct {
Options []string
// Default will be highlighted first if it's a valid option.
Default string
Message string
Size int
HideSearch bool
}
@@ -66,7 +68,7 @@ type RichSelectOptions struct {
}
// RichSelect displays a list of user options including name and description.
func RichSelect(inv *serpent.Invocation, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) {
func RichSelect(inv *clibase.Invocation, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) {
opts := make([]string, len(richOptions.Options))
var defaultOpt string
for i, option := range richOptions.Options {
@@ -100,7 +102,7 @@ func RichSelect(inv *serpent.Invocation, richOptions RichSelectOptions) (*coders
}
// Select displays a list of user options.
func Select(inv *serpent.Invocation, opts SelectOptions) (string, error) {
func Select(inv *clibase.Invocation, opts SelectOptions) (string, error) {
// The survey library used *always* fails when testing on Windows,
// as it requires a live TTY (can't be a conpty). We should fork
// this library to add a dummy fallback, that simply reads/writes
@@ -110,416 +112,66 @@ func Select(inv *serpent.Invocation, opts SelectOptions) (string, error) {
return opts.Options[0], nil
}
initialModel := selectModel{
search: textinput.New(),
hideSearch: opts.HideSearch,
options: opts.Options,
height: opts.Size,
message: opts.Message,
var defaultOption interface{}
if opts.Default != "" {
defaultOption = opts.Default
}
if initialModel.height == 0 {
initialModel.height = defaultSelectModelHeight
}
initialModel.search.Prompt = ""
initialModel.search.Focus()
p := tea.NewProgram(
initialModel,
tea.WithoutSignalHandler(),
tea.WithContext(inv.Context()),
tea.WithInput(inv.Stdin),
tea.WithOutput(inv.Stdout),
)
closeSignalHandler := installSignalHandler(p)
defer closeSignalHandler()
m, err := p.Run()
if err != nil {
return "", err
}
model, ok := m.(selectModel)
if !ok {
return "", xerrors.New(fmt.Sprintf("unknown model found %T (%+v)", m, m))
}
if model.canceled {
return "", Canceled
}
return model.selected, nil
}
type selectModel struct {
search textinput.Model
options []string
cursor int
height int
message string
selected string
canceled bool
hideSearch bool
}
func (selectModel) Init() tea.Cmd {
return nil
}
//nolint:revive // The linter complains about modifying 'm' but this is typical practice for bubbletea
func (m selectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case terminateMsg:
m.canceled = true
return m, tea.Quit
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC:
m.canceled = true
return m, tea.Quit
case tea.KeyEnter:
options := m.filteredOptions()
if len(options) != 0 {
m.selected = options[m.cursor]
return m, tea.Quit
}
case tea.KeyUp:
options := m.filteredOptions()
if m.cursor > 0 {
m.cursor--
} else {
m.cursor = len(options) - 1
}
case tea.KeyDown:
options := m.filteredOptions()
if m.cursor < len(options)-1 {
m.cursor++
} else {
m.cursor = 0
}
var value string
err := survey.AskOne(&survey.Select{
Options: opts.Options,
Default: defaultOption,
PageSize: opts.Size,
}, &value, survey.WithIcons(func(is *survey.IconSet) {
is.Help.Text = "Type to search"
if opts.HideSearch {
is.Help.Text = ""
}
}), survey.WithStdio(fileReadWriter{
Reader: inv.Stdin,
}, fileReadWriter{
Writer: inv.Stdout,
}, inv.Stdout))
if errors.Is(err, terminal.InterruptErr) {
return value, Canceled
}
if !m.hideSearch {
oldSearch := m.search.Value()
m.search, cmd = m.search.Update(msg)
// If the search query has changed then we need to ensure
// the cursor is still pointing at a valid option.
if m.search.Value() != oldSearch {
options := m.filteredOptions()
if m.cursor > len(options)-1 {
m.cursor = max(0, len(options)-1)
}
}
}
return m, cmd
return value, err
}
func (m selectModel) View() string {
var s strings.Builder
msg := pretty.Sprintf(pretty.Bold(), "? %s", m.message)
if m.selected != "" {
selected := pretty.Sprint(DefaultStyles.Keyword, m.selected)
_, _ = s.WriteString(fmt.Sprintf("%s %s\n", msg, selected))
return s.String()
}
if m.hideSearch {
_, _ = s.WriteString(fmt.Sprintf("%s [Use arrows to move]\n", msg))
} else {
_, _ = s.WriteString(fmt.Sprintf(
"%s %s[Use arrows to move, type to filter]\n",
msg,
m.search.View(),
))
}
options, start := m.viewableOptions()
for i, option := range options {
// Is this the currently selected option?
style := pretty.Wrap(" ", "")
if m.cursor == start+i {
style = pretty.Style{
pretty.Wrap("> ", ""),
DefaultStyles.Keyword,
}
}
_, _ = s.WriteString(pretty.Sprint(style, option))
_, _ = s.WriteString("\n")
}
return s.String()
}
func (m selectModel) viewableOptions() ([]string, int) {
options := m.filteredOptions()
halfHeight := m.height / 2
bottom := 0
top := len(options)
switch {
case m.cursor <= halfHeight:
top = min(top, m.height)
case m.cursor < top-halfHeight:
bottom = max(0, m.cursor-halfHeight)
top = min(top, m.cursor+halfHeight+1)
default:
bottom = max(0, top-m.height)
}
return options[bottom:top], bottom
}
func (m selectModel) filteredOptions() []string {
options := []string{}
for _, o := range m.options {
filter := strings.ToLower(m.search.Value())
option := strings.ToLower(o)
if strings.Contains(option, filter) {
options = append(options, o)
}
}
return options
}
type MultiSelectOptions struct {
Message string
Options []string
Defaults []string
}
func MultiSelect(inv *serpent.Invocation, opts MultiSelectOptions) ([]string, error) {
func MultiSelect(inv *clibase.Invocation, items []string) ([]string, error) {
// Similar hack is applied to Select()
if flag.Lookup("test.v") != nil {
return opts.Defaults, nil
return items, nil
}
options := make([]*multiSelectOption, len(opts.Options))
for i, option := range opts.Options {
chosen := false
for _, d := range opts.Defaults {
if option == d {
chosen = true
break
}
}
options[i] = &multiSelectOption{
option: option,
chosen: chosen,
}
prompt := &survey.MultiSelect{
Options: items,
Default: items,
}
initialModel := multiSelectModel{
search: textinput.New(),
options: options,
message: opts.Message,
}
initialModel.search.Prompt = ""
initialModel.search.Focus()
p := tea.NewProgram(
initialModel,
tea.WithoutSignalHandler(),
tea.WithContext(inv.Context()),
tea.WithInput(inv.Stdin),
tea.WithOutput(inv.Stdout),
)
closeSignalHandler := installSignalHandler(p)
defer closeSignalHandler()
m, err := p.Run()
if err != nil {
return nil, err
}
model, ok := m.(multiSelectModel)
if !ok {
return nil, xerrors.New(fmt.Sprintf("unknown model found %T (%+v)", m, m))
}
if model.canceled {
var values []string
err := survey.AskOne(prompt, &values, survey.WithStdio(fileReadWriter{
Reader: inv.Stdin,
}, fileReadWriter{
Writer: inv.Stdout,
}, inv.Stdout))
if errors.Is(err, terminal.InterruptErr) {
return nil, Canceled
}
return model.selectedOptions(), nil
return values, err
}
type multiSelectOption struct {
option string
chosen bool
type fileReadWriter struct {
io.Reader
io.Writer
}
type multiSelectModel struct {
search textinput.Model
options []*multiSelectOption
cursor int
message string
canceled bool
selected bool
}
func (multiSelectModel) Init() tea.Cmd {
return nil
}
//nolint:revive // For same reason as previous Update definition
func (m multiSelectModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case terminateMsg:
m.canceled = true
return m, tea.Quit
case tea.KeyMsg:
switch msg.Type {
case tea.KeyCtrlC:
m.canceled = true
return m, tea.Quit
case tea.KeyEnter:
if len(m.options) != 0 {
m.selected = true
return m, tea.Quit
}
case tea.KeySpace:
options := m.filteredOptions()
if len(options) != 0 {
options[m.cursor].chosen = !options[m.cursor].chosen
}
// We back out early here otherwise a space will be inserted
// into the search field.
return m, nil
case tea.KeyUp:
options := m.filteredOptions()
if m.cursor > 0 {
m.cursor--
} else {
m.cursor = len(options) - 1
}
case tea.KeyDown:
options := m.filteredOptions()
if m.cursor < len(options)-1 {
m.cursor++
} else {
m.cursor = 0
}
case tea.KeyRight:
options := m.filteredOptions()
for _, option := range options {
option.chosen = true
}
case tea.KeyLeft:
options := m.filteredOptions()
for _, option := range options {
option.chosen = false
}
}
func (f fileReadWriter) Fd() uintptr {
if file, ok := f.Reader.(*os.File); ok {
return file.Fd()
}
oldSearch := m.search.Value()
m.search, cmd = m.search.Update(msg)
// If the search query has changed then we need to ensure
// the cursor is still pointing at a valid option.
if m.search.Value() != oldSearch {
options := m.filteredOptions()
if m.cursor > len(options)-1 {
m.cursor = max(0, len(options)-1)
}
if file, ok := f.Writer.(*os.File); ok {
return file.Fd()
}
return m, cmd
}
func (m multiSelectModel) View() string {
var s strings.Builder
msg := pretty.Sprintf(pretty.Bold(), "? %s", m.message)
if m.selected {
selected := pretty.Sprint(DefaultStyles.Keyword, strings.Join(m.selectedOptions(), ", "))
_, _ = s.WriteString(fmt.Sprintf("%s %s\n", msg, selected))
return s.String()
}
_, _ = s.WriteString(fmt.Sprintf(
"%s %s[Use arrows to move, space to select, <right> to all, <left> to none, type to filter]\n",
msg,
m.search.View(),
))
for i, option := range m.filteredOptions() {
cursor := " "
chosen := "[ ]"
o := option.option
if m.cursor == i {
cursor = pretty.Sprint(DefaultStyles.Keyword, "> ")
chosen = pretty.Sprint(DefaultStyles.Keyword, "[ ]")
o = pretty.Sprint(DefaultStyles.Keyword, o)
}
if option.chosen {
chosen = pretty.Sprint(DefaultStyles.Keyword, "[x]")
}
_, _ = s.WriteString(fmt.Sprintf(
"%s%s %s\n",
cursor,
chosen,
o,
))
}
return s.String()
}
func (m multiSelectModel) filteredOptions() []*multiSelectOption {
options := []*multiSelectOption{}
for _, o := range m.options {
filter := strings.ToLower(m.search.Value())
option := strings.ToLower(o.option)
if strings.Contains(option, filter) {
options = append(options, o)
}
}
return options
}
func (m multiSelectModel) selectedOptions() []string {
selected := []string{}
for _, o := range m.options {
if o.chosen {
selected = append(selected, o.option)
}
}
return selected
return 0
}

Some files were not shown because too many files have changed in this diff Show More