Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eee13c42a4 | |||
| 65b48c0f84 | |||
| 30cdf29e52 | |||
| b1d2bb6d71 | |||
| 94bad2a956 | |||
| 111714c7ed | |||
| 1f9c516c5c | |||
| 3645c65bb2 | |||
| d3d2d2fb1e | |||
| 086fb1f5d5 | |||
| a73a535a5b | |||
| 96e01c3018 | |||
| 6b10a0359b | |||
| b62583ad4b | |||
| 3d6727a2cb | |||
| b163962a14 | |||
| 9aca4ea27c | |||
| b0c10131ea | |||
| c8c7e13e96 | |||
| 249b7ea38e | |||
| 1333096e25 | |||
| 54bc9324dd | |||
| 109e5f2b19 | |||
| ee176b4207 | |||
| 7e1e16be33 | |||
| 5cfe8082ce | |||
| 6b7f672834 | |||
| c55f6252a1 | |||
| 842553b677 | |||
| 05a771ba77 | |||
| 70a0d42e65 | |||
| 6b1d73b466 | |||
| d7b9596145 | |||
| 7a0aa1a40a | |||
| 4d8ea43e11 | |||
| 6fddae98f6 | |||
| e33fbb6087 | |||
| 2337393e13 | |||
| d7357a1b0a | |||
| afbf1af29c | |||
| 1d834c747c | |||
| a80edec752 | |||
| 2a6473e8c6 | |||
| 1f9c0b9b7f | |||
| 5494afabd8 | |||
| 07c6e86a50 | |||
| b543821a1c | |||
| e8b7045a9b | |||
| 2571089528 | |||
| 1fb733fe1e |
@@ -1,249 +0,0 @@
|
||||
# Modern Go (1.18–1.26)
|
||||
|
||||
Reference for writing idiomatic Go. Covers what changed, what it
|
||||
replaced, and what to reach for. Respect the project's `go.mod` `go`
|
||||
line: don't emit features from a version newer than what the module
|
||||
declares. Check `go.mod` before writing code.
|
||||
|
||||
## How modern Go thinks differently
|
||||
|
||||
**Generics** (1.18): Design reusable code with type parameters instead
|
||||
of `interface{}` casts, code generation, or the `sort.Interface`
|
||||
pattern. Use `any` for unconstrained types, `comparable` for map keys
|
||||
and equality, `cmp.Ordered` for sortable types. Type inference usually
|
||||
makes explicit type arguments unnecessary (improved in 1.21).
|
||||
|
||||
**Per-iteration loop variables** (1.22): Each loop iteration gets its
|
||||
own variable copy. Closures inside loops capture the correct value. The
|
||||
`v := v` shadow trick is dead. Remove it when you see it.
|
||||
|
||||
**Iterators** (1.23): `iter.Seq[V]` and `iter.Seq2[K,V]` are the
|
||||
standard iterator types. Containers expose `.All()` methods returning
|
||||
these. Combined with `slices.Collect`, `slices.Sorted`, `maps.Keys`,
|
||||
etc., they replace ad-hoc "loop and append" code with composable,
|
||||
lazy pipelines. When a sequence is consumed only once, prefer an
|
||||
iterator over materializing a slice.
|
||||
|
||||
**Error trees** (1.20–1.26): Errors compose as trees, not chains.
|
||||
`errors.Join` aggregates multiple errors. `fmt.Errorf` accepts multiple
|
||||
`%w` verbs. `errors.Is`/`As` traverse the full tree. Custom error
|
||||
types that wrap multiple causes must implement `Unwrap() []error` (the
|
||||
slice form), not `Unwrap() error`, or tree traversal won't find the
|
||||
children. `errors.AsType[T]` (1.26) is the type-safe way to match
|
||||
error types. Propagate cancellation reasons with
|
||||
`context.WithCancelCause`.
|
||||
|
||||
**Structured logging** (1.21): `log/slog` is the standard structured
|
||||
logger. This project uses `cdr.dev/slog/v3` instead, which has a
|
||||
different API. Do not use `log/slog` directly.
|
||||
|
||||
## Replace these patterns
|
||||
|
||||
The left column reflects common patterns from pre-1.22 Go. Write the
|
||||
right column instead. The "Since" column tells you the minimum `go`
|
||||
directive version required in `go.mod`.
|
||||
|
||||
| Old pattern | Modern replacement | Since |
|
||||
|---|---|---|
|
||||
| `interface{}` | `any` | 1.18 |
|
||||
| `v := v` inside loops | remove it | 1.22 |
|
||||
| `for i := 0; i < n; i++` | `for i := range n` | 1.22 |
|
||||
| `for i := 0; i < b.N; i++` (benchmarks) | `for b.Loop()` (correct timing, future-proof) | 1.24 |
|
||||
| `sort.Slice(s, func(i,j int) bool{…})` | `slices.SortFunc(s, cmpFn)` | 1.21 |
|
||||
| `wg.Add(1); go func(){ defer wg.Done(); … }()` | `wg.Go(func(){…})` | 1.25 |
|
||||
| `func ptr[T any](v T) *T { return &v }` | `new(expr)` e.g. `new(time.Now())` | 1.26 |
|
||||
| `var target *E; errors.As(err, &target)` | `t, ok := errors.AsType[*E](err)` | 1.26 |
|
||||
| Custom multi-error type | `errors.Join(err1, err2, …)` | 1.20 |
|
||||
| Single `%w` for multiple causes | `fmt.Errorf("…: %w, %w", e1, e2)` | 1.20 |
|
||||
| `rand.Seed(time.Now().UnixNano())` | delete it (auto-seeded); prefer `math/rand/v2` | 1.20/1.22 |
|
||||
| `sync.Once` + captured variable | `sync.OnceValue(func() T {…})` / `OnceValues` | 1.21 |
|
||||
| Custom `min`/`max` helpers | `min(a, b)` / `max(a, b)` builtins (any ordered type) | 1.21 |
|
||||
| `for k := range m { delete(m, k) }` | `clear(m)` (also zeroes slices) | 1.21 |
|
||||
| Index+slice or `SplitN(s, sep, 2)` | `strings.Cut(s, sep)` / `bytes.Cut` | 1.18 |
|
||||
| `TrimPrefix` + check if anything was trimmed | `strings.CutPrefix` / `CutSuffix` (returns ok bool) | 1.20 |
|
||||
| `strings.Split` + loop when no slice is needed | `strings.SplitSeq` / `Lines` / `FieldsSeq` (iterator, no alloc) | 1.24 |
|
||||
| `"2006-01-02"` / `"2006-01-02 15:04:05"` / `"15:04:05"` | `time.DateOnly` / `time.DateTime` / `time.TimeOnly` | 1.20 |
|
||||
| Manual `Before`/`After`/`Equal` chains for comparison | `time.Time.Compare` (returns -1/0/+1; works with `slices.SortFunc`) | 1.20 |
|
||||
| Loop collecting map keys into slice | `slices.Sorted(maps.Keys(m))` | 1.23 |
|
||||
| `fmt.Sprintf` + append to `[]byte` | `fmt.Appendf(buf, …)` (also `Append`, `Appendln`) | 1.18 |
|
||||
| `reflect.TypeOf((*T)(nil)).Elem()` | `reflect.TypeFor[T]()` | 1.22 |
|
||||
| `*(*[4]byte)(slice)` unsafe cast | `[4]byte(slice)` direct conversion | 1.20 |
|
||||
| `atomic.LoadInt64` / `StoreInt64` | `atomic.Int64` (also `Bool`, `Uint64`, `Pointer[T]`) | 1.19 |
|
||||
| `crypto/rand.Read(buf)` + hex/base64 encode | `crypto/rand.Text()` (one call) | 1.24 |
|
||||
| Checking `crypto/rand.Read` error | don't: return is always nil | 1.24 |
|
||||
| `time.Sleep` in tests | `testing/synctest` (deterministic fake clock) | 1.24/1.25 |
|
||||
| `json:",omitempty"` on zero-value structs like `time.Time{}` | `json:",omitzero"` (uses `IsZero()` method) | 1.24 |
|
||||
| `strings.Title` | `golang.org/x/text/cases` | 1.18 |
|
||||
| `net.IP` in new code | `net/netip.Addr` (immutable, comparable, lighter) | 1.18 |
|
||||
| `tools.go` with blank imports | `tool` directive in `go.mod` | 1.24 |
|
||||
| `runtime.SetFinalizer` | `runtime.AddCleanup` (multiple per object, no pointer cycles) | 1.24 |
|
||||
| `httputil.ReverseProxy.Director` | `.Rewrite` hook + `ProxyRequest` (Director deprecated in 1.26) | 1.20 |
|
||||
| `sql.NullString`, `sql.NullInt64`, etc. | `sql.Null[T]` | 1.22 |
|
||||
| Manual `ctx, cancel := context.WithCancel(…)` + `t.Cleanup(cancel)` | `t.Context()` (auto-canceled when test ends) | 1.24 |
|
||||
| `if d < 0 { d = -d }` on durations | `d.Abs()` (handles `math.MinInt64`) | 1.19 |
|
||||
| Implement only `TextMarshaler` | also implement `TextAppender` for alloc-free marshaling | 1.24 |
|
||||
| Custom `Unwrap() error` on multi-cause errors | `Unwrap() []error` (slice form; required for tree traversal) | 1.20 |
|
||||
|
||||
## New capabilities
|
||||
|
||||
These enable things that weren't practical before. Reach for them in the
|
||||
described situations.
|
||||
|
||||
| What | Since | When to use it |
|
||||
|---|---|---|
|
||||
| `cmp.Or(a, b, c)` | 1.22 | Defaults/fallback chains: returns first non-zero value. Replaces verbose `if a != "" { return a }` cascades. |
|
||||
| `context.WithoutCancel(ctx)` | 1.21 | Background work that must outlive the request (e.g. async cleanup after HTTP response). Derived context keeps parent's values but ignores cancellation. |
|
||||
| `context.AfterFunc(ctx, fn)` | 1.21 | Register cleanup that fires on context cancellation without spawning a goroutine that blocks on `<-ctx.Done()`. |
|
||||
| `context.WithCancelCause` / `Cause` | 1.20 | When callers need to know WHY a context was canceled, not just that it was. Retrieve cause with `context.Cause(ctx)`. |
|
||||
| `context.WithDeadlineCause` / `WithTimeoutCause` | 1.21 | Attach a domain-specific error to deadline/timeout expiry (e.g. distinguish "DB query timed out" from "HTTP request timed out"). |
|
||||
| `errors.ErrUnsupported` | 1.21 | Standard sentinel for "not supported." Use instead of per-package custom sentinels. Check with `errors.Is`. |
|
||||
| `http.ResponseController` | 1.20 | Per-request flush, hijack, and deadline control without type-asserting `ResponseWriter` to `http.Flusher` or `http.Hijacker`. |
|
||||
| Enhanced `ServeMux` routing | 1.22 | `"GET /items/{id}"` patterns in `http.ServeMux`. Access with `r.PathValue("id")`. Wildcards: `{name}`, catch-all: `{path...}`, exact: `{$}`. Eliminates many third-party router dependencies. |
|
||||
| `os.Root` / `OpenRoot` | 1.24 | Confined directory access that prevents symlink escape. 1.25 adds `MkdirAll`, `ReadFile`, `WriteFile` for real use. |
|
||||
| `os.CopyFS` | 1.23 | Copy an entire `fs.FS` to local filesystem in one call. |
|
||||
| `os/signal.NotifyContext` with cause | 1.26 | Cancellation cause identifies which signal (SIGTERM vs SIGINT) triggered shutdown. |
|
||||
| `io/fs.SkipAll` / `filepath.SkipAll` | 1.20 | Return from `WalkDir` callback to stop walking entirely. Cleaner than a sentinel error. |
|
||||
| `GOMEMLIMIT` env / `debug.SetMemoryLimit` | 1.19 | Soft memory limit for GC. Use alongside or instead of `GOGC` in memory-constrained containers. |
|
||||
| `net/url.JoinPath` | 1.19 | Join URL path segments correctly. Replaces error-prone string concatenation. |
|
||||
| `go test -skip` | 1.20 | Skip tests matching a pattern. Useful when running a subset of a large test suite. |
|
||||
|
||||
## Key packages
|
||||
|
||||
### `slices` (1.21, iterators added 1.23)
|
||||
|
||||
Replaces `sort.Slice`, manual search loops, and manual contains checks.
|
||||
|
||||
Search: `Contains`, `ContainsFunc`, `Index`, `IndexFunc`,
|
||||
`BinarySearch`, `BinarySearchFunc`.
|
||||
|
||||
Sort: `Sort`, `SortFunc`, `SortStableFunc`, `IsSorted`, `IsSortedFunc`,
|
||||
`Min`, `MinFunc`, `Max`, `MaxFunc`.
|
||||
|
||||
Transform: `Clone`, `Compact`, `CompactFunc`, `Grow`, `Clip`,
|
||||
`Concat` (1.22), `Repeat` (1.23), `Reverse`, `Insert`, `Delete`,
|
||||
`Replace`.
|
||||
|
||||
Compare: `Equal`, `EqualFunc`, `Compare`.
|
||||
|
||||
Iterators (1.23): `All`, `Values`, `Backward`, `Collect`, `AppendSeq`,
|
||||
`Sorted`, `SortedFunc`, `SortedStableFunc`, `Chunk`.
|
||||
|
||||
### `maps` (1.21, iterators added 1.23)
|
||||
|
||||
Core: `Clone`, `Copy`, `Equal`, `EqualFunc`, `DeleteFunc`.
|
||||
|
||||
Iterators (1.23): `All`, `Keys`, `Values`, `Insert`, `Collect`.
|
||||
|
||||
### `cmp` (1.21, `Or` added 1.22)
|
||||
|
||||
`Ordered` constraint for any ordered type. `Compare(a, b)` returns
|
||||
-1/0/+1. `Less(a, b)` returns bool. `Or(vals...)` returns first
|
||||
non-zero value.
|
||||
|
||||
### `iter` (1.23)
|
||||
|
||||
`Seq[V]` is `func(yield func(V) bool)`. `Seq2[K,V]` is
|
||||
`func(yield func(K, V) bool)`. Return these from your container's
|
||||
`.All()` methods. Consume with `for v := range seq` or pass to
|
||||
`slices.Collect`, `slices.Sorted`, `maps.Collect`, etc.
|
||||
|
||||
### `math/rand/v2` (1.22)
|
||||
|
||||
Replaces `math/rand`. `IntN` not `Intn`. Generic `N[T]()` for any
|
||||
integer type. Default source is `ChaCha8` (crypto-quality). No global
|
||||
`Seed`. Use `rand.New(source)` for reproducible sequences.
|
||||
|
||||
### `log/slog` (1.21)
|
||||
|
||||
`slog.Info`, `slog.Warn`, `slog.Error`, `slog.Debug` with key-value
|
||||
pairs. `slog.With(attrs...)` for logger with preset fields.
|
||||
`slog.GroupAttrs` (1.25) for clean group creation. Implement
|
||||
`slog.Handler` for custom backends.
|
||||
|
||||
**Note:** This project uses `cdr.dev/slog/v3`, not `log/slog`. The
|
||||
API is different. Read existing code for usage patterns.
|
||||
|
||||
## Pitfalls
|
||||
|
||||
Things that are easy to get wrong, even when you know the modern API
|
||||
exists. Check your output against these.
|
||||
|
||||
**Version misuse.** The replacement table has a "Since" column. If the
|
||||
project's `go.mod` says `go 1.22`, you cannot use `wg.Go` (1.25),
|
||||
`errors.AsType` (1.26), `new(expr)` (1.26), `b.Loop()` (1.24), or
|
||||
`testing/synctest` (1.24). Fall back to the older pattern. Always
|
||||
check before reaching for a replacement.
|
||||
|
||||
**`slices.Sort` vs `slices.SortFunc`.** `slices.Sort` requires
|
||||
`cmp.Ordered` types (int, string, float64, etc.). For structs, custom
|
||||
types, or multi-field sorting, use `slices.SortFunc` with a comparator
|
||||
function. Using `slices.Sort` on a non-ordered type is a compile error.
|
||||
|
||||
**`for range n` still binds the index.** `for range n` discards the
|
||||
index. If you need it, write `for i := range n`. Writing
|
||||
`for range n` and then trying to use `i` inside the loop is a compile
|
||||
error.
|
||||
|
||||
**Don't hand-roll iterators when the stdlib returns one.** Functions
|
||||
like `maps.Keys`, `slices.Values`, `strings.SplitSeq`, and
|
||||
`strings.Lines` already return `iter.Seq` or `iter.Seq2`. Don't
|
||||
reimplement them. Compose with `slices.Collect`, `slices.Sorted`, etc.
|
||||
|
||||
**Don't mix `math/rand` and `math/rand/v2`.** They have different
|
||||
function names (`Intn` vs `IntN`) and different default sources. Pick
|
||||
one per package. Prefer v2 for new code. The v1 global source is
|
||||
auto-seeded since 1.20, so delete `rand.Seed` calls either way.
|
||||
|
||||
**Iterator protocol.** When implementing `iter.Seq`, you must respect
|
||||
the `yield` return value. If `yield` returns `false`, stop iteration
|
||||
immediately and return. Ignoring it violates the contract and causes
|
||||
panics when consumers break out of `for range` loops early.
|
||||
|
||||
**`errors.Join` with nil.** `errors.Join` skips nil arguments. This is
|
||||
intentional and useful for aggregating optional errors, but don't
|
||||
assume the result is always non-nil. `errors.Join(nil, nil)` returns
|
||||
nil.
|
||||
|
||||
**`cmp.Or` evaluates all arguments.** Unlike a chain of `if`
|
||||
statements, `cmp.Or(a(), b(), c())` calls all three functions. If any
|
||||
have side effects or are expensive, use `if`/`else` instead.
|
||||
|
||||
**Timer channel semantics changed in 1.23.** Code that checks
|
||||
`len(timer.C)` to see if a value is pending no longer works (channel
|
||||
capacity is 0). Use a non-blocking `select` receive instead:
|
||||
`select { case <-timer.C: default: }`.
|
||||
|
||||
**`context.WithoutCancel` still propagates values.** The derived
|
||||
context inherits all values from the parent. If any middleware stores
|
||||
request-scoped state (deadlines, trace IDs) via `context.WithValue`,
|
||||
the background work sees it. This is usually desired but can be
|
||||
surprising if the values hold references that should not outlive the
|
||||
request.
|
||||
|
||||
## Behavioral changes that affect code
|
||||
|
||||
- **Timers** (1.23): unstopped `Timer`/`Ticker` are GC'd immediately.
|
||||
Channels are unbuffered: no stale values after `Reset`/`Stop`. You no
|
||||
longer need `defer t.Stop()` to prevent leaks.
|
||||
- **Error tree traversal** (1.20): `errors.Is`/`As` follow
|
||||
`Unwrap() []error`, not just `Unwrap() error`. Multi-error types must
|
||||
expose the slice form for child errors to be found.
|
||||
- **`math/rand` auto-seeded** (1.20): the global RNG is auto-seeded.
|
||||
`rand.Seed` is a no-op in 1.24+. Don't call it.
|
||||
- **GODEBUG compat** (1.21): behavioral changes are gated by `go.mod`'s
|
||||
`go` line. Upgrading the version opts into new defaults.
|
||||
- **Build tags** (1.18): `//go:build` is the only syntax. `// +build`
|
||||
is gone.
|
||||
- **Tool install** (1.18): `go get` no longer builds. Use
|
||||
`go install pkg@version`.
|
||||
- **Doc comments** (1.19): support `[links]`, lists, and headings.
|
||||
- **`go test -skip`** (1.20): skip tests by name pattern from the
|
||||
command line.
|
||||
- **`go fix ./...` modernizers** (1.26): auto-rewrites code to use
|
||||
newer idioms. Run after Go version upgrades.
|
||||
|
||||
## Transparent improvements (no code changes)
|
||||
|
||||
Swiss Tables maps, Green Tea GC, PGO, faster `io.ReadAll`,
|
||||
stack-allocated slices, reduced cgo overhead, container-aware
|
||||
GOMAXPROCS. Free on upgrade.
|
||||
@@ -1,4 +0,0 @@
|
||||
# All artifacts of the build processed are dumped here.
|
||||
# Ignore it for docker context, as all Dockerfiles should build their own
|
||||
# binaries.
|
||||
build
|
||||
@@ -4,7 +4,10 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.25.7"
|
||||
default: "1.25.6"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
use-cache:
|
||||
description: "Whether to use the cache."
|
||||
default: "true"
|
||||
@@ -12,9 +15,9 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5.6.0
|
||||
uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2
|
||||
with:
|
||||
go-version: ${{ inputs.version }}
|
||||
go-version: ${{ inputs.use-preinstalled-go == 'false' && inputs.version || '' }}
|
||||
cache: ${{ inputs.use-cache }}
|
||||
|
||||
- name: Install gotestsum
|
||||
|
||||
@@ -7,5 +7,5 @@ runs:
|
||||
- name: Install Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
|
||||
with:
|
||||
terraform_version: 1.14.5
|
||||
terraform_version: 1.14.1
|
||||
terraform_wrapper: false
|
||||
|
||||
+21
-25
@@ -35,7 +35,7 @@ jobs:
|
||||
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -272,7 +272,7 @@ jobs:
|
||||
if: ${{ !cancelled() }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -329,7 +329,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -381,7 +381,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -422,6 +422,10 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
# Runners have Go baked-in and Go will automatically
|
||||
# download the toolchain configured in go.mod, so we don't
|
||||
# need to reinstall it. It's faster on Windows runners.
|
||||
use-preinstalled-go: ${{ runner.os == 'Windows' }}
|
||||
use-cache: true
|
||||
|
||||
- name: Setup Terraform
|
||||
@@ -485,14 +489,6 @@ jobs:
|
||||
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
|
||||
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
|
||||
|
||||
- name: Increase PTY limit (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
shell: bash
|
||||
run: |
|
||||
# Increase PTY limit to avoid exhaustion during tests.
|
||||
# Default is 511; 999 is the maximum value on CI runner.
|
||||
sudo sysctl -w kern.tty.ptmx_max=999
|
||||
|
||||
- name: Test with PostgreSQL Database (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
uses: ./.github/actions/test-go-pg
|
||||
@@ -582,7 +578,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -644,7 +640,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -716,7 +712,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -743,7 +739,7 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -776,7 +772,7 @@ jobs:
|
||||
name: ${{ matrix.variant.name }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -856,7 +852,7 @@ jobs:
|
||||
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -937,7 +933,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1009,7 +1005,7 @@ jobs:
|
||||
if: always()
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1124,7 +1120,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1179,7 +1175,7 @@ jobs:
|
||||
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -1576,7 +1572,7 @@ jobs:
|
||||
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
# This workflow triggers a Vercel deploy hook which builds+deploys coder.com
|
||||
# (a Next.js app), to keep coder.com/docs URLs in sync with docs/manifest.json
|
||||
#
|
||||
# https://vercel.com/docs/deploy-hooks#triggering-a-deploy-hook
|
||||
|
||||
name: Update coder.com/docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/manifest.json"
|
||||
|
||||
jobs:
|
||||
deploy-docs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Deploy docs site
|
||||
run: |
|
||||
curl -X POST "${{ secrets.DEPLOY_DOCS_VERCEL_WEBHOOK }}"
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
packages: write # to retag image as dogfood
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
needs: deploy
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
if: github.repository_owner == 'coder'
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -58,11 +58,11 @@ jobs:
|
||||
run: mkdir base-build-context
|
||||
|
||||
- name: Install depot.dev CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
|
||||
|
||||
# This uses OIDC authentication, so no auth variables are required.
|
||||
- name: Build base Docker image via depot.dev
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
with:
|
||||
project: wl5hnrrkns
|
||||
context: base-build-context
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
BRANCH_NAME: ${{ steps.branch-name.outputs.current_branch }}
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and push Non-Nix image
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
with:
|
||||
project: b4q6ltmpzh
|
||||
token: ${{ secrets.DEPOT_TOKEN }}
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
- windows-2022
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -64,6 +64,11 @@ jobs:
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
with:
|
||||
# Runners have Go baked-in and Go will automatically
|
||||
# download the toolchain configured in go.mod, so we don't
|
||||
# need to reinstall it. It's faster on Windows runners.
|
||||
use-preinstalled-go: ${{ runner.os == 'Windows' }}
|
||||
|
||||
- name: Setup Terraform
|
||||
uses: ./.github/actions/setup-tf
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
runs-on: "ubuntu-latest"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
pull-requests: write # needed for commenting on PRs
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ jobs:
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -386,12 +386,12 @@ jobs:
|
||||
|
||||
- name: Install depot.dev CLI
|
||||
if: steps.image-base-tag.outputs.tag != ''
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
|
||||
|
||||
# This uses OIDC authentication, so no auth variables are required.
|
||||
- name: Build base Docker image via depot.dev
|
||||
if: steps.image-base-tag.outputs.tag != ''
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
|
||||
with:
|
||||
project: wl5hnrrkns
|
||||
context: base-build-context
|
||||
@@ -796,7 +796,7 @@ jobs:
|
||||
# TODO: skip this if it's not a new release (i.e. a backport). This is
|
||||
# fine right now because it just makes a PR that we can close.
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -872,7 +872,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -955,3 +955,35 @@ jobs:
|
||||
# different repo.
|
||||
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
|
||||
VERSION: ${{ needs.release.outputs.version }}
|
||||
|
||||
# publish-sqlc pushes the latest schema to sqlc cloud.
|
||||
# At present these pushes cannot be tagged, so the last push is always the latest.
|
||||
publish-sqlc:
|
||||
name: "Publish to schema sqlc cloud"
|
||||
runs-on: "ubuntu-latest"
|
||||
needs: release
|
||||
if: ${{ !inputs.dry_run }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 1
|
||||
persist-credentials: false
|
||||
|
||||
# We need golang to run the migration main.go
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Setup sqlc
|
||||
uses: ./.github/actions/setup-sqlc
|
||||
|
||||
- name: Push schema to sqlc cloud
|
||||
# Don't block a release on this
|
||||
continue-on-error: true
|
||||
run: |
|
||||
make sqlc-push
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
|
||||
with:
|
||||
image-ref: ${{ steps.build.outputs.image }}
|
||||
format: sarif
|
||||
|
||||
@@ -18,12 +18,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: stale
|
||||
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
with:
|
||||
stale-issue-label: "stale"
|
||||
stale-pr-label: "stale"
|
||||
@@ -96,7 +96,7 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
actions: write
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
pull-requests: write # required to post PR review comments by the action
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
|
||||
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
|
||||
@@ -98,6 +98,3 @@ AGENTS.local.md
|
||||
|
||||
# Ignore plans written by AI agents.
|
||||
PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
|
||||
@@ -198,12 +198,13 @@ reviewer time and clutters the diff.
|
||||
**Don't delete existing comments** that explain non-obvious behavior. These
|
||||
comments preserve important context about why code works a certain way.
|
||||
|
||||
**When adding tests for new behavior**, read existing tests first to understand what's covered. Add new cases for uncovered behavior. Edit existing tests as needed, but don't change what they verify.
|
||||
**When adding tests for new behavior**, add new test cases instead of modifying
|
||||
existing ones. This preserves coverage for the original behavior and makes it
|
||||
clear what the new test covers.
|
||||
|
||||
## Detailed Development Guides
|
||||
|
||||
@.claude/docs/ARCHITECTURE.md
|
||||
@.claude/docs/GO.md
|
||||
@.claude/docs/OAUTH2.md
|
||||
@.claude/docs/TESTING.md
|
||||
@.claude/docs/TROUBLESHOOTING.md
|
||||
|
||||
@@ -427,7 +427,6 @@ SITE_GEN_FILES := \
|
||||
site/src/api/typesGenerated.ts \
|
||||
site/src/api/rbacresourcesGenerated.ts \
|
||||
site/src/api/countriesGenerated.ts \
|
||||
site/src/api/chatModelOptionsGenerated.json \
|
||||
site/src/theme/icons.json
|
||||
|
||||
site/out/index.html: \
|
||||
@@ -655,7 +654,6 @@ GEN_FILES := \
|
||||
tailnet/proto/tailnet.pb.go \
|
||||
agent/proto/agent.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
agent/boundarylogproxy/codec/boundary.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
@@ -711,7 +709,6 @@ gen/mark-fresh:
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
agent/agentsocket/proto/agentsocket.pb.go \
|
||||
agent/boundarylogproxy/codec/boundary.pb.go \
|
||||
vpn/vpn.pb.go \
|
||||
enterprise/aibridged/proto/aibridged.pb.go \
|
||||
coderd/database/dump.sql \
|
||||
@@ -722,7 +719,6 @@ gen/mark-fresh:
|
||||
coderd/rbac/scopes_constants_gen.go \
|
||||
site/src/api/rbacresourcesGenerated.ts \
|
||||
site/src/api/countriesGenerated.ts \
|
||||
site/src/api/chatModelOptionsGenerated.json \
|
||||
docs/admin/integrations/prometheus.md \
|
||||
docs/reference/cli/index.md \
|
||||
docs/admin/security/audit-logs.md \
|
||||
@@ -847,12 +843,6 @@ vpn/vpn.pb.go: vpn/vpn.proto
|
||||
--go_opt=paths=source_relative \
|
||||
./vpn/vpn.proto
|
||||
|
||||
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
./agent/boundarylogproxy/codec/boundary.proto
|
||||
|
||||
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
@@ -864,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
|
||||
# -C sets the directory for the go run command
|
||||
go run -C ./scripts/apitypings main.go > $@
|
||||
./scripts/biome_format.sh src/api/typesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
@@ -873,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
|
||||
go run ./scripts/gensite/ -icons "$@"
|
||||
./scripts/biome_format.sh src/theme/icons.json
|
||||
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
|
||||
touch "$@"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
|
||||
@@ -911,22 +901,15 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
|
||||
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
|
||||
go run scripts/typegen/main.go rbac typescript > "$@"
|
||||
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
|
||||
go run scripts/typegen/main.go countries > "$@"
|
||||
./scripts/biome_format.sh src/api/countriesGenerated.ts
|
||||
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
|
||||
touch "$@"
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
|
||||
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$@"
|
||||
cd site && pnpm biome format --write src/api/chatModelOptionsGenerated.json
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
|
||||
go run ./scripts/metricsdocgen/scanner > $@
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics
|
||||
go run scripts/metricsdocgen/main.go
|
||||
pnpm exec markdownlint-cli2 --fix ./docs/admin/integrations/prometheus.md
|
||||
pnpm exec markdown-table-formatter ./docs/admin/integrations/prometheus.md
|
||||
@@ -964,11 +947,11 @@ coderd/apidoc/.gen: \
|
||||
touch "$@"
|
||||
|
||||
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
|
||||
./scripts/biome_format.sh ../docs/manifest.json
|
||||
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
|
||||
touch "$@"
|
||||
|
||||
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
|
||||
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
|
||||
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
|
||||
touch "$@"
|
||||
|
||||
update-golden-files:
|
||||
@@ -1013,19 +996,11 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
|
||||
touch "$@"
|
||||
|
||||
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
|
||||
fi
|
||||
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
|
||||
touch "$@"
|
||||
|
||||
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
|
||||
if command -v helm >/dev/null 2>&1; then
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
else
|
||||
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
|
||||
fi
|
||||
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
|
||||
touch "$@"
|
||||
|
||||
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
|
||||
|
||||
+5
-25
@@ -41,7 +41,6 @@ import (
|
||||
"github.com/coder/coder/v2/agent/agentcontainers"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentfiles"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/agent/agentscripts"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
@@ -112,12 +111,6 @@ type Client interface {
|
||||
ConnectRPC28(ctx context.Context) (
|
||||
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
|
||||
)
|
||||
// ConnectRPC28WithRole is like ConnectRPC28 but sends an explicit
|
||||
// role query parameter to the server. The workspace agent should
|
||||
// use role "agent" to enable connection monitoring.
|
||||
ConnectRPC28WithRole(ctx context.Context, role string) (
|
||||
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
|
||||
)
|
||||
tailnet.DERPMapRewriter
|
||||
agentsdk.RefreshableSessionTokenProvider
|
||||
}
|
||||
@@ -303,8 +296,7 @@ type agent struct {
|
||||
containerAPIOptions []agentcontainers.Option
|
||||
containerAPI *agentcontainers.API
|
||||
|
||||
filesAPI *agentfiles.API
|
||||
processAPI *agentproc.API
|
||||
filesAPI *agentfiles.API
|
||||
|
||||
socketServerEnabled bool
|
||||
socketPath string
|
||||
@@ -377,7 +369,6 @@ func (a *agent) init() {
|
||||
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
|
||||
|
||||
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
|
||||
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
|
||||
|
||||
a.reconnectingPTYServer = reconnectingpty.NewServer(
|
||||
a.logger.Named("reconnecting-pty"),
|
||||
@@ -410,7 +401,7 @@ func (a *agent) initSocketServer() {
|
||||
agentsocket.WithPath(a.socketPath),
|
||||
)
|
||||
if err != nil {
|
||||
a.logger.Error(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
|
||||
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,12 +411,7 @@ func (a *agent) initSocketServer() {
|
||||
|
||||
// startBoundaryLogProxyServer starts the boundary log proxy socket server.
|
||||
func (a *agent) startBoundaryLogProxyServer() {
|
||||
if a.boundaryLogProxySocketPath == "" {
|
||||
a.logger.Warn(a.hardCtx, "boundary log proxy socket path not defined; not starting proxy")
|
||||
return
|
||||
}
|
||||
|
||||
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath, a.prometheusRegistry)
|
||||
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath)
|
||||
if err := proxy.Start(); err != nil {
|
||||
a.logger.Warn(a.hardCtx, "failed to start boundary log proxy", slog.Error(err))
|
||||
return
|
||||
@@ -1011,10 +997,8 @@ func (a *agent) run() (retErr error) {
|
||||
return xerrors.Errorf("refresh token: %w", err)
|
||||
}
|
||||
|
||||
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs.
|
||||
// We pass role "agent" to enable connection monitoring on the server, which tracks
|
||||
// the agent's connectivity state (first_connected_at, last_connected_at, disconnected_at).
|
||||
aAPI, tAPI, err := a.client.ConnectRPC28WithRole(a.hardCtx, "agent")
|
||||
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
|
||||
aAPI, tAPI, err := a.client.ConnectRPC28(a.hardCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2038,10 +2022,6 @@ func (a *agent) Close() error {
|
||||
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if err := a.processAPI.Close(); err != nil {
|
||||
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
|
||||
}
|
||||
|
||||
if a.boundaryLogProxy != nil {
|
||||
err = a.boundaryLogProxy.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -29,7 +29,6 @@ func (api *API) Routes() http.Handler {
|
||||
|
||||
r.Post("/list-directory", api.HandleLS)
|
||||
r.Get("/read-file", api.HandleReadFile)
|
||||
r.Get("/read-file-lines", api.HandleReadFileLines)
|
||||
r.Post("/write-file", api.HandleWriteFile)
|
||||
r.Post("/edit-files", api.HandleEditFiles)
|
||||
|
||||
|
||||
+7
-283
@@ -10,10 +10,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/icholy/replace"
|
||||
"github.com/spf13/afero"
|
||||
"golang.org/x/text/transform"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -22,22 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// ReadFileLinesResponse is the JSON response for the line-based file reader.
|
||||
type ReadFileLinesResponse struct {
|
||||
// Success indicates whether the read was successful.
|
||||
Success bool `json:"success"`
|
||||
// FileSize is the original file size in bytes.
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
// TotalLines is the total number of lines in the file.
|
||||
TotalLines int `json:"total_lines,omitempty"`
|
||||
// LinesRead is the count of lines returned in this response.
|
||||
LinesRead int `json:"lines_read,omitempty"`
|
||||
// Content is the line-numbered file content.
|
||||
Content string `json:"content,omitempty"`
|
||||
// Error is the error message when success is false.
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type HTTPResponseCode = int
|
||||
|
||||
func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
|
||||
@@ -118,166 +103,6 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (api *API) HandleReadFileLines(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
query := r.URL.Query()
|
||||
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
|
||||
path := parser.String(query, "", "path")
|
||||
offset := parser.PositiveInt64(query, 1, "offset")
|
||||
limit := parser.PositiveInt64(query, 0, "limit")
|
||||
maxFileSize := parser.PositiveInt64(query, workspacesdk.DefaultMaxFileSize, "max_file_size")
|
||||
maxLineBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxLineBytes, "max_line_bytes")
|
||||
maxResponseLines := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseLines, "max_response_lines")
|
||||
maxResponseBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseBytes, "max_response_bytes")
|
||||
parser.ErrorExcessParams(query)
|
||||
if len(parser.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Validations: parser.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.readFileLines(ctx, path, offset, limit, workspacesdk.ReadFileLinesLimits{
|
||||
MaxFileSize: maxFileSize,
|
||||
MaxLineBytes: int(maxLineBytes),
|
||||
MaxResponseLines: int(maxResponseLines),
|
||||
MaxResponseBytes: int(maxResponseBytes),
|
||||
})
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) readFileLines(_ context.Context, path string, offset, limit int64, limits workspacesdk.ReadFileLinesLimits) ReadFileLinesResponse {
|
||||
errResp := func(msg string) ReadFileLinesResponse {
|
||||
return ReadFileLinesResponse{Success: false, Error: msg}
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(path) {
|
||||
return errResp(fmt.Sprintf("file path must be absolute: %q", path))
|
||||
}
|
||||
|
||||
f, err := api.filesystem.Open(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return errResp(fmt.Sprintf("file does not exist: %s", path))
|
||||
}
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
return errResp(fmt.Sprintf("permission denied: %s", path))
|
||||
}
|
||||
return errResp(fmt.Sprintf("open file: %s", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
return errResp(fmt.Sprintf("stat file: %s", err))
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
return errResp(fmt.Sprintf("not a file: %s", path))
|
||||
}
|
||||
|
||||
fileSize := stat.Size()
|
||||
if fileSize > limits.MaxFileSize {
|
||||
return errResp(fmt.Sprintf(
|
||||
"file is %d bytes which exceeds the maximum of %d bytes. Use grep, sed, or awk to extract the content you need, or use offset and limit to read a portion.",
|
||||
fileSize, limits.MaxFileSize,
|
||||
))
|
||||
}
|
||||
|
||||
// Read the entire file (up to MaxFileSize).
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return errResp(fmt.Sprintf("read file: %s", err))
|
||||
}
|
||||
|
||||
// Split into lines.
|
||||
content := string(data)
|
||||
// Handle empty file.
|
||||
if content == "" {
|
||||
return ReadFileLinesResponse{
|
||||
Success: true,
|
||||
FileSize: fileSize,
|
||||
TotalLines: 0,
|
||||
LinesRead: 0,
|
||||
Content: "",
|
||||
}
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
totalLines := len(lines)
|
||||
|
||||
// offset is 1-based line number.
|
||||
if offset < 1 {
|
||||
offset = 1
|
||||
}
|
||||
if offset > int64(totalLines) {
|
||||
return errResp(fmt.Sprintf(
|
||||
"offset %d is beyond the file length of %d lines",
|
||||
offset, totalLines,
|
||||
))
|
||||
}
|
||||
|
||||
// Default limit.
|
||||
if limit <= 0 {
|
||||
limit = int64(limits.MaxResponseLines)
|
||||
}
|
||||
|
||||
startIdx := int(offset - 1) // convert to 0-based
|
||||
endIdx := startIdx + int(limit)
|
||||
if endIdx > totalLines {
|
||||
endIdx = totalLines
|
||||
}
|
||||
|
||||
var numbered []string
|
||||
totalBytesAccumulated := 0
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
line := lines[i]
|
||||
|
||||
// Per-line truncation.
|
||||
if len(line) > limits.MaxLineBytes {
|
||||
line = line[:limits.MaxLineBytes] + "... [truncated]"
|
||||
}
|
||||
|
||||
// Format with 1-based line number.
|
||||
numberedLine := fmt.Sprintf("%d\t%s", i+1, line)
|
||||
lineBytes := len(numberedLine)
|
||||
|
||||
// Check total byte budget.
|
||||
newTotal := totalBytesAccumulated + lineBytes
|
||||
if len(numbered) > 0 {
|
||||
newTotal++ // account for \n joiner
|
||||
}
|
||||
if newTotal > limits.MaxResponseBytes {
|
||||
return errResp(fmt.Sprintf(
|
||||
"output would exceed %d bytes. Read less at a time using offset and limit parameters.",
|
||||
limits.MaxResponseBytes,
|
||||
))
|
||||
}
|
||||
|
||||
// Check line count.
|
||||
if len(numbered) >= limits.MaxResponseLines {
|
||||
return errResp(fmt.Sprintf(
|
||||
"output would exceed %d lines. Read less at a time using offset and limit parameters.",
|
||||
limits.MaxResponseLines,
|
||||
))
|
||||
}
|
||||
|
||||
numbered = append(numbered, numberedLine)
|
||||
totalBytesAccumulated = newTotal
|
||||
}
|
||||
|
||||
return ReadFileLinesResponse{
|
||||
Success: true,
|
||||
FileSize: fileSize,
|
||||
TotalLines: totalLines,
|
||||
LinesRead: len(numbered),
|
||||
Content: strings.Join(numbered, "\n"),
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -420,21 +245,9 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, xerrors.Errorf("read %s: %w", path, err)
|
||||
}
|
||||
content := string(data)
|
||||
|
||||
for _, edit := range edits {
|
||||
var ok bool
|
||||
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
|
||||
if !ok {
|
||||
api.logger.Warn(ctx, "edit search string not found, skipping",
|
||||
slog.F("path", path),
|
||||
slog.F("search_preview", truncate(edit.Search, 64)),
|
||||
)
|
||||
}
|
||||
transforms := make([]transform.Transformer, len(edits))
|
||||
for i, edit := range edits {
|
||||
transforms[i] = replace.String(edit.Search, edit.Replace)
|
||||
}
|
||||
|
||||
// Create an adjacent file to ensure it will be on the same device and can be
|
||||
@@ -445,7 +258,8 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
}
|
||||
defer tmpfile.Close()
|
||||
|
||||
if _, err := tmpfile.Write([]byte(content)); err != nil {
|
||||
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
|
||||
if err != nil {
|
||||
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
|
||||
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
|
||||
}
|
||||
@@ -459,93 +273,3 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// fuzzyReplace attempts to find `search` inside `content` and replace its first
|
||||
// occurrence with `replace`. It uses a cascading match strategy inspired by
|
||||
// openai/codex's apply_patch:
|
||||
//
|
||||
// 1. Exact substring match (byte-for-byte).
|
||||
// 2. Line-by-line match ignoring trailing whitespace on each line.
|
||||
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
|
||||
//
|
||||
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
|
||||
// at the byte offsets of the original content so that surrounding text (including
|
||||
// indentation of untouched lines) is preserved.
|
||||
//
|
||||
// Returns the (possibly modified) content and a bool indicating whether a match
|
||||
// was found.
|
||||
func fuzzyReplace(content, search, replace string) (string, bool) {
|
||||
// Pass 1 – exact substring (replace all occurrences).
|
||||
if strings.Contains(content, search) {
|
||||
return strings.ReplaceAll(content, search, replace), true
|
||||
}
|
||||
|
||||
// For line-level fuzzy matching we split both content and search into lines.
|
||||
contentLines := strings.SplitAfter(content, "\n")
|
||||
searchLines := strings.SplitAfter(search, "\n")
|
||||
|
||||
// A trailing newline in the search produces an empty final element from
|
||||
// SplitAfter. Drop it so it doesn't interfere with line matching.
|
||||
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
|
||||
searchLines = searchLines[:len(searchLines)-1]
|
||||
}
|
||||
|
||||
// Pass 2 – trim trailing whitespace on each line.
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
// Pass 3 – trim all leading and trailing whitespace (indentation-tolerant).
|
||||
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
|
||||
return strings.TrimSpace(a) == strings.TrimSpace(b)
|
||||
}); ok {
|
||||
return spliceLines(contentLines, start, end, replace), true
|
||||
}
|
||||
|
||||
return content, false
|
||||
}
|
||||
|
||||
// seekLines scans contentLines looking for a contiguous subsequence that matches
|
||||
// searchLines according to the provided `eq` function. It returns the start and
|
||||
// end (exclusive) indices into contentLines of the match.
|
||||
func seekLines(contentLines, searchLines []string, eq func(a, b string) bool) (start, end int, ok bool) {
|
||||
if len(searchLines) == 0 {
|
||||
return 0, 0, true
|
||||
}
|
||||
if len(searchLines) > len(contentLines) {
|
||||
return 0, 0, false
|
||||
}
|
||||
outer:
|
||||
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
|
||||
for j, sLine := range searchLines {
|
||||
if !eq(contentLines[i+j], sLine) {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
return i, i + len(searchLines), true
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
// spliceLines replaces contentLines[start:end] with replacement text, returning
|
||||
// the full content as a single string.
|
||||
func spliceLines(contentLines []string, start, end int, replacement string) string {
|
||||
var b strings.Builder
|
||||
for _, l := range contentLines[:start] {
|
||||
_, _ = b.WriteString(l)
|
||||
}
|
||||
_, _ = b.WriteString(replacement)
|
||||
for _, l := range contentLines[end:] {
|
||||
_, _ = b.WriteString(l)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
@@ -649,106 +649,6 @@ func TestEditFiles(t *testing.T) {
|
||||
filepath.Join(tmpdir, "file3"): "edited3 3",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TrailingWhitespace",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "foo \nbar\t\t\nbaz"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "trailing-ws"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "foo\nbar\nbaz",
|
||||
Replace: "replaced",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "replaced"},
|
||||
},
|
||||
{
|
||||
name: "TabsVsSpaces",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tfoo()\n\t}"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "tabs-vs-spaces"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
// Search uses spaces but file uses tabs.
|
||||
Search: " if true {\n foo()\n }",
|
||||
Replace: "\tif true {\n\t\tbar()\n\t}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tbar()\n\t}"},
|
||||
},
|
||||
{
|
||||
name: "DifferentIndentDepth",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tnested()"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "indent-depth"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
// Search has wrong indent depth (1 tab instead of 3).
|
||||
Search: "\tdeep()\n\tnested()",
|
||||
Replace: "\t\t\tdeep()\n\t\t\tchanged()",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tchanged()"},
|
||||
},
|
||||
{
|
||||
name: "ExactMatchPreferred",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "hello world"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "exact-preferred"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "hello world",
|
||||
Replace: "goodbye world",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
|
||||
},
|
||||
{
|
||||
name: "NoMatchStillSucceeds",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "no-match"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
Search: "this does not exist in the file",
|
||||
Replace: "whatever",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// File should remain unchanged.
|
||||
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
|
||||
},
|
||||
{
|
||||
name: "MixedWhitespaceMultiline",
|
||||
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
|
||||
edits: []workspacesdk.FileEdits{
|
||||
{
|
||||
Path: filepath.Join(tmpdir, "mixed-ws"),
|
||||
Edits: []workspacesdk.FileEdit{
|
||||
{
|
||||
// Search uses spaces, file uses tabs.
|
||||
Search: " result := compute()\n fmt.Println(result)\n",
|
||||
Replace: "\tresult := compute()\n\tlog.Println(result)\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tlog.Println(result)\n}"},
|
||||
},
|
||||
{
|
||||
name: "MultiError",
|
||||
contents: map[string]string{
|
||||
@@ -837,188 +737,3 @@ func TestEditFiles(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpdir := os.TempDir()
|
||||
noPermsFilePath := filepath.Join(tmpdir, "no-perms-lines")
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
fs := newTestFs(afero.NewMemMapFs(), func(call, file string) error {
|
||||
if file == noPermsFilePath {
|
||||
return os.ErrPermission
|
||||
}
|
||||
return nil
|
||||
})
|
||||
api := agentfiles.NewAPI(logger, fs)
|
||||
|
||||
dirPath := filepath.Join(tmpdir, "a-directory-lines")
|
||||
err := fs.MkdirAll(dirPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyFilePath := filepath.Join(tmpdir, "empty-file")
|
||||
err = afero.WriteFile(fs, emptyFilePath, []byte(""), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
basicFilePath := filepath.Join(tmpdir, "basic-file")
|
||||
err = afero.WriteFile(fs, basicFilePath, []byte("line1\nline2\nline3"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
longLine := string(bytes.Repeat([]byte("x"), 1025))
|
||||
longLineFilePath := filepath.Join(tmpdir, "long-line-file")
|
||||
err = afero.WriteFile(fs, longLineFilePath, []byte(longLine), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
largeFilePath := filepath.Join(tmpdir, "large-file")
|
||||
err = afero.WriteFile(fs, largeFilePath, bytes.Repeat([]byte("x"), 1<<20+1), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
offset int64
|
||||
limit int64
|
||||
expSuccess bool
|
||||
expError string
|
||||
expContent string
|
||||
expTotal int
|
||||
expRead int
|
||||
expSize int64
|
||||
// useCodersdk is set for cases where the handler returns
|
||||
// codersdk.Response (query param validation) instead of ReadFileLinesResponse.
|
||||
useCodersdk bool
|
||||
}{
|
||||
{
|
||||
name: "NoPath",
|
||||
path: "",
|
||||
useCodersdk: true,
|
||||
expError: "is required",
|
||||
},
|
||||
{
|
||||
name: "RelativePath",
|
||||
path: "relative/path",
|
||||
expError: "file path must be absolute",
|
||||
},
|
||||
{
|
||||
name: "NonExistent",
|
||||
path: filepath.Join(tmpdir, "does-not-exist"),
|
||||
expError: "file does not exist",
|
||||
},
|
||||
{
|
||||
name: "IsDir",
|
||||
path: dirPath,
|
||||
expError: "not a file",
|
||||
},
|
||||
{
|
||||
name: "NoPermissions",
|
||||
path: noPermsFilePath,
|
||||
expError: "permission denied",
|
||||
},
|
||||
{
|
||||
name: "EmptyFile",
|
||||
path: emptyFilePath,
|
||||
expSuccess: true,
|
||||
expTotal: 0,
|
||||
expRead: 0,
|
||||
expSize: 0,
|
||||
},
|
||||
{
|
||||
name: "BasicRead",
|
||||
path: basicFilePath,
|
||||
expSuccess: true,
|
||||
expContent: "1\tline1\n2\tline2\n3\tline3",
|
||||
expTotal: 3,
|
||||
expRead: 3,
|
||||
expSize: int64(len("line1\nline2\nline3")),
|
||||
},
|
||||
{
|
||||
name: "Offset2",
|
||||
path: basicFilePath,
|
||||
offset: 2,
|
||||
expSuccess: true,
|
||||
expContent: "2\tline2\n3\tline3",
|
||||
expTotal: 3,
|
||||
expRead: 2,
|
||||
expSize: int64(len("line1\nline2\nline3")),
|
||||
},
|
||||
{
|
||||
name: "Limit1",
|
||||
path: basicFilePath,
|
||||
limit: 1,
|
||||
expSuccess: true,
|
||||
expContent: "1\tline1",
|
||||
expTotal: 3,
|
||||
expRead: 1,
|
||||
expSize: int64(len("line1\nline2\nline3")),
|
||||
},
|
||||
{
|
||||
name: "Offset2Limit1",
|
||||
path: basicFilePath,
|
||||
offset: 2,
|
||||
limit: 1,
|
||||
expSuccess: true,
|
||||
expContent: "2\tline2",
|
||||
expTotal: 3,
|
||||
expRead: 1,
|
||||
expSize: int64(len("line1\nline2\nline3")),
|
||||
},
|
||||
{
|
||||
name: "OffsetBeyondFile",
|
||||
path: basicFilePath,
|
||||
offset: 100,
|
||||
expError: "offset 100 is beyond the file length of 3 lines",
|
||||
},
|
||||
{
|
||||
name: "LongLineTruncation",
|
||||
path: longLineFilePath,
|
||||
expSuccess: true,
|
||||
expContent: "1\t" + string(bytes.Repeat([]byte("x"), 1024)) + "... [truncated]",
|
||||
expTotal: 1,
|
||||
expRead: 1,
|
||||
expSize: 1025,
|
||||
},
|
||||
{
|
||||
name: "LargeFile",
|
||||
path: largeFilePath,
|
||||
expError: "exceeds the maximum",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/read-file-lines?path=%s&offset=%d&limit=%d", tt.path, tt.offset, tt.limit), nil)
|
||||
api.Routes().ServeHTTP(w, r)
|
||||
|
||||
if tt.useCodersdk {
|
||||
// Query param validation errors return codersdk.Response.
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
require.Contains(t, w.Body.String(), tt.expError)
|
||||
return
|
||||
}
|
||||
|
||||
var resp agentfiles.ReadFileLinesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.expSuccess {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.True(t, resp.Success)
|
||||
require.Equal(t, tt.expContent, resp.Content)
|
||||
require.Equal(t, tt.expTotal, resp.TotalLines)
|
||||
require.Equal(t, tt.expRead, resp.LinesRead)
|
||||
require.Equal(t, tt.expSize, resp.FileSize)
|
||||
} else {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.False(t, resp.Success)
|
||||
require.Contains(t, resp.Error, tt.expError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,175 +0,0 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
// API exposes process-related operations through the agent.
|
||||
type API struct {
|
||||
logger slog.Logger
|
||||
manager *manager
|
||||
}
|
||||
|
||||
// NewAPI creates a new process API handler.
|
||||
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
manager: newManager(logger, execer, updateEnv),
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the process manager, killing all running
|
||||
// processes.
|
||||
func (api *API) Close() error {
|
||||
return api.manager.Close()
|
||||
}
|
||||
|
||||
// Routes returns the HTTP handler for process-related routes.
|
||||
func (api *API) Routes() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/start", api.handleStartProcess)
|
||||
r.Get("/list", api.handleListProcesses)
|
||||
r.Get("/{id}/output", api.handleProcessOutput)
|
||||
r.Post("/{id}/signal", api.handleSignalProcess)
|
||||
return r
|
||||
}
|
||||
|
||||
// handleStartProcess starts a new process.
|
||||
func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req workspacesdk.StartProcessRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Request body must be valid JSON.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Command == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Command is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
proc, err := api.manager.start(req)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to start process.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
|
||||
ID: proc.id,
|
||||
Started: true,
|
||||
})
|
||||
}
|
||||
|
||||
// handleListProcesses lists all tracked processes.
|
||||
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
infos := api.manager.list()
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
|
||||
Processes: infos,
|
||||
})
|
||||
}
|
||||
|
||||
// handleProcessOutput returns the output of a process.
|
||||
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
proc, ok := api.manager.get(id)
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
output, truncated := proc.output()
|
||||
info := proc.info()
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ProcessOutputResponse{
|
||||
Output: output,
|
||||
Truncated: truncated,
|
||||
Running: info.Running,
|
||||
ExitCode: info.ExitCode,
|
||||
})
|
||||
}
|
||||
|
||||
// handleSignalProcess sends a signal to a running process.
|
||||
func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
var req workspacesdk.SignalProcessRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Request body must be valid JSON.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Signal == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Signal is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Signal != "kill" && req.Signal != "terminate" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Unsupported signal %q. Use \"kill\" or \"terminate\".",
|
||||
req.Signal,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.manager.signal(id, req.Signal); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errProcessNotFound):
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: fmt.Sprintf("Process %q not found.", id),
|
||||
})
|
||||
case errors.Is(err, errProcessNotRunning):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Process %q is not running.", id,
|
||||
),
|
||||
})
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to signal process.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Signal %q sent to process %q.", req.Signal, id,
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -1,691 +0,0 @@
|
||||
package agentproc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// postStart sends a POST /start request and returns the recorder.
|
||||
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// getList sends a GET /list request and returns the recorder.
|
||||
func getList(t *testing.T, handler http.Handler) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// getOutput sends a GET /{id}/output request and returns the
|
||||
// recorder.
|
||||
func getOutput(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/%s/output", id), nil)
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// postSignal sends a POST /{id}/signal request and returns
|
||||
// the recorder.
|
||||
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("/%s/signal", id), bytes.NewReader(body))
|
||||
handler.ServeHTTP(w, r)
|
||||
return w
|
||||
}
|
||||
|
||||
// newTestAPI creates a new API with a test logger and default
|
||||
// execer, returning the handler and API.
|
||||
func newTestAPI(t *testing.T) http.Handler {
|
||||
t.Helper()
|
||||
return newTestAPIWithUpdateEnv(t, nil)
|
||||
}
|
||||
|
||||
// newTestAPIWithUpdateEnv creates a new API with an optional
|
||||
// updateEnv hook for testing environment injection.
|
||||
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
|
||||
t.Cleanup(func() {
|
||||
_ = api.Close()
|
||||
})
|
||||
return api.Routes()
|
||||
}
|
||||
|
||||
// waitForExit polls the output endpoint until the process is
|
||||
// no longer running or the context expires.
|
||||
func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.ProcessOutputResponse {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for process to exit")
|
||||
case <-ticker.C:
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !resp.Running {
|
||||
return resp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startAndGetID is a helper that starts a process and returns
|
||||
// the process ID.
|
||||
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
|
||||
t.Helper()
|
||||
|
||||
w := postStart(t, handler, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
return resp.ID
|
||||
}
|
||||
|
||||
func TestStartProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ForegroundCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("BackgroundCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo background",
|
||||
Background: true,
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.StartProcessResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Started)
|
||||
require.NotEmpty(t, resp.ID)
|
||||
})
|
||||
|
||||
t.Run("EmptyCommand", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postStart(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Command is required")
|
||||
})
|
||||
|
||||
t.Run("MalformedJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", strings.NewReader("{invalid json"))
|
||||
handler.ServeHTTP(w, r)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "valid JSON")
|
||||
})
|
||||
|
||||
t.Run("CustomWorkDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Write a marker file to verify the command ran in
|
||||
// the correct directory. Comparing pwd output is
|
||||
// unreliable on Windows where Git Bash returns POSIX
|
||||
// paths.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "touch marker.txt && ls marker.txt",
|
||||
WorkDir: tmpDir,
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "marker.txt")
|
||||
})
|
||||
|
||||
t.Run("CustomEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Use a unique env var name to avoid collisions in
|
||||
// parallel tests.
|
||||
envKey := fmt.Sprintf("TEST_PROC_ENV_%d", time.Now().UnixNano())
|
||||
envVal := "custom_value_12345"
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: fmt.Sprintf("printenv %s", envKey),
|
||||
Env: map[string]string{envKey: envVal},
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
|
||||
})
|
||||
|
||||
t.Run("UpdateEnvHook", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
envKey := fmt.Sprintf("TEST_UPDATE_ENV_%d", time.Now().UnixNano())
|
||||
envVal := "injected_by_hook"
|
||||
|
||||
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
|
||||
return append(current, fmt.Sprintf("%s=%s", envKey, envVal)), nil
|
||||
})
|
||||
|
||||
// The process should see the variable even though it
|
||||
// was not passed in req.Env.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: fmt.Sprintf("printenv %s", envKey),
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
|
||||
})
|
||||
|
||||
t.Run("UpdateEnvHookOverriddenByReqEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
envKey := fmt.Sprintf("TEST_OVERRIDE_%d", time.Now().UnixNano())
|
||||
hookVal := "from_hook"
|
||||
reqVal := "from_request"
|
||||
|
||||
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
|
||||
return append(current, fmt.Sprintf("%s=%s", envKey, hookVal)), nil
|
||||
})
|
||||
|
||||
// req.Env should take precedence over the hook.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: fmt.Sprintf("printenv %s", envKey),
|
||||
Env: map[string]string{envKey: reqVal},
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
// When duplicate env vars exist, shells use the last
|
||||
// value. Since req.Env is appended after the hook,
|
||||
// the request value wins.
|
||||
require.Contains(t, strings.TrimSpace(resp.Output), reqVal)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListProcesses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoProcesses", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Processes)
|
||||
require.Empty(t, resp.Processes)
|
||||
})
|
||||
|
||||
t.Run("MixedRunningAndExited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start a process that exits quickly.
|
||||
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
waitForExit(t, handler, exitedID)
|
||||
|
||||
// Start a long-running process.
|
||||
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
// List should contain both.
|
||||
w := getList(t, handler)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ListProcessesResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Processes, 2)
|
||||
|
||||
procMap := make(map[string]workspacesdk.ProcessInfo)
|
||||
for _, p := range resp.Processes {
|
||||
procMap[p.ID] = p
|
||||
}
|
||||
|
||||
exited, ok := procMap[exitedID]
|
||||
require.True(t, ok, "exited process should be in list")
|
||||
require.False(t, exited.Running)
|
||||
require.NotNil(t, exited.ExitCode)
|
||||
|
||||
running, ok := procMap[runningID]
|
||||
require.True(t, ok, "running process should be in list")
|
||||
require.True(t, running.Running)
|
||||
|
||||
// Clean up the long-running process.
|
||||
sw := postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, sw.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ExitedProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo hello-output",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "hello-output")
|
||||
})
|
||||
|
||||
t.Run("RunningProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.Running)
|
||||
|
||||
// Kill and wait for the process so cleanup does
|
||||
// not hang.
|
||||
postSignal(
|
||||
t, handler, id,
|
||||
workspacesdk.SignalProcessRequest{Signal: "kill"},
|
||||
)
|
||||
waitForExit(t, handler, id)
|
||||
})
|
||||
|
||||
t.Run("NonexistentProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := getOutput(t, handler, "nonexistent-id-12345")
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSignalProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("KillRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify the process exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
})
|
||||
|
||||
t.Run("TerminateRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("SIGTERM is not supported on Windows")
|
||||
}
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "terminate",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Verify the process exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
})
|
||||
|
||||
t.Run("NonexistentProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
w := postSignal(t, handler, "nonexistent-id-12345", workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
})
|
||||
|
||||
t.Run("AlreadyExitedProcess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo done",
|
||||
})
|
||||
|
||||
// Wait for exit first.
|
||||
waitForExit(t, handler, id)
|
||||
|
||||
// Signaling an exited process should return 409
|
||||
// Conflict via the errProcessNotRunning sentinel.
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
assert.Equal(t, http.StatusConflict, w.Code,
|
||||
"expected 409 for signaling exited process, got %d", w.Code)
|
||||
})
|
||||
|
||||
t.Run("EmptySignal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Signal is required")
|
||||
|
||||
// Clean up.
|
||||
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("InvalidSignal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "SIGFOO",
|
||||
})
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
|
||||
var resp codersdk.Response
|
||||
err := json.NewDecoder(w.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, resp.Message, "Unsupported signal")
|
||||
|
||||
// Clean up.
|
||||
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("StartWaitCheckOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo lifecycle-test && echo second-line",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
require.Contains(t, resp.Output, "lifecycle-test")
|
||||
require.Contains(t, resp.Output, "second-line")
|
||||
})
|
||||
|
||||
t.Run("NonZeroExitCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "exit 42",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 42, *resp.ExitCode)
|
||||
})
|
||||
|
||||
t.Run("StartSignalVerifyExit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Start a long-running background process.
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "sleep 300",
|
||||
Background: true,
|
||||
})
|
||||
|
||||
// Verify it's running.
|
||||
w := getOutput(t, handler, id)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var running workspacesdk.ProcessOutputResponse
|
||||
err := json.NewDecoder(w.Body).Decode(&running)
|
||||
require.NoError(t, err)
|
||||
require.True(t, running.Running)
|
||||
|
||||
// Signal it.
|
||||
sw := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
|
||||
Signal: "kill",
|
||||
})
|
||||
require.Equal(t, http.StatusOK, sw.Code)
|
||||
|
||||
// Verify it exits.
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
})
|
||||
|
||||
t.Run("OutputExceedsBuffer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
// Generate output that exceeds MaxHeadBytes +
|
||||
// MaxTailBytes. Each line is ~100 chars, and we
|
||||
// need more than 32KB total (16KB head + 16KB
|
||||
// tail).
|
||||
lineCount := (agentproc.MaxHeadBytes+agentproc.MaxTailBytes)/50 + 500
|
||||
cmd := fmt.Sprintf(
|
||||
"for i in $(seq 1 %d); do echo \"line-$i-padding-to-make-this-longer-than-fifty-characters-total\"; done",
|
||||
lineCount,
|
||||
)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: cmd,
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
|
||||
// The output should be truncated with head/tail
|
||||
// strategy metadata.
|
||||
require.NotNil(t, resp.Truncated, "large output should be truncated")
|
||||
require.Equal(t, "head_tail", resp.Truncated.Strategy)
|
||||
require.Greater(t, resp.Truncated.OmittedBytes, 0)
|
||||
require.Greater(t, resp.Truncated.OriginalBytes, resp.Truncated.RetainedBytes)
|
||||
|
||||
// Verify the output contains the omission marker.
|
||||
require.Contains(t, resp.Output, "... [omitted")
|
||||
})
|
||||
|
||||
t.Run("StderrCaptured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newTestAPI(t)
|
||||
|
||||
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
|
||||
Command: "echo stdout-msg && echo stderr-msg >&2",
|
||||
})
|
||||
|
||||
resp := waitForExit(t, handler, id)
|
||||
require.False(t, resp.Running)
|
||||
require.NotNil(t, resp.ExitCode)
|
||||
require.Equal(t, 0, *resp.ExitCode)
|
||||
// Both stdout and stderr should be captured.
|
||||
require.Contains(t, resp.Output, "stdout-msg")
|
||||
require.Contains(t, resp.Output, "stderr-msg")
|
||||
})
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxHeadBytes is the number of bytes retained from the
|
||||
// beginning of the output for LLM consumption.
|
||||
MaxHeadBytes = 16 << 10 // 16KB
|
||||
|
||||
// MaxTailBytes is the number of bytes retained from the
|
||||
// end of the output for LLM consumption.
|
||||
MaxTailBytes = 16 << 10 // 16KB
|
||||
|
||||
// MaxLineLength is the maximum length of a single line
|
||||
// before it is truncated. This prevents minified files
|
||||
// or other long single-line output from consuming the
|
||||
// entire buffer.
|
||||
MaxLineLength = 2048
|
||||
|
||||
// lineTruncationSuffix is appended to lines that exceed
|
||||
// MaxLineLength.
|
||||
lineTruncationSuffix = " ... [truncated]"
|
||||
)
|
||||
|
||||
// HeadTailBuffer is a thread-safe buffer that captures process
|
||||
// output and provides head+tail truncation for LLM consumption.
|
||||
// It implements io.Writer so it can be used directly as
|
||||
// cmd.Stdout or cmd.Stderr.
|
||||
//
|
||||
// The buffer stores up to MaxHeadBytes from the beginning of
|
||||
// the output and up to MaxTailBytes from the end in a ring
|
||||
// buffer, keeping total memory usage bounded regardless of
|
||||
// how much output is written.
|
||||
type HeadTailBuffer struct {
|
||||
mu sync.Mutex
|
||||
head []byte
|
||||
tail []byte
|
||||
tailPos int
|
||||
tailFull bool
|
||||
headFull bool
|
||||
totalBytes int
|
||||
maxHead int
|
||||
maxTail int
|
||||
}
|
||||
|
||||
// NewHeadTailBuffer creates a new HeadTailBuffer with the
|
||||
// default head and tail sizes.
|
||||
func NewHeadTailBuffer() *HeadTailBuffer {
|
||||
return &HeadTailBuffer{
|
||||
maxHead: MaxHeadBytes,
|
||||
maxTail: MaxTailBytes,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
|
||||
// head and tail sizes. This is useful for testing truncation
|
||||
// logic with smaller buffers.
|
||||
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
|
||||
return &HeadTailBuffer{
|
||||
maxHead: maxHead,
|
||||
maxTail: maxTail,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It is safe for concurrent use.
|
||||
// All bytes are accepted; the return value always equals
|
||||
// len(p) with a nil error.
|
||||
func (b *HeadTailBuffer) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
n := len(p)
|
||||
b.totalBytes += n
|
||||
|
||||
// Fill head buffer if it is not yet full.
|
||||
if !b.headFull {
|
||||
remaining := b.maxHead - len(b.head)
|
||||
if remaining > 0 {
|
||||
take := remaining
|
||||
if take > len(p) {
|
||||
take = len(p)
|
||||
}
|
||||
b.head = append(b.head, p[:take]...)
|
||||
p = p[take:]
|
||||
if len(b.head) >= b.maxHead {
|
||||
b.headFull = true
|
||||
}
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write remaining bytes into the tail ring buffer.
|
||||
b.writeTail(p)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// writeTail appends data to the tail ring buffer. The caller
|
||||
// must hold b.mu.
|
||||
func (b *HeadTailBuffer) writeTail(p []byte) {
|
||||
if b.maxTail <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Lazily allocate the tail buffer on first use.
|
||||
if b.tail == nil {
|
||||
b.tail = make([]byte, b.maxTail)
|
||||
}
|
||||
|
||||
for len(p) > 0 {
|
||||
// Write as many bytes as fit starting at tailPos.
|
||||
space := b.maxTail - b.tailPos
|
||||
take := space
|
||||
if take > len(p) {
|
||||
take = len(p)
|
||||
}
|
||||
copy(b.tail[b.tailPos:b.tailPos+take], p[:take])
|
||||
p = p[take:]
|
||||
b.tailPos += take
|
||||
if b.tailPos >= b.maxTail {
|
||||
b.tailPos = 0
|
||||
b.tailFull = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tailBytes returns the current tail contents in order. The
|
||||
// caller must hold b.mu.
|
||||
func (b *HeadTailBuffer) tailBytes() []byte {
|
||||
if b.tail == nil {
|
||||
return nil
|
||||
}
|
||||
if !b.tailFull {
|
||||
// Haven't wrapped yet; data is [0, tailPos).
|
||||
return b.tail[:b.tailPos]
|
||||
}
|
||||
// Wrapped: data is [tailPos, maxTail) + [0, tailPos).
|
||||
out := make([]byte, b.maxTail)
|
||||
n := copy(out, b.tail[b.tailPos:])
|
||||
copy(out[n:], b.tail[:b.tailPos])
|
||||
return out
|
||||
}
|
||||
|
||||
// Bytes returns a copy of the raw buffer contents. If no
|
||||
// truncation has occurred the full output is returned;
|
||||
// otherwise the head and tail portions are concatenated.
|
||||
func (b *HeadTailBuffer) Bytes() []byte {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
tail := b.tailBytes()
|
||||
if len(tail) == 0 {
|
||||
out := make([]byte, len(b.head))
|
||||
copy(out, b.head)
|
||||
return out
|
||||
}
|
||||
out := make([]byte, len(b.head)+len(tail))
|
||||
copy(out, b.head)
|
||||
copy(out[len(b.head):], tail)
|
||||
return out
|
||||
}
|
||||
|
||||
// Len returns the number of bytes currently stored in the
|
||||
// buffer.
|
||||
func (b *HeadTailBuffer) Len() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
tailLen := 0
|
||||
if b.tailFull {
|
||||
tailLen = b.maxTail
|
||||
} else if b.tail != nil {
|
||||
tailLen = b.tailPos
|
||||
}
|
||||
return len(b.head) + tailLen
|
||||
}
|
||||
|
||||
// TotalWritten returns the total number of bytes written to
|
||||
// the buffer, which may exceed the stored capacity.
|
||||
func (b *HeadTailBuffer) TotalWritten() int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.totalBytes
|
||||
}
|
||||
|
||||
// Output returns the truncated output suitable for LLM
|
||||
// consumption, along with truncation metadata. If the total
|
||||
// output fits within the head buffer alone, the full output is
|
||||
// returned with nil truncation info. Otherwise the head and
|
||||
// tail are joined with an omission marker and long lines are
|
||||
// truncated.
|
||||
func (b *HeadTailBuffer) Output() (string, *workspacesdk.ProcessTruncation) {
|
||||
b.mu.Lock()
|
||||
head := make([]byte, len(b.head))
|
||||
copy(head, b.head)
|
||||
tail := b.tailBytes()
|
||||
total := b.totalBytes
|
||||
headFull := b.headFull
|
||||
b.mu.Unlock()
|
||||
|
||||
storedLen := len(head) + len(tail)
|
||||
|
||||
// If everything fits, no head/tail split is needed.
|
||||
if !headFull || len(tail) == 0 {
|
||||
out := truncateLines(string(head))
|
||||
if total == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// We have both head and tail data, meaning the total
|
||||
// output exceeded the head capacity. Build the
|
||||
// combined output with an omission marker.
|
||||
omitted := total - storedLen
|
||||
headStr := truncateLines(string(head))
|
||||
tailStr := truncateLines(string(tail))
|
||||
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString(headStr)
|
||||
if omitted > 0 {
|
||||
_, _ = sb.WriteString(fmt.Sprintf(
|
||||
"\n\n... [omitted %d bytes] ...\n\n",
|
||||
omitted,
|
||||
))
|
||||
} else {
|
||||
// Head and tail are contiguous but were stored
|
||||
// separately because the head filled up.
|
||||
_, _ = sb.WriteString("\n")
|
||||
}
|
||||
_, _ = sb.WriteString(tailStr)
|
||||
result := sb.String()
|
||||
|
||||
return result, &workspacesdk.ProcessTruncation{
|
||||
OriginalBytes: total,
|
||||
RetainedBytes: len(result),
|
||||
OmittedBytes: omitted,
|
||||
Strategy: "head_tail",
|
||||
}
|
||||
}
|
||||
|
||||
// truncateLines scans the input line by line and truncates
|
||||
// any line longer than MaxLineLength.
|
||||
func truncateLines(s string) string {
|
||||
if len(s) <= MaxLineLength {
|
||||
// Fast path: if the entire string is shorter than
|
||||
// the max line length, no line can exceed it.
|
||||
return s
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
for len(s) > 0 {
|
||||
idx := strings.IndexByte(s, '\n')
|
||||
var line string
|
||||
if idx == -1 {
|
||||
line = s
|
||||
s = ""
|
||||
} else {
|
||||
line = s[:idx]
|
||||
s = s[idx+1:]
|
||||
}
|
||||
|
||||
if len(line) > MaxLineLength {
|
||||
// Truncate preserving the suffix length so the
|
||||
// total does not exceed a reasonable size.
|
||||
cut := MaxLineLength - len(lineTruncationSuffix)
|
||||
if cut < 0 {
|
||||
cut = 0
|
||||
}
|
||||
_, _ = b.WriteString(line[:cut])
|
||||
_, _ = b.WriteString(lineTruncationSuffix)
|
||||
} else {
|
||||
_, _ = b.WriteString(line)
|
||||
}
|
||||
|
||||
// Re-add the newline unless this was the final
|
||||
// segment without a trailing newline.
|
||||
if idx != -1 {
|
||||
_ = b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Reset clears the buffer, discarding all data.
|
||||
func (b *HeadTailBuffer) Reset() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.head = nil
|
||||
b.tail = nil
|
||||
b.tailPos = 0
|
||||
b.tailFull = false
|
||||
b.headFull = false
|
||||
b.totalBytes = 0
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
package agentproc_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentproc"
|
||||
)
|
||||
|
||||
func TestHeadTailBuffer_EmptyBuffer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
out, info := buf.Output()
|
||||
require.Empty(t, out)
|
||||
require.Nil(t, info)
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
require.Empty(t, buf.Bytes())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_SmallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
data := "hello world\n"
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(data), n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, data, out)
|
||||
require.Nil(t, info, "small output should not be truncated")
|
||||
require.Equal(t, len(data), buf.Len())
|
||||
require.Equal(t, len(data), buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_ExactlyHeadSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Build data that is exactly MaxHeadBytes using short
|
||||
// lines so that line truncation does not apply.
|
||||
line := strings.Repeat("x", 79) + "\n" // 80 bytes per line
|
||||
count := agentproc.MaxHeadBytes / len(line)
|
||||
pad := agentproc.MaxHeadBytes - (count * len(line))
|
||||
data := strings.Repeat(line, count) + strings.Repeat("y", pad)
|
||||
require.Equal(t, agentproc.MaxHeadBytes, len(data),
|
||||
"test data must be exactly MaxHeadBytes")
|
||||
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, agentproc.MaxHeadBytes, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, data, out)
|
||||
require.Nil(t, info, "output fitting in head should not be truncated")
|
||||
require.Equal(t, agentproc.MaxHeadBytes, buf.Len())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_HeadPlusTailNoOmission(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a small buffer so we can test the boundary where
|
||||
// head fills and tail starts but nothing is omitted.
|
||||
// With maxHead=10, maxTail=10, writing exactly 20 bytes
|
||||
// means head gets 10, tail gets 10, omitted = 0.
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
data := "0123456789abcdefghij" // 20 bytes
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 20, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 0, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
// The output should contain both head and tail.
|
||||
require.Contains(t, out, "0123456789")
|
||||
require.Contains(t, out, "abcdefghij")
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LargeOutputTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use small head/tail so truncation is easy to verify.
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
// Write 100 bytes: head=10, tail=10, omitted=80.
|
||||
data := strings.Repeat("A", 50) + strings.Repeat("Z", 50)
|
||||
n, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 100, n)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 100, info.OriginalBytes)
|
||||
require.Equal(t, 80, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
|
||||
// Head should be first 10 bytes (all A's).
|
||||
require.True(t, strings.HasPrefix(out, "AAAAAAAAAA"))
|
||||
// Tail should be last 10 bytes (all Z's).
|
||||
require.True(t, strings.HasSuffix(out, "ZZZZZZZZZZ"))
|
||||
// Omission marker should be present.
|
||||
require.Contains(t, out, "... [omitted 80 bytes] ...")
|
||||
|
||||
require.Equal(t, 20, buf.Len())
|
||||
require.Equal(t, 100, buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultiMBStaysBounded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write 5MB of data in chunks.
|
||||
chunk := []byte(strings.Repeat("x", 4096) + "\n")
|
||||
totalWritten := 0
|
||||
for totalWritten < 5*1024*1024 {
|
||||
n, err := buf.Write(chunk)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(chunk), n)
|
||||
totalWritten += n
|
||||
}
|
||||
|
||||
// Memory should be bounded to head+tail.
|
||||
require.LessOrEqual(t, buf.Len(),
|
||||
agentproc.MaxHeadBytes+agentproc.MaxTailBytes)
|
||||
require.Equal(t, totalWritten, buf.TotalWritten())
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, totalWritten, info.OriginalBytes)
|
||||
require.Greater(t, info.OmittedBytes, 0)
|
||||
require.NotEmpty(t, out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LongLineTruncation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write a line longer than MaxLineLength.
|
||||
longLine := strings.Repeat("m", agentproc.MaxLineLength+500)
|
||||
_, err := buf.Write([]byte(longLine + "\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, _ := buf.Output()
|
||||
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
|
||||
require.Len(t, lines, 1)
|
||||
require.LessOrEqual(t, len(lines[0]), agentproc.MaxLineLength)
|
||||
require.True(t, strings.HasSuffix(lines[0], "... [truncated]"))
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_LongLineInTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use small buffers so we can force data into the tail.
|
||||
buf := agentproc.NewHeadTailBufferSized(20, 5000)
|
||||
|
||||
// Fill head with short data.
|
||||
_, err := buf.Write([]byte("head data goes here\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now write a very long line into the tail.
|
||||
longLine := strings.Repeat("T", agentproc.MaxLineLength+100)
|
||||
_, err = buf.Write([]byte(longLine + "\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
// The long line in the tail should be truncated.
|
||||
require.Contains(t, out, "... [truncated]")
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_ConcurrentWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
const goroutines = 10
|
||||
const writes = 1000
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for g := range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
line := fmt.Sprintf("goroutine-%d: data\n", g)
|
||||
for range writes {
|
||||
_, err := buf.Write([]byte(line))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify totals are consistent.
|
||||
require.Greater(t, buf.TotalWritten(), 0)
|
||||
require.Greater(t, buf.Len(), 0)
|
||||
|
||||
out, _ := buf.Output()
|
||||
require.NotEmpty(t, out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_TruncationInfoFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBufferSized(10, 10)
|
||||
|
||||
// Write enough to cause omission.
|
||||
data := strings.Repeat("D", 50)
|
||||
_, err := buf.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
require.Equal(t, 50, info.OriginalBytes)
|
||||
require.Equal(t, 30, info.OmittedBytes)
|
||||
require.Equal(t, "head_tail", info.Strategy)
|
||||
// RetainedBytes is the length of the formatted output
|
||||
// string including the omission marker.
|
||||
require.Greater(t, info.RetainedBytes, 0)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultipleSmallWrites(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
// Write one byte at a time.
|
||||
expected := "hello world"
|
||||
for i := range len(expected) {
|
||||
n, err := buf.Write([]byte{expected[i]})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, n)
|
||||
}
|
||||
|
||||
out, info := buf.Output()
|
||||
require.Equal(t, expected, out)
|
||||
require.Nil(t, info)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_WriteEmptySlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
n, err := buf.Write([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
_, err := buf.Write([]byte("some data"))
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, buf.Len(), 0)
|
||||
|
||||
buf.Reset()
|
||||
|
||||
require.Equal(t, 0, buf.Len())
|
||||
require.Equal(t, 0, buf.TotalWritten())
|
||||
out, info := buf.Output()
|
||||
require.Empty(t, out)
|
||||
require.Nil(t, info)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_BytesReturnsCopy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
_, err := buf.Write([]byte("original"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := buf.Bytes()
|
||||
require.Equal(t, []byte("original"), b)
|
||||
|
||||
// Mutating the returned slice should not affect the
|
||||
// buffer.
|
||||
b[0] = 'X'
|
||||
require.Equal(t, []byte("original"), buf.Bytes())
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_RingBufferWraparound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a tail of 10 bytes and write enough to wrap
|
||||
// around multiple times.
|
||||
buf := agentproc.NewHeadTailBufferSized(5, 10)
|
||||
|
||||
// Fill head (5 bytes).
|
||||
_, err := buf.Write([]byte("HEADD"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write 25 bytes into tail, wrapping 2.5 times.
|
||||
_, err = buf.Write([]byte("0123456789"))
|
||||
require.NoError(t, err)
|
||||
_, err = buf.Write([]byte("abcdefghij"))
|
||||
require.NoError(t, err)
|
||||
_, err = buf.Write([]byte("ABCDE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, info := buf.Output()
|
||||
require.NotNil(t, info)
|
||||
// Tail should contain the last 10 bytes: "fghijABCDE".
|
||||
require.True(t, strings.HasSuffix(out, "fghijABCDE"),
|
||||
"expected tail to be last 10 bytes, got: %q", out)
|
||||
}
|
||||
|
||||
func TestHeadTailBuffer_MultipleLinesTruncated(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
buf := agentproc.NewHeadTailBuffer()
|
||||
|
||||
short := "short line\n"
|
||||
long := strings.Repeat("L", agentproc.MaxLineLength+100) + "\n"
|
||||
_, err := buf.Write([]byte(short + long + short))
|
||||
require.NoError(t, err)
|
||||
|
||||
out, _ := buf.Output()
|
||||
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
|
||||
require.Len(t, lines, 3)
|
||||
require.Equal(t, "short line", lines[0])
|
||||
require.True(t, strings.HasSuffix(lines[1], "... [truncated]"))
|
||||
require.Equal(t, "short line", lines[2])
|
||||
}
|
||||
@@ -1,294 +0,0 @@
|
||||
package agentproc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentexec"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
var (
|
||||
errProcessNotFound = xerrors.New("process not found")
|
||||
errProcessNotRunning = xerrors.New("process is not running")
|
||||
)
|
||||
|
||||
// process represents a running or completed process.
|
||||
type process struct {
|
||||
mu sync.Mutex
|
||||
id string
|
||||
command string
|
||||
workDir string
|
||||
background bool
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
buf *HeadTailBuffer
|
||||
running bool
|
||||
exitCode *int
|
||||
startedAt int64
|
||||
exitedAt *int64
|
||||
done chan struct{} // closed when process exits
|
||||
}
|
||||
|
||||
// info returns a snapshot of the process state.
|
||||
func (p *process) info() workspacesdk.ProcessInfo {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
return workspacesdk.ProcessInfo{
|
||||
ID: p.id,
|
||||
Command: p.command,
|
||||
WorkDir: p.workDir,
|
||||
Background: p.background,
|
||||
Running: p.running,
|
||||
ExitCode: p.exitCode,
|
||||
StartedAt: p.startedAt,
|
||||
ExitedAt: p.exitedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// output returns the truncated output from the process buffer
|
||||
// along with optional truncation metadata.
|
||||
func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
|
||||
return p.buf.Output()
|
||||
}
|
||||
|
||||
// manager tracks processes spawned by the agent.
|
||||
type manager struct {
|
||||
mu sync.Mutex
|
||||
logger slog.Logger
|
||||
execer agentexec.Execer
|
||||
clock quartz.Clock
|
||||
procs map[string]*process
|
||||
closed bool
|
||||
updateEnv func(current []string) (updated []string, err error)
|
||||
}
|
||||
|
||||
// newManager creates a new process manager.
|
||||
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
|
||||
return &manager{
|
||||
logger: logger,
|
||||
execer: execer,
|
||||
clock: quartz.NewReal(),
|
||||
procs: make(map[string]*process),
|
||||
updateEnv: updateEnv,
|
||||
}
|
||||
}
|
||||
|
||||
// start spawns a new process. Both foreground and background
|
||||
// processes use a long-lived context so the process survives
|
||||
// the HTTP request lifecycle. The background flag only affects
|
||||
// client-side polling behavior.
|
||||
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return nil, xerrors.New("manager is closed")
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
id := uuid.New().String()
|
||||
|
||||
// Use a cancellable context so Close() can terminate
|
||||
// all processes. context.Background() is the parent so
|
||||
// the process is not tied to any HTTP request.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
|
||||
if req.WorkDir != "" {
|
||||
cmd.Dir = req.WorkDir
|
||||
}
|
||||
cmd.Stdin = nil
|
||||
|
||||
// WaitDelay ensures cmd.Wait returns promptly after
|
||||
// the process is killed, even if child processes are
|
||||
// still holding the stdout/stderr pipes open.
|
||||
cmd.WaitDelay = 5 * time.Second
|
||||
|
||||
buf := NewHeadTailBuffer()
|
||||
cmd.Stdout = buf
|
||||
cmd.Stderr = buf
|
||||
|
||||
// Build the process environment. If the manager has an
|
||||
// updateEnv hook (provided by the agent), use it to get the
|
||||
// full agent environment including GIT_ASKPASS, CODER_* vars,
|
||||
// etc. Otherwise fall back to the current process env.
|
||||
baseEnv := os.Environ()
|
||||
if m.updateEnv != nil {
|
||||
updated, err := m.updateEnv(baseEnv)
|
||||
if err != nil {
|
||||
m.logger.Warn(
|
||||
context.Background(),
|
||||
"failed to update command environment, falling back to os env",
|
||||
slog.Error(err),
|
||||
)
|
||||
} else {
|
||||
baseEnv = updated
|
||||
}
|
||||
}
|
||||
|
||||
// Always set cmd.Env explicitly so that req.Env overrides
|
||||
// are applied on top of the full agent environment.
|
||||
cmd.Env = baseEnv
|
||||
for k, v := range req.Env {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
return nil, xerrors.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
now := m.clock.Now().Unix()
|
||||
proc := &process{
|
||||
id: id,
|
||||
command: req.Command,
|
||||
workDir: req.WorkDir,
|
||||
background: req.Background,
|
||||
cmd: cmd,
|
||||
cancel: cancel,
|
||||
buf: buf,
|
||||
running: true,
|
||||
startedAt: now,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
// Manager closed between our check and now. Kill the
|
||||
// process we just started.
|
||||
cancel()
|
||||
_ = cmd.Wait()
|
||||
return nil, xerrors.New("manager is closed")
|
||||
}
|
||||
m.procs[id] = proc
|
||||
m.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
exitedAt := m.clock.Now().Unix()
|
||||
|
||||
proc.mu.Lock()
|
||||
proc.running = false
|
||||
proc.exitedAt = &exitedAt
|
||||
code := 0
|
||||
if err != nil {
|
||||
// Extract the exit code from the error.
|
||||
var exitErr *exec.ExitError
|
||||
if xerrors.As(err, &exitErr) {
|
||||
code = exitErr.ExitCode()
|
||||
} else {
|
||||
// Unknown error; use -1 as a sentinel.
|
||||
code = -1
|
||||
m.logger.Warn(
|
||||
context.Background(),
|
||||
"process wait returned non-exit error",
|
||||
slog.F("id", id),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
proc.exitCode = &code
|
||||
proc.mu.Unlock()
|
||||
|
||||
close(proc.done)
|
||||
}()
|
||||
|
||||
return proc, nil
|
||||
}
|
||||
|
||||
// get returns a process by ID.
|
||||
func (m *manager) get(id string) (*process, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
proc, ok := m.procs[id]
|
||||
return proc, ok
|
||||
}
|
||||
|
||||
// list returns info about all tracked processes.
|
||||
func (m *manager) list() []workspacesdk.ProcessInfo {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
|
||||
for _, proc := range m.procs {
|
||||
infos = append(infos, proc.info())
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// signal sends a signal to a running process. It returns
|
||||
// sentinel errors errProcessNotFound and errProcessNotRunning
|
||||
// so callers can distinguish failure modes.
|
||||
func (m *manager) signal(id string, sig string) error {
|
||||
m.mu.Lock()
|
||||
proc, ok := m.procs[id]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return errProcessNotFound
|
||||
}
|
||||
|
||||
proc.mu.Lock()
|
||||
defer proc.mu.Unlock()
|
||||
|
||||
if !proc.running {
|
||||
return errProcessNotRunning
|
||||
}
|
||||
|
||||
switch sig {
|
||||
case "kill":
|
||||
if err := proc.cmd.Process.Kill(); err != nil {
|
||||
return xerrors.Errorf("kill process: %w", err)
|
||||
}
|
||||
case "terminate":
|
||||
//nolint:revive // syscall.SIGTERM is portable enough
|
||||
// for our supported platforms.
|
||||
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
return xerrors.Errorf("terminate process: %w", err)
|
||||
}
|
||||
default:
|
||||
return xerrors.Errorf("unsupported signal %q", sig)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close kills all running processes and prevents new ones from
|
||||
// starting. It cancels each process's context, which causes
|
||||
// CommandContext to kill the process and its pipe goroutines to
|
||||
// drain.
|
||||
func (m *manager) Close() error {
|
||||
m.mu.Lock()
|
||||
if m.closed {
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.closed = true
|
||||
procs := make([]*process, 0, len(m.procs))
|
||||
for _, p := range m.procs {
|
||||
procs = append(procs, p)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, p := range procs {
|
||||
p.cancel()
|
||||
}
|
||||
|
||||
// Wait for all processes to exit.
|
||||
for _, p := range procs {
|
||||
<-p.done
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,22 +1,37 @@
|
||||
package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent"
|
||||
"github.com/coder/coder/v2/agent/agentsocket"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
t.Run("StartStop", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
@@ -26,7 +41,7 @@ func TestServer(t *testing.T) {
|
||||
t.Run("AlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server1, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
@@ -34,4 +49,90 @@ func TestServer(t *testing.T) {
|
||||
_, err = agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.ErrorContains(t, err, "create socket")
|
||||
})
|
||||
|
||||
t.Run("AutoSocketPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServerWindowsNotSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("this test only runs on Windows")
|
||||
}
|
||||
|
||||
t.Run("NewServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
logger := slog.Make().Leveled(slog.LevelDebug)
|
||||
_, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
|
||||
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
|
||||
})
|
||||
|
||||
t.Run("NewClient", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := agentsocket.NewClient(context.Background(), agentsocket.WithPath("test.sock"))
|
||||
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAgentInitializesOnWindowsWithoutSocketServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("this test only runs on Windows")
|
||||
}
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := testutil.Logger(t).Named("agent")
|
||||
|
||||
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
|
||||
|
||||
coordinator := tailnet.NewCoordinator(logger)
|
||||
t.Cleanup(func() {
|
||||
_ = coordinator.Close()
|
||||
})
|
||||
|
||||
statsCh := make(chan *agentproto.Stats, 50)
|
||||
agentID := uuid.New()
|
||||
manifest := agentsdk.Manifest{
|
||||
AgentID: agentID,
|
||||
AgentName: "test-agent",
|
||||
WorkspaceName: "test-workspace",
|
||||
OwnerName: "test-user",
|
||||
WorkspaceID: uuid.New(),
|
||||
DERPMap: derpMap,
|
||||
}
|
||||
|
||||
client := agenttest.NewClient(t, logger.Named("agenttest"), agentID, manifest, statsCh, coordinator)
|
||||
t.Cleanup(client.Close)
|
||||
|
||||
options := agent.Options{
|
||||
Client: client,
|
||||
Filesystem: afero.NewMemMapFs(),
|
||||
Logger: logger.Named("agent"),
|
||||
ReconnectingPTYTimeout: testutil.WaitShort,
|
||||
EnvironmentVariables: map[string]string{},
|
||||
SocketPath: "",
|
||||
}
|
||||
|
||||
agnt := agent.New(options)
|
||||
t.Cleanup(func() {
|
||||
_ = agnt.Close()
|
||||
})
|
||||
|
||||
startup := testutil.TryReceive(ctx, t, client.GetStartup())
|
||||
require.NotNil(t, startup, "agent should send startup message")
|
||||
|
||||
err := agnt.Close()
|
||||
require.NoError(t, err, "agent should close cleanly")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package agentsocket_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -28,10 +30,14 @@ func newSocketClient(ctx context.Context, t *testing.T, socketPath string) *agen
|
||||
func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -51,7 +57,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
|
||||
t.Run("NewUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -73,7 +79,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitAlreadyStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -103,7 +109,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -142,7 +148,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -172,7 +178,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("NewUnits", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -197,7 +203,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -232,7 +238,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -274,7 +280,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnregisteredUnit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -293,7 +299,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitNotReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
@@ -317,7 +323,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
|
||||
t.Run("UnitReady", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
|
||||
@@ -4,60 +4,19 @@ package agentsocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
const defaultSocketPath = `\\.\pipe\com.coder.agentsocket`
|
||||
|
||||
func createSocket(path string) (net.Listener, error) {
|
||||
if path == "" {
|
||||
path = defaultSocketPath
|
||||
}
|
||||
if !strings.HasPrefix(path, `\\.\pipe\`) {
|
||||
return nil, xerrors.Errorf("%q is not a valid local socket path", path)
|
||||
}
|
||||
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to look up current user: %w", err)
|
||||
}
|
||||
sid := user.Uid
|
||||
|
||||
// SecurityDescriptor is in SDDL format. c.f.
|
||||
// https://learn.microsoft.com/en-us/windows/win32/secauthz/security-descriptor-string-format for full details.
|
||||
// D: indicates this is a Discretionary Access Control List (DACL), which is Windows-speak for ACLs that allow or
|
||||
// deny access (as opposed to SACL which controls audit logging).
|
||||
// P indicates that this DACL is "protected" from being modified thru inheritance
|
||||
// () delimit access control entries (ACEs), here we only have one, which, allows (A) generic all (GA) access to our
|
||||
// specific user's security ID (SID).
|
||||
//
|
||||
// Note that although Microsoft docs at https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes warns that
|
||||
// named pipes are accessible from remote machines in the general case, the `winio` package sets the flag
|
||||
// windows.FILE_PIPE_REJECT_REMOTE_CLIENTS when creating pipes, so connections from remote machines are always
|
||||
// denied. This is important because we sort of expect customers to run the Coder agent under a generic user
|
||||
// account unless they are very sophisticated. We don't want this socket to cross the boundary of the local machine.
|
||||
configuration := &winio.PipeConfig{
|
||||
SecurityDescriptor: fmt.Sprintf("D:P(A;;GA;;;%s)", sid),
|
||||
}
|
||||
|
||||
listener, err := winio.ListenPipe(path, configuration)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to open named pipe: %w", err)
|
||||
}
|
||||
return listener, nil
|
||||
func createSocket(_ string) (net.Listener, error) {
|
||||
return nil, xerrors.New("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
func cleanupSocket(path string) error {
|
||||
return os.Remove(path)
|
||||
func cleanupSocket(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func dialSocket(ctx context.Context, path string) (net.Conn, error) {
|
||||
return winio.DialPipeContext(ctx, path)
|
||||
func dialSocket(_ context.Context, _ string) (net.Conn, error) {
|
||||
return nil, xerrors.New("agentsocket is not supported on Windows")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
|
||||
var o agent.Options
|
||||
log := testutil.Logger(t).Named("agent")
|
||||
o.Logger = log
|
||||
o.SocketPath = testutil.AgentSocketPath(t)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
|
||||
@@ -124,12 +124,6 @@ func (c *Client) Close() {
|
||||
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
|
||||
}
|
||||
|
||||
func (c *Client) ConnectRPC28WithRole(ctx context.Context, _ string) (
|
||||
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
|
||||
) {
|
||||
return c.ConnectRPC28(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) ConnectRPC28(ctx context.Context) (
|
||||
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
|
||||
) {
|
||||
@@ -235,10 +229,6 @@ type FakeAgentAPI struct {
|
||||
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
|
||||
}
|
||||
|
||||
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
|
||||
return f.manifest, nil
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ func (a *agent) apiHandler() http.Handler {
|
||||
})
|
||||
|
||||
r.Mount("/api/v0", a.filesAPI.Routes())
|
||||
r.Mount("/api/v0/processes", a.processAPI.Routes())
|
||||
|
||||
if a.devcontainers {
|
||||
r.Mount("/api/v0/containers", a.containerAPI.Routes())
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -70,7 +69,7 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1,286 +0,0 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.30.0
|
||||
// protoc v4.23.4
|
||||
// source: agent/boundarylogproxy/codec/boundary.proto
|
||||
|
||||
package codec
|
||||
|
||||
import (
|
||||
proto "github.com/coder/coder/v2/agent/proto"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// BoundaryMessage is the envelope for all TagV2 messages sent over the
|
||||
// boundary <-> agent unix socket. TagV1 carries a bare
|
||||
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
|
||||
// everything in this envelope so the protocol can be extended with new
|
||||
// message types without adding more tags.
|
||||
type BoundaryMessage struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Msg:
|
||||
//
|
||||
// *BoundaryMessage_Logs
|
||||
// *BoundaryMessage_Status
|
||||
Msg isBoundaryMessage_Msg `protobuf_oneof:"msg"`
|
||||
}
|
||||
|
||||
func (x *BoundaryMessage) Reset() {
|
||||
*x = BoundaryMessage{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *BoundaryMessage) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*BoundaryMessage) ProtoMessage() {}
|
||||
|
||||
func (x *BoundaryMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use BoundaryMessage.ProtoReflect.Descriptor instead.
|
||||
func (*BoundaryMessage) Descriptor() ([]byte, []int) {
|
||||
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (m *BoundaryMessage) GetMsg() isBoundaryMessage_Msg {
|
||||
if m != nil {
|
||||
return m.Msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *BoundaryMessage) GetLogs() *proto.ReportBoundaryLogsRequest {
|
||||
if x, ok := x.GetMsg().(*BoundaryMessage_Logs); ok {
|
||||
return x.Logs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *BoundaryMessage) GetStatus() *BoundaryStatus {
|
||||
if x, ok := x.GetMsg().(*BoundaryMessage_Status); ok {
|
||||
return x.Status
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isBoundaryMessage_Msg interface {
|
||||
isBoundaryMessage_Msg()
|
||||
}
|
||||
|
||||
type BoundaryMessage_Logs struct {
|
||||
Logs *proto.ReportBoundaryLogsRequest `protobuf:"bytes,1,opt,name=logs,proto3,oneof"`
|
||||
}
|
||||
|
||||
type BoundaryMessage_Status struct {
|
||||
Status *BoundaryStatus `protobuf:"bytes,2,opt,name=status,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*BoundaryMessage_Logs) isBoundaryMessage_Msg() {}
|
||||
|
||||
func (*BoundaryMessage_Status) isBoundaryMessage_Msg() {}
|
||||
|
||||
// BoundaryStatus carries operational metadata from boundary to the agent.
|
||||
// The agent records these values as Prometheus metrics. This message is
|
||||
// never forwarded to coderd.
|
||||
type BoundaryStatus struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Logs dropped because boundary's internal channel buffer was full.
|
||||
DroppedChannelFull int64 `protobuf:"varint,1,opt,name=dropped_channel_full,json=droppedChannelFull,proto3" json:"dropped_channel_full,omitempty"`
|
||||
// Logs dropped because boundary's batch buffer was full after a
|
||||
// failed flush attempt.
|
||||
DroppedBatchFull int64 `protobuf:"varint,2,opt,name=dropped_batch_full,json=droppedBatchFull,proto3" json:"dropped_batch_full,omitempty"`
|
||||
}
|
||||
|
||||
func (x *BoundaryStatus) Reset() {
|
||||
*x = BoundaryStatus{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *BoundaryStatus) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*BoundaryStatus) ProtoMessage() {}
|
||||
|
||||
func (x *BoundaryStatus) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use BoundaryStatus.ProtoReflect.Descriptor instead.
|
||||
func (*BoundaryStatus) Descriptor() ([]byte, []int) {
|
||||
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *BoundaryStatus) GetDroppedChannelFull() int64 {
|
||||
if x != nil {
|
||||
return x.DroppedChannelFull
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *BoundaryStatus) GetDroppedBatchFull() int64 {
|
||||
if x != nil {
|
||||
return x.DroppedBatchFull
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
var File_agent_boundarylogproxy_codec_boundary_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = []byte{
|
||||
0x0a, 0x2b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79,
|
||||
0x6c, 0x6f, 0x67, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2f, 0x62,
|
||||
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x1f, 0x63,
|
||||
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
|
||||
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x1a, 0x17,
|
||||
0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x67, 0x65, 0x6e,
|
||||
0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x0f, 0x42, 0x6f, 0x75, 0x6e,
|
||||
0x64, 0x61, 0x72, 0x79, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3f, 0x0a, 0x04, 0x6c,
|
||||
0x6f, 0x67, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x63, 0x6f, 0x64, 0x65,
|
||||
0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72,
|
||||
0x74, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x4c, 0x6f, 0x67, 0x73, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x04, 0x6c, 0x6f, 0x67, 0x73, 0x12, 0x49, 0x0a, 0x06,
|
||||
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x63,
|
||||
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
|
||||
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x2e, 0x42,
|
||||
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x48, 0x00, 0x52,
|
||||
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0x70,
|
||||
0x0a, 0x0e, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
|
||||
0x12, 0x30, 0x0a, 0x14, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x6e,
|
||||
0x6e, 0x65, 0x6c, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x12,
|
||||
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x43, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x46, 0x75,
|
||||
0x6c, 0x6c, 0x12, 0x2c, 0x0a, 0x12, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x62, 0x61,
|
||||
0x74, 0x63, 0x68, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x10,
|
||||
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x42, 0x61, 0x74, 0x63, 0x68, 0x46, 0x75, 0x6c, 0x6c,
|
||||
0x42, 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
|
||||
0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x61, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x70,
|
||||
0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce sync.Once
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = file_agent_boundarylogproxy_codec_boundary_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP() []byte {
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce.Do(func() {
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_boundarylogproxy_codec_boundary_proto_rawDescData)
|
||||
})
|
||||
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_boundarylogproxy_codec_boundary_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_agent_boundarylogproxy_codec_boundary_proto_goTypes = []interface{}{
|
||||
(*BoundaryMessage)(nil), // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage
|
||||
(*BoundaryStatus)(nil), // 1: coder.boundarylogproxy.codec.v1.BoundaryStatus
|
||||
(*proto.ReportBoundaryLogsRequest)(nil), // 2: coder.agent.v2.ReportBoundaryLogsRequest
|
||||
}
|
||||
var file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = []int32{
|
||||
2, // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage.logs:type_name -> coder.agent.v2.ReportBoundaryLogsRequest
|
||||
1, // 1: coder.boundarylogproxy.codec.v1.BoundaryMessage.status:type_name -> coder.boundarylogproxy.codec.v1.BoundaryStatus
|
||||
2, // [2:2] is the sub-list for method output_type
|
||||
2, // [2:2] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_boundarylogproxy_codec_boundary_proto_init() }
|
||||
func file_agent_boundarylogproxy_codec_boundary_proto_init() {
|
||||
if File_agent_boundarylogproxy_codec_boundary_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*BoundaryMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*BoundaryStatus); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].OneofWrappers = []interface{}{
|
||||
(*BoundaryMessage_Logs)(nil),
|
||||
(*BoundaryMessage_Status)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_agent_boundarylogproxy_codec_boundary_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_agent_boundarylogproxy_codec_boundary_proto_goTypes,
|
||||
DependencyIndexes: file_agent_boundarylogproxy_codec_boundary_proto_depIdxs,
|
||||
MessageInfos: file_agent_boundarylogproxy_codec_boundary_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_boundarylogproxy_codec_boundary_proto = out.File
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = nil
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_goTypes = nil
|
||||
file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = nil
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/coder/coder/v2/agent/boundarylogproxy/codec";
|
||||
|
||||
package coder.boundarylogproxy.codec.v1;
|
||||
|
||||
import "agent/proto/agent.proto";
|
||||
|
||||
// BoundaryMessage is the envelope for all TagV2 messages sent over the
|
||||
// boundary <-> agent unix socket. TagV1 carries a bare
|
||||
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
|
||||
// everything in this envelope so the protocol can be extended with new
|
||||
// message types without adding more tags.
|
||||
message BoundaryMessage {
|
||||
oneof msg {
|
||||
coder.agent.v2.ReportBoundaryLogsRequest logs = 1;
|
||||
BoundaryStatus status = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// BoundaryStatus carries operational metadata from boundary to the agent.
|
||||
// The agent records these values as Prometheus metrics. This message is
|
||||
// never forwarded to coderd.
|
||||
message BoundaryStatus {
|
||||
// Logs dropped because boundary's internal channel buffer was full.
|
||||
int64 dropped_channel_full = 1;
|
||||
// Logs dropped because boundary's batch buffer was full after a
|
||||
// failed flush attempt.
|
||||
int64 dropped_batch_full = 2;
|
||||
}
|
||||
@@ -14,23 +14,14 @@ import (
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
type Tag uint8
|
||||
|
||||
const (
|
||||
// TagV1 identifies the first revision of the protocol. The payload is a
|
||||
// bare ReportBoundaryLogsRequest. This version has a maximum data length
|
||||
// of MaxMessageSizeV1.
|
||||
// TagV1 identifies the first revision of the protocol. This version has a maximum
|
||||
// data length of MaxMessageSizeV1.
|
||||
TagV1 Tag = 1
|
||||
|
||||
// TagV2 identifies the second revision of the protocol. The payload is
|
||||
// a BoundaryMessage envelope. This version has a maximum data length of
|
||||
// MaxMessageSizeV2.
|
||||
TagV2 Tag = 2
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -44,9 +35,6 @@ const (
|
||||
// over the wire for the TagV1 tag. While the wire format allows 24 bits for
|
||||
// length, TagV1 only uses 15 bits.
|
||||
MaxMessageSizeV1 uint32 = 1 << 15
|
||||
|
||||
// MaxMessageSizeV2 is the maximum data length for TagV2.
|
||||
MaxMessageSizeV2 = MaxMessageSizeV1
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -60,9 +48,12 @@ var (
|
||||
// WriteFrame writes a framed message with the given tag and data. The data
|
||||
// must not exceed 2^DataLength in length.
|
||||
func WriteFrame(w io.Writer, tag Tag, data []byte) error {
|
||||
maxSize, err := maxSizeForTag(tag)
|
||||
if err != nil {
|
||||
return err
|
||||
var maxSize uint32
|
||||
switch tag {
|
||||
case TagV1:
|
||||
maxSize = MaxMessageSizeV1
|
||||
default:
|
||||
return xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
|
||||
if len(data) > int(maxSize) {
|
||||
@@ -110,9 +101,12 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
|
||||
}
|
||||
tag := Tag(shifted)
|
||||
|
||||
maxSize, err := maxSizeForTag(tag)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
var maxSize uint32
|
||||
switch tag {
|
||||
case TagV1:
|
||||
maxSize = MaxMessageSizeV1
|
||||
default:
|
||||
return 0, nil, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
|
||||
if length > maxSize {
|
||||
@@ -131,56 +125,3 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
|
||||
|
||||
return tag, buf[:length], nil
|
||||
}
|
||||
|
||||
// maxSizeForTag returns the maximum payload size for the given tag.
|
||||
func maxSizeForTag(tag Tag) (uint32, error) {
|
||||
switch tag {
|
||||
case TagV1:
|
||||
return MaxMessageSizeV1, nil
|
||||
case TagV2:
|
||||
return MaxMessageSizeV2, nil
|
||||
default:
|
||||
return 0, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadMessage reads a framed message and unmarshals it based on tag. The
|
||||
// returned buf should be passed back on the next call for buffer reuse.
|
||||
func ReadMessage(r io.Reader, buf []byte) (proto.Message, []byte, error) {
|
||||
tag, data, err := ReadFrame(r, buf)
|
||||
if err != nil {
|
||||
return nil, data, err
|
||||
}
|
||||
|
||||
var msg proto.Message
|
||||
switch tag {
|
||||
case TagV1:
|
||||
var req agentproto.ReportBoundaryLogsRequest
|
||||
if err := proto.Unmarshal(data, &req); err != nil {
|
||||
return nil, data, xerrors.Errorf("unmarshal TagV1: %w", err)
|
||||
}
|
||||
msg = &req
|
||||
case TagV2:
|
||||
var envelope BoundaryMessage
|
||||
if err := proto.Unmarshal(data, &envelope); err != nil {
|
||||
return nil, data, xerrors.Errorf("unmarshal TagV2: %w", err)
|
||||
}
|
||||
msg = &envelope
|
||||
default:
|
||||
// maxSizeForTag already rejects unknown tags during ReadFrame,
|
||||
// but handle it here for safety.
|
||||
return nil, data, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
|
||||
}
|
||||
|
||||
return msg, data, nil
|
||||
}
|
||||
|
||||
// WriteMessage marshals a proto message and writes it as a framed message
|
||||
// with the given tag.
|
||||
func WriteMessage(w io.Writer, tag Tag, msg proto.Message) error {
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal: %w", err)
|
||||
}
|
||||
return WriteFrame(w, tag, data)
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestReadFrameInvalidTag(t *testing.T) {
|
||||
// reading the invalid tag.
|
||||
const (
|
||||
dataLength uint32 = 10
|
||||
bogusTag uint32 = 222
|
||||
bogusTag uint32 = 2
|
||||
)
|
||||
header := bogusTag<<codec.DataLength | dataLength
|
||||
data := make([]byte, 4)
|
||||
@@ -139,7 +139,7 @@ func TestWriteFrameInvalidTag(t *testing.T) {
|
||||
|
||||
var buf bytes.Buffer
|
||||
data := make([]byte, 1)
|
||||
const bogusTag = 222
|
||||
const bogusTag = 2
|
||||
err := codec.WriteFrame(&buf, codec.Tag(bogusTag), data)
|
||||
require.ErrorIs(t, err, codec.ErrUnsupportedTag)
|
||||
}
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package boundarylogproxy
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
// Metrics tracks observability for the boundary -> agent -> coderd audit log
|
||||
// pipeline.
|
||||
//
|
||||
// Audit logs from boundary workspaces pass through several async buffers
|
||||
// before reaching coderd, and any stage can silently drop data. These
|
||||
// metrics make that loss visible so operators/devs can:
|
||||
//
|
||||
// - Bubble up data loss: a non-zero drop rate means audit logs are being
|
||||
// lost, which may have auditing implications.
|
||||
// - Identify the bottleneck: the reason label pinpoints where drops
|
||||
// occur: boundary's internal buffers, the agent's channel, or the
|
||||
// RPC to coderd.
|
||||
// - Tune buffer sizes: sustained "buffer_full" drops indicate the
|
||||
// agent's channel (or boundary's batch buffer) is too small for the
|
||||
// workload. Combined with batches_forwarded_total you can compute a
|
||||
// drop rate: drops / (drops + forwards).
|
||||
// - Detect batch forwarding issues: "forward_failed" drops increase when
|
||||
// the agent cannot reach coderd.
|
||||
//
|
||||
// Drops are captured at two stages:
|
||||
// - Agent-side: the agent's channel buffer overflows (reason
|
||||
// "buffer_full") or the RPC forward to coderd fails (reason
|
||||
// "forward_failed").
|
||||
// - Boundary-reported: boundary self-reports drops via BoundaryStatus
|
||||
// messages (reasons "boundary_channel_full", "boundary_batch_full").
|
||||
// These arrive on the next successful flush from boundary.
|
||||
//
|
||||
// There are circumstances where metrics could be lost e.g., agent restarts,
|
||||
// boundary crashes, or the agent shuts down when the DRPC connection is down.
|
||||
type Metrics struct {
|
||||
batchesDropped *prometheus.CounterVec
|
||||
logsDropped *prometheus.CounterVec
|
||||
batchesForwarded prometheus.Counter
|
||||
}
|
||||
|
||||
func newMetrics(registerer prometheus.Registerer) *Metrics {
|
||||
batchesDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "agent",
|
||||
Subsystem: "boundary_log_proxy",
|
||||
Name: "batches_dropped_total",
|
||||
Help: "Total number of boundary log batches dropped before reaching coderd. " +
|
||||
"Reason: buffer_full = the agent's internal buffer is full, meaning boundary is producing logs faster than the agent can forward them to coderd; " +
|
||||
"forward_failed = the agent failed to send the batch to coderd, potentially because coderd is unreachable or the connection was interrupted.",
|
||||
}, []string{"reason"})
|
||||
registerer.MustRegister(batchesDropped)
|
||||
|
||||
logsDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "agent",
|
||||
Subsystem: "boundary_log_proxy",
|
||||
Name: "logs_dropped_total",
|
||||
Help: "Total number of individual boundary log entries dropped before reaching coderd. " +
|
||||
"Reason: buffer_full = the agent's internal buffer is full; " +
|
||||
"forward_failed = the agent failed to send the batch to coderd; " +
|
||||
"boundary_channel_full = boundary's internal send channel overflowed, meaning boundary is generating logs faster than it can batch and send them; " +
|
||||
"boundary_batch_full = boundary's outgoing batch buffer overflowed after a failed flush, meaning boundary could not write to the agent's socket.",
|
||||
}, []string{"reason"})
|
||||
registerer.MustRegister(logsDropped)
|
||||
|
||||
batchesForwarded := prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "agent",
|
||||
Subsystem: "boundary_log_proxy",
|
||||
Name: "batches_forwarded_total",
|
||||
Help: "Total number of boundary log batches successfully forwarded to coderd. " +
|
||||
"Compare with batches_dropped_total to compute a drop rate.",
|
||||
})
|
||||
registerer.MustRegister(batchesForwarded)
|
||||
|
||||
return &Metrics{
|
||||
batchesDropped: batchesDropped,
|
||||
logsDropped: logsDropped,
|
||||
batchesForwarded: batchesForwarded,
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
@@ -27,13 +26,6 @@ const (
|
||||
logBufferSize = 100
|
||||
)
|
||||
|
||||
const (
|
||||
droppedReasonBoundaryChannelFull = "boundary_channel_full"
|
||||
droppedReasonBoundaryBatchFull = "boundary_batch_full"
|
||||
droppedReasonBufferFull = "buffer_full"
|
||||
droppedReasonForwardFailed = "forward_failed"
|
||||
)
|
||||
|
||||
// DefaultSocketPath returns the default path for the boundary audit log socket.
|
||||
func DefaultSocketPath() string {
|
||||
return filepath.Join(os.TempDir(), "boundary-audit.sock")
|
||||
@@ -51,7 +43,6 @@ type Reporter interface {
|
||||
type Server struct {
|
||||
logger slog.Logger
|
||||
socketPath string
|
||||
metrics *Metrics
|
||||
|
||||
listener net.Listener
|
||||
cancel context.CancelFunc
|
||||
@@ -62,11 +53,10 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer creates a new boundary log proxy server.
|
||||
func NewServer(logger slog.Logger, socketPath string, registerer prometheus.Registerer) *Server {
|
||||
func NewServer(logger slog.Logger, socketPath string) *Server {
|
||||
return &Server{
|
||||
logger: logger.Named("boundary-log-proxy"),
|
||||
socketPath: socketPath,
|
||||
metrics: newMetrics(registerer),
|
||||
logs: make(chan *agentproto.ReportBoundaryLogsRequest, logBufferSize),
|
||||
}
|
||||
}
|
||||
@@ -110,13 +100,9 @@ func (s *Server) RunForwarder(ctx context.Context, sender Reporter) error {
|
||||
s.logger.Warn(ctx, "failed to forward boundary logs",
|
||||
slog.Error(err),
|
||||
slog.F("log_count", len(req.Logs)))
|
||||
s.metrics.batchesDropped.WithLabelValues(droppedReasonForwardFailed).Inc()
|
||||
s.metrics.logsDropped.WithLabelValues(droppedReasonForwardFailed).Add(float64(len(req.Logs)))
|
||||
// Continue forwarding other logs. The current batch is lost,
|
||||
// but the socket stays alive.
|
||||
continue
|
||||
}
|
||||
s.metrics.batchesForwarded.Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,8 +139,8 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// This is intended to be a sane starting point for the read buffer size.
|
||||
// It may be grown by codec.ReadMessage if necessary.
|
||||
// This is intended to be a sane starting point for the read buffer size. It may be
|
||||
// grown by codec.ReadFrame if necessary.
|
||||
const initBufSize = 1 << 10
|
||||
buf := make([]byte, initBufSize)
|
||||
|
||||
@@ -165,59 +151,36 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
default:
|
||||
}
|
||||
|
||||
var err error
|
||||
var msg proto.Message
|
||||
msg, buf, err = codec.ReadMessage(conn, buf)
|
||||
var (
|
||||
tag codec.Tag
|
||||
err error
|
||||
)
|
||||
tag, buf, err = codec.ReadFrame(conn, buf)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed):
|
||||
return
|
||||
case errors.Is(err, codec.ErrUnsupportedTag) || errors.Is(err, codec.ErrMessageTooLarge):
|
||||
case err != nil:
|
||||
s.logger.Warn(ctx, "read frame error", slog.Error(err))
|
||||
return
|
||||
case err != nil:
|
||||
s.logger.Warn(ctx, "read message error", slog.Error(err))
|
||||
}
|
||||
|
||||
if tag != codec.TagV1 {
|
||||
s.logger.Warn(ctx, "invalid tag value", slog.F("tag", tag))
|
||||
return
|
||||
}
|
||||
|
||||
var req agentproto.ReportBoundaryLogsRequest
|
||||
if err := proto.Unmarshal(buf, &req); err != nil {
|
||||
s.logger.Warn(ctx, "proto unmarshal error", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
s.handleMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleMessage(ctx context.Context, msg proto.Message) {
|
||||
switch m := msg.(type) {
|
||||
case *agentproto.ReportBoundaryLogsRequest:
|
||||
s.bufferLogs(ctx, m)
|
||||
case *codec.BoundaryMessage:
|
||||
switch inner := m.Msg.(type) {
|
||||
case *codec.BoundaryMessage_Logs:
|
||||
s.bufferLogs(ctx, inner.Logs)
|
||||
case *codec.BoundaryMessage_Status:
|
||||
s.recordBoundaryStatus(inner.Status)
|
||||
select {
|
||||
case s.logs <- &req:
|
||||
default:
|
||||
s.logger.Warn(ctx, "unknown BoundaryMessage variant")
|
||||
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
|
||||
slog.F("log_count", len(req.Logs)))
|
||||
}
|
||||
default:
|
||||
s.logger.Warn(ctx, "unexpected message type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) recordBoundaryStatus(status *codec.BoundaryStatus) {
|
||||
if n := status.DroppedChannelFull; n > 0 {
|
||||
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryChannelFull).Add(float64(n))
|
||||
}
|
||||
if n := status.DroppedBatchFull; n > 0 {
|
||||
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryBatchFull).Add(float64(n))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) bufferLogs(ctx context.Context, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
select {
|
||||
case s.logs <- req:
|
||||
default:
|
||||
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
|
||||
slog.F("log_count", len(req.Logs)))
|
||||
s.metrics.batchesDropped.WithLabelValues(droppedReasonBufferFull).Inc()
|
||||
s.metrics.logsDropped.WithLabelValues(droppedReasonBufferFull).Add(float64(len(req.Logs)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
@@ -21,42 +21,20 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// sendLogsV1 writes a bare ReportBoundaryLogsRequest using TagV1, the
|
||||
// legacy framing that existing boundary deployments use.
|
||||
func sendLogsV1(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
// sendMessage writes a framed protobuf message to the connection.
|
||||
func sendMessage(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
err := codec.WriteMessage(conn, codec.TagV1, req)
|
||||
data, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
t.Errorf("write v1 logs: %s", err)
|
||||
//nolint:gocritic // In tests we're not worried about conn being nil.
|
||||
t.Errorf("%s marshal req: %s", conn.LocalAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendLogs writes a BoundaryMessage envelope containing logs to the
|
||||
// connection using TagV2.
|
||||
func sendLogs(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
msg := &codec.BoundaryMessage{
|
||||
Msg: &codec.BoundaryMessage_Logs{Logs: req},
|
||||
}
|
||||
err := codec.WriteMessage(conn, codec.TagV2, msg)
|
||||
err = codec.WriteFrame(conn, codec.TagV1, data)
|
||||
if err != nil {
|
||||
t.Errorf("write logs: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// sendStatus writes a BoundaryMessage envelope containing a BoundaryStatus
|
||||
// to the connection using TagV2.
|
||||
func sendStatus(t *testing.T, conn net.Conn, status *codec.BoundaryStatus) {
|
||||
t.Helper()
|
||||
|
||||
msg := &codec.BoundaryMessage{
|
||||
Msg: &codec.BoundaryMessage_Status{Status: status},
|
||||
}
|
||||
err := codec.WriteMessage(conn, codec.TagV2, msg)
|
||||
if err != nil {
|
||||
t.Errorf("write status: %s", err)
|
||||
//nolint:gocritic // In tests we're not worried about conn being nil.
|
||||
t.Errorf("%s write frame: %s", conn.LocalAddr().String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +80,7 @@ func TestServer_StartAndClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -121,7 +99,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -158,7 +136,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
sendLogs(t, conn, req)
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
// Wait for the reporter to receive the log.
|
||||
require.Eventually(t, func() bool {
|
||||
@@ -181,7 +159,7 @@ func TestServer_MultipleMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -217,7 +195,7 @@ func TestServer_MultipleMessages(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req)
|
||||
sendMessage(t, conn, req)
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
@@ -233,7 +211,7 @@ func TestServer_MultipleConnections(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -276,7 +254,7 @@ func TestServer_MultipleConnections(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req)
|
||||
sendMessage(t, conn, req)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -294,7 +272,7 @@ func TestServer_MessageTooLarge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -322,7 +300,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -364,7 +342,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req1)
|
||||
sendMessage(t, conn, req1)
|
||||
|
||||
select {
|
||||
case <-reportNotify:
|
||||
@@ -387,7 +365,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req2)
|
||||
sendMessage(t, conn, req2)
|
||||
|
||||
// Only the second message should be recorded.
|
||||
require.Eventually(t, func() bool {
|
||||
@@ -407,7 +385,7 @@ func TestServer_CloseStopsForwarder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -436,7 +414,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -480,7 +458,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req)
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
@@ -495,7 +473,7 @@ func TestServer_InvalidHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -545,7 +523,7 @@ func TestServer_AllowRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -581,7 +559,7 @@ func TestServer_AllowRequest(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, req)
|
||||
sendMessage(t, conn, req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs := reporter.getLogs()
|
||||
@@ -598,258 +576,3 @@ func TestServer_AllowRequest(t *testing.T) {
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_TagV1BackwardsCompatibility(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reporter := &fakeReporter{}
|
||||
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send a TagV1 message (bare ReportBoundaryLogsRequest) to verify
|
||||
// the server still handles the legacy framing used by existing
|
||||
// boundary deployments.
|
||||
v1Req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com/v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogsV1(t, conn, v1Req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(reporter.getLogs()) == 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
// Now send a TagV2 message on the same connection to verify both
|
||||
// tag versions work interleaved.
|
||||
v2Req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: []*agentproto.BoundaryLog{
|
||||
{
|
||||
Allowed: false,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "POST",
|
||||
Url: "https://example.com/v2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
sendLogs(t, conn, v2Req)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(reporter.getLogs()) == 2
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
logs := reporter.getLogs()
|
||||
require.Equal(t, "https://example.com/v1", logs[0].Logs[0].GetHttpRequest().Url)
|
||||
require.Equal(t, "https://example.com/v2", logs[1].Logs[0].GetHttpRequest().Url)
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
}
|
||||
|
||||
func TestServer_Metrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
makeReq := func(n int) *agentproto.ReportBoundaryLogsRequest {
|
||||
logs := make([]*agentproto.BoundaryLog, n)
|
||||
for i := range n {
|
||||
logs[i] = &agentproto.BoundaryLog{
|
||||
Allowed: true,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{
|
||||
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: "GET",
|
||||
Url: "https://example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return &agentproto.ReportBoundaryLogsRequest{Logs: logs}
|
||||
}
|
||||
|
||||
// BufferFull needs its own setup because it intentionally does not run
|
||||
// a forwarder so the channel fills up.
|
||||
t.Run("BufferFull", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Fill the buffer (size 100) without running a forwarder so nothing
|
||||
// drains. Then send one more to trigger the drop path.
|
||||
for range 101 {
|
||||
sendLogs(t, conn, makeReq(1))
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "buffer_full") >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.GreaterOrEqual(t,
|
||||
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "buffer_full"),
|
||||
float64(1))
|
||||
})
|
||||
|
||||
// The remaining metrics share one server, forwarder, and connection. The
|
||||
// phases run sequentially so metrics accumulate.
|
||||
t.Run("Forwarding", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
|
||||
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
|
||||
|
||||
err := srv.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, srv.Close()) })
|
||||
|
||||
reportNotify := make(chan struct{}, 4)
|
||||
reporter := &fakeReporter{
|
||||
err: context.DeadlineExceeded,
|
||||
errOnce: true,
|
||||
reportCb: func() {
|
||||
select {
|
||||
case reportNotify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
forwarderDone := make(chan error, 1)
|
||||
go func() {
|
||||
forwarderDone <- srv.RunForwarder(ctx, reporter)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Phase 1: the first forward errors
|
||||
sendLogs(t, conn, makeReq(2))
|
||||
|
||||
select {
|
||||
case <-reportNotify:
|
||||
case <-time.After(testutil.WaitShort):
|
||||
t.Fatal("timed out waiting for forward attempt")
|
||||
}
|
||||
|
||||
// The metric is incremented after ReportBoundaryLogs returns, so we
|
||||
// need to poll briefly.
|
||||
require.Eventually(t, func() bool {
|
||||
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "forward_failed") >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Equal(t, float64(2),
|
||||
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "forward_failed"))
|
||||
|
||||
// Phase 2: forward succeeds.
|
||||
sendLogs(t, conn, makeReq(1))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(reporter.getLogs()) >= 1
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Equal(t, float64(1),
|
||||
getCounterValue(t, reg, "agent_boundary_log_proxy_batches_forwarded_total"))
|
||||
|
||||
// Phase 3: boundary-reported drop counts arrive as a separate BoundaryStatus
|
||||
// message, not piggybacked on log batches.
|
||||
sendStatus(t, conn, &codec.BoundaryStatus{
|
||||
DroppedChannelFull: 5,
|
||||
DroppedBatchFull: 3,
|
||||
})
|
||||
|
||||
// Status is handled immediately by the reader goroutine, not by the
|
||||
// forwarder, so poll metrics directly.
|
||||
require.Eventually(t, func() bool {
|
||||
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full") >= 5
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.Equal(t, float64(5),
|
||||
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full"))
|
||||
require.Equal(t, float64(3),
|
||||
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_batch_full"))
|
||||
|
||||
cancel()
|
||||
<-forwarderDone
|
||||
})
|
||||
}
|
||||
|
||||
// getCounterVecValue returns the current value of a CounterVec metric filtered
|
||||
// by the given reason label.
|
||||
func getCounterVecValue(t *testing.T, reg *prometheus.Registry, name, reason string) float64 {
|
||||
t.Helper()
|
||||
|
||||
metrics, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, mf := range metrics {
|
||||
if mf.GetName() != name {
|
||||
continue
|
||||
}
|
||||
for _, m := range mf.GetMetric() {
|
||||
for _, lp := range m.GetLabel() {
|
||||
if lp.GetName() == "reason" && lp.GetValue() == reason {
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// getCounterValue returns the current value of a Counter metric.
|
||||
func getCounterValue(t *testing.T, reg *prometheus.Registry, name string) float64 {
|
||||
t.Helper()
|
||||
|
||||
metrics, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, mf := range metrics {
|
||||
if mf.GetName() != name {
|
||||
continue
|
||||
}
|
||||
for _, m := range mf.GetMetric() {
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
package filefinder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/filefinder"
|
||||
)
|
||||
|
||||
var (
|
||||
dirNames = []string{
|
||||
"cmd", "internal", "pkg", "api", "auth", "database", "server", "client", "middleware",
|
||||
"handler", "config", "utils", "models", "service", "worker", "scheduler", "notification",
|
||||
"provisioner", "template", "workspace", "agent", "proxy", "crypto", "telemetry", "billing",
|
||||
}
|
||||
fileExts = []string{
|
||||
".go", ".ts", ".tsx", ".js", ".py", ".sql", ".yaml", ".json", ".md", ".proto", ".sh",
|
||||
}
|
||||
fileStems = []string{
|
||||
"main", "handler", "middleware", "service", "model", "query", "config", "utils", "helpers",
|
||||
"types", "interface", "test", "mock", "factory", "builder", "adapter", "observer", "provider",
|
||||
"resolver", "schema", "migration", "fixture", "snapshot", "checkpoint",
|
||||
}
|
||||
)
|
||||
|
||||
// generateFileTree creates n files under root in a realistic nested directory structure.
|
||||
func generateFileTree(t testing.TB, root string, n int, seed int64) {
|
||||
t.Helper()
|
||||
rng := rand.New(rand.NewSource(seed)) //nolint:gosec // deterministic benchmarks
|
||||
|
||||
numDirs := n / 5
|
||||
if numDirs < 10 {
|
||||
numDirs = 10
|
||||
}
|
||||
dirs := make([]string, 0, numDirs)
|
||||
for i := 0; i < numDirs; i++ {
|
||||
depth := rng.Intn(6) + 1
|
||||
parts := make([]string, depth)
|
||||
for d := 0; d < depth; d++ {
|
||||
parts[d] = dirNames[rng.Intn(len(dirNames))]
|
||||
}
|
||||
dirs = append(dirs, filepath.Join(parts...))
|
||||
}
|
||||
|
||||
created := make(map[string]struct{})
|
||||
for _, d := range dirs {
|
||||
full := filepath.Join(root, d)
|
||||
if _, ok := created[full]; ok {
|
||||
continue
|
||||
}
|
||||
require.NoError(t, os.MkdirAll(full, 0o755))
|
||||
created[full] = struct{}{}
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
dir := dirs[rng.Intn(len(dirs))]
|
||||
stem := fileStems[rng.Intn(len(fileStems))]
|
||||
ext := fileExts[rng.Intn(len(fileExts))]
|
||||
name := fmt.Sprintf("%s_%d%s", stem, i, ext)
|
||||
full := filepath.Join(root, dir, name)
|
||||
f, err := os.Create(full)
|
||||
require.NoError(t, err)
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// buildIndex walks root and returns a populated Index, the same
|
||||
// way Engine.AddRoot does but without starting a watcher.
|
||||
func buildIndex(t testing.TB, root string) *filefinder.Index {
|
||||
t.Helper()
|
||||
absRoot, err := filepath.Abs(root)
|
||||
require.NoError(t, err)
|
||||
idx, err := filefinder.BuildTestIndex(absRoot)
|
||||
require.NoError(t, err)
|
||||
return idx
|
||||
}
|
||||
|
||||
func BenchmarkBuildIndex(b *testing.B) {
|
||||
scales := []struct {
|
||||
name string
|
||||
n int
|
||||
}{
|
||||
{"1K", 1_000},
|
||||
{"10K", 10_000},
|
||||
{"100K", 100_000},
|
||||
}
|
||||
|
||||
for _, sc := range scales {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
if sc.n >= 100_000 && testing.Short() {
|
||||
b.Skip("skipping large-scale benchmark")
|
||||
}
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, sc.n, 42)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := buildIndex(b, dir)
|
||||
if idx.Len() == 0 {
|
||||
b.Fatal("expected non-empty index")
|
||||
}
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
idx := buildIndex(b, dir)
|
||||
b.ReportMetric(float64(idx.Len())/b.Elapsed().Seconds(), "files/sec")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearch_ByScale(b *testing.B) {
|
||||
queries := []struct {
|
||||
name string
|
||||
query string
|
||||
}{
|
||||
{"exact_basename", "handler.go"},
|
||||
{"short_query", "ha"},
|
||||
{"fuzzy_basename", "hndlr"},
|
||||
{"path_structured", "internal/handler"},
|
||||
{"multi_token", "api handler"},
|
||||
}
|
||||
scales := []struct {
|
||||
name string
|
||||
n int
|
||||
}{
|
||||
{"1K", 1_000},
|
||||
{"10K", 10_000},
|
||||
{"100K", 100_000},
|
||||
}
|
||||
|
||||
for _, sc := range scales {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
if sc.n >= 100_000 && testing.Short() {
|
||||
b.Skip("skipping large-scale benchmark")
|
||||
}
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, sc.n, 42)
|
||||
idx := buildIndex(b, dir)
|
||||
snap := idx.Snapshot()
|
||||
opts := filefinder.DefaultSearchOptions()
|
||||
|
||||
for _, q := range queries {
|
||||
b.Run(q.name, func(b *testing.B) {
|
||||
p := filefinder.NewQueryPlanForTest(q.query)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = filefinder.SearchSnapshotForTest(p, snap, opts.MaxCandidates)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearch_ConcurrentReads(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, 10_000, 42)
|
||||
|
||||
logger := slogtest.Make(b, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelError)
|
||||
ctx := context.Background()
|
||||
eng := filefinder.NewEngine(logger)
|
||||
require.NoError(b, eng.AddRoot(ctx, dir))
|
||||
b.Cleanup(func() { _ = eng.Close() })
|
||||
|
||||
opts := filefinder.DefaultSearchOptions()
|
||||
goroutines := []int{1, 4, 16, 64}
|
||||
|
||||
for _, g := range goroutines {
|
||||
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
|
||||
b.SetParallelism(g)
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
results, err := eng.Search(ctx, "handler", opts)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_ = results
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDeltaUpdate(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, 10_000, 42)
|
||||
|
||||
addCounts := []int{1, 10, 100}
|
||||
|
||||
for _, count := range addCounts {
|
||||
b.Run(fmt.Sprintf("add_%d_files", count), func(b *testing.B) {
|
||||
paths := make([]string, count)
|
||||
for i := range paths {
|
||||
paths[i] = fmt.Sprintf("injected/dir_%d/newfile_%d.go", i%10, i)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
idx := buildIndex(b, dir)
|
||||
b.StartTimer()
|
||||
for _, p := range paths {
|
||||
idx.Add(p, 0)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(count), "files_added/op")
|
||||
})
|
||||
}
|
||||
|
||||
b.Run("search_after_100_additions", func(b *testing.B) {
|
||||
idx := buildIndex(b, dir)
|
||||
for i := 0; i < 100; i++ {
|
||||
idx.Add(fmt.Sprintf("injected/extra/file_%d.go", i), 0)
|
||||
}
|
||||
snap := idx.Snapshot()
|
||||
plan := filefinder.NewQueryPlanForTest("handler")
|
||||
opts := filefinder.DefaultSearchOptions()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = filefinder.SearchSnapshotForTest(plan, snap, opts.MaxCandidates)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkMemoryProfile(b *testing.B) {
|
||||
scales := []struct {
|
||||
name string
|
||||
n int
|
||||
}{
|
||||
{"10K", 10_000},
|
||||
{"100K", 100_000},
|
||||
}
|
||||
|
||||
for _, sc := range scales {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
if sc.n >= 100_000 && testing.Short() {
|
||||
b.Skip("skipping large-scale memory profile")
|
||||
}
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, sc.n, 42)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
idx := buildIndex(b, dir)
|
||||
_ = idx.Snapshot()
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
// Report memory stats on the last iteration.
|
||||
runtime.GC()
|
||||
var before runtime.MemStats
|
||||
runtime.ReadMemStats(&before)
|
||||
idx := buildIndex(b, dir)
|
||||
var after runtime.MemStats
|
||||
runtime.ReadMemStats(&after)
|
||||
|
||||
allocDelta := after.TotalAlloc - before.TotalAlloc
|
||||
b.ReportMetric(float64(allocDelta)/float64(idx.Len()), "bytes/file")
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&before)
|
||||
snap := idx.Snapshot()
|
||||
_ = snap
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&after)
|
||||
|
||||
snapAlloc := after.TotalAlloc - before.TotalAlloc
|
||||
b.ReportMetric(float64(snapAlloc)/float64(idx.Len()), "snap-bytes/file")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearch_ConcurrentReads_Throughput(b *testing.B) {
|
||||
dir := b.TempDir()
|
||||
generateFileTree(b, dir, 10_000, 42)
|
||||
idx := buildIndex(b, dir)
|
||||
snap := idx.Snapshot()
|
||||
|
||||
goroutines := []int{1, 4, 16, 64}
|
||||
plan := filefinder.NewQueryPlanForTest("handler.go")
|
||||
maxCands := filefinder.DefaultSearchOptions().MaxCandidates
|
||||
|
||||
for _, g := range goroutines {
|
||||
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
var wg sync.WaitGroup
|
||||
perGoroutine := b.N / g
|
||||
if perGoroutine < 1 {
|
||||
perGoroutine = 1
|
||||
}
|
||||
for gi := 0; gi < g; gi++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < perGoroutine; j++ {
|
||||
_ = filefinder.SearchSnapshotForTest(plan, snap, maxCands)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
totalOps := float64(g * perGoroutine)
|
||||
b.ReportMetric(totalOps/b.Elapsed().Seconds(), "searches/sec")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package filefinder
|
||||
|
||||
import "strings"
|
||||
|
||||
// FileFlag represents the type of filesystem entry.
|
||||
type FileFlag uint16
|
||||
|
||||
const (
|
||||
FlagFile FileFlag = 0
|
||||
FlagDir FileFlag = 1
|
||||
FlagSymlink FileFlag = 2
|
||||
)
|
||||
|
||||
type doc struct {
|
||||
path string
|
||||
baseOff int
|
||||
baseLen int
|
||||
depth int
|
||||
flags uint16
|
||||
}
|
||||
|
||||
// Index is an append-only in-memory file index with snapshot support.
|
||||
type Index struct {
|
||||
docs []doc
|
||||
byGram map[uint32][]uint32
|
||||
byPrefix1 [256][]uint32
|
||||
byPrefix2 map[uint16][]uint32
|
||||
byPath map[string]uint32
|
||||
deleted map[uint32]bool
|
||||
}
|
||||
|
||||
// Snapshot is a frozen, read-only view of the index at a point in time.
|
||||
type Snapshot struct {
|
||||
docs []doc
|
||||
deleted map[uint32]bool
|
||||
byGram map[uint32][]uint32
|
||||
byPrefix1 [256][]uint32
|
||||
byPrefix2 map[uint16][]uint32
|
||||
}
|
||||
|
||||
// NewIndex creates an empty Index.
|
||||
func NewIndex() *Index {
|
||||
return &Index{
|
||||
byGram: make(map[uint32][]uint32),
|
||||
byPrefix2: make(map[uint16][]uint32),
|
||||
byPath: make(map[string]uint32),
|
||||
deleted: make(map[uint32]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Add inserts a path into the index, tombstoning any previous entry.
|
||||
func (idx *Index) Add(path string, flags uint16) uint32 {
|
||||
norm := string(normalizePathBytes([]byte(path)))
|
||||
if oldID, ok := idx.byPath[norm]; ok {
|
||||
idx.deleted[oldID] = true
|
||||
}
|
||||
id := uint32(len(idx.docs)) //nolint:gosec // Index will never exceed 2^32 docs.
|
||||
baseOff, baseLen := extractBasename([]byte(norm))
|
||||
idx.docs = append(idx.docs, doc{
|
||||
path: norm, baseOff: baseOff, baseLen: baseLen,
|
||||
depth: strings.Count(norm, "/"), flags: flags,
|
||||
})
|
||||
idx.byPath[norm] = id
|
||||
for _, g := range extractTrigrams([]byte(norm)) {
|
||||
idx.byGram[g] = append(idx.byGram[g], id)
|
||||
}
|
||||
if baseLen > 0 {
|
||||
basename := []byte(norm[baseOff : baseOff+baseLen])
|
||||
p1 := prefix1(basename)
|
||||
idx.byPrefix1[p1] = append(idx.byPrefix1[p1], id)
|
||||
p2 := prefix2(basename)
|
||||
idx.byPrefix2[p2] = append(idx.byPrefix2[p2], id)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// Remove marks the entry for path as deleted.
|
||||
func (idx *Index) Remove(path string) bool {
|
||||
norm := string(normalizePathBytes([]byte(path)))
|
||||
id, ok := idx.byPath[norm]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
idx.deleted[id] = true
|
||||
delete(idx.byPath, norm)
|
||||
return true
|
||||
}
|
||||
|
||||
// Has reports whether path exists (not deleted) in the index.
|
||||
func (idx *Index) Has(path string) bool {
|
||||
_, ok := idx.byPath[string(normalizePathBytes([]byte(path)))]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Len returns the number of live (non-deleted) documents.
|
||||
func (idx *Index) Len() int { return len(idx.byPath) }
|
||||
|
||||
func copyPostings[K comparable](m map[K][]uint32) map[K][]uint32 {
|
||||
cp := make(map[K][]uint32, len(m))
|
||||
for k, v := range m {
|
||||
cp[k] = v[:len(v):len(v)]
|
||||
}
|
||||
return cp
|
||||
}
|
||||
|
||||
// Snapshot returns a frozen read-only view of the index.
|
||||
func (idx *Index) Snapshot() *Snapshot {
|
||||
del := make(map[uint32]bool, len(idx.deleted))
|
||||
for id := range idx.deleted {
|
||||
del[id] = true
|
||||
}
|
||||
var p1Copy [256][]uint32
|
||||
for i, ids := range idx.byPrefix1 {
|
||||
if len(ids) > 0 {
|
||||
p1Copy[i] = ids[:len(ids):len(ids)]
|
||||
}
|
||||
}
|
||||
return &Snapshot{
|
||||
docs: idx.docs[:len(idx.docs):len(idx.docs)],
|
||||
deleted: del,
|
||||
byGram: copyPostings(idx.byGram),
|
||||
byPrefix1: p1Copy,
|
||||
byPrefix2: copyPostings(idx.byPrefix2),
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package filefinder_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/agent/filefinder"
|
||||
)
|
||||
|
||||
func TestIndex_AddAndLen(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("foo/bar.go", 0)
|
||||
idx.Add("foo/baz.go", 0)
|
||||
if idx.Len() != 2 {
|
||||
t.Fatalf("expected 2, got %d", idx.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_Has(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("foo/bar.go", 0)
|
||||
if !idx.Has("foo/bar.go") {
|
||||
t.Fatal("expected Has to return true")
|
||||
}
|
||||
if idx.Has("foo/missing.go") {
|
||||
t.Fatal("expected Has to return false for missing path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_Remove(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("foo/bar.go", 0)
|
||||
if !idx.Remove("foo/bar.go") {
|
||||
t.Fatal("expected Remove to return true")
|
||||
}
|
||||
if idx.Has("foo/bar.go") {
|
||||
t.Fatal("expected Has to return false after Remove")
|
||||
}
|
||||
if idx.Len() != 0 {
|
||||
t.Fatalf("expected Len 0 after Remove, got %d", idx.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_AddOverwrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("foo/bar.go", uint16(filefinder.FlagFile))
|
||||
idx.Add("foo/bar.go", uint16(filefinder.FlagDir)) // overwrite
|
||||
if idx.Len() != 1 {
|
||||
t.Fatalf("expected 1 after overwrite, got %d", idx.Len())
|
||||
}
|
||||
// The old entry should be tombstoned.
|
||||
if !filefinder.IndexIsDeleted(idx, 0) {
|
||||
t.Fatal("expected old entry to be deleted")
|
||||
}
|
||||
if filefinder.IndexIsDeleted(idx, 1) {
|
||||
t.Fatal("expected new entry to be live")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_Snapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("foo/bar.go", 0)
|
||||
idx.Add("foo/baz.go", 0)
|
||||
|
||||
snap := idx.Snapshot()
|
||||
if filefinder.SnapshotCount(snap) != 2 {
|
||||
t.Fatalf("expected snapshot count 2, got %d", filefinder.SnapshotCount(snap))
|
||||
}
|
||||
|
||||
// Adding more docs after snapshot doesn't affect it.
|
||||
idx.Add("foo/qux.go", 0)
|
||||
if filefinder.SnapshotCount(snap) != 2 {
|
||||
t.Fatal("snapshot count should not change after new adds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_TrigramIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("handler.go", 0)
|
||||
|
||||
// "handler.go" should produce trigrams for "handler.go".
|
||||
// Check that at least one trigram exists.
|
||||
if filefinder.IndexByGramLen(idx) == 0 {
|
||||
t.Fatal("expected non-empty trigram index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_PrefixIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("handler.go", 0)
|
||||
|
||||
// basename is "handler.go", first byte is 'h'
|
||||
if filefinder.IndexByPrefix1Len(idx, 'h') == 0 {
|
||||
t.Fatal("expected prefix1['h'] to be non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_RemoveNonexistent(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
if idx.Remove("nonexistent.go") {
|
||||
t.Fatal("expected Remove to return false for missing path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex_PathNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("Foo/Bar.go", 0)
|
||||
// Should be findable with lowercase.
|
||||
if !idx.Has("foo/bar.go") {
|
||||
t.Fatal("expected case-insensitive Has")
|
||||
}
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
// Package filefinder provides an in-memory file index with trigram
|
||||
// matching, fuzzy search, and filesystem watching. It is designed
|
||||
// to power file-finding features on workspace agents.
|
||||
package filefinder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// SearchOptions controls search behavior.
|
||||
type SearchOptions struct {
|
||||
Limit int
|
||||
MaxCandidates int
|
||||
}
|
||||
|
||||
// DefaultSearchOptions returns sensible default search options.
|
||||
func DefaultSearchOptions() SearchOptions {
|
||||
return SearchOptions{Limit: 100, MaxCandidates: 10000}
|
||||
}
|
||||
|
||||
type rootSnapshot struct {
|
||||
root string
|
||||
snap *Snapshot
|
||||
}
|
||||
|
||||
// Engine is the main file finder. Safe for concurrent use.
|
||||
type Engine struct {
|
||||
snap atomic.Pointer[[]*rootSnapshot]
|
||||
logger slog.Logger
|
||||
mu sync.Mutex
|
||||
roots map[string]*rootState
|
||||
eventCh chan rootEvent
|
||||
closeCh chan struct{}
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
type rootState struct {
|
||||
root string
|
||||
index *Index
|
||||
watcher *fsWatcher
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
type rootEvent struct {
|
||||
root string
|
||||
events []FSEvent
|
||||
}
|
||||
|
||||
// walkRoot performs a full filesystem walk of absRoot and returns
|
||||
// a populated Index containing all discovered files and directories.
|
||||
func walkRoot(absRoot string) (*Index, error) {
|
||||
idx := NewIndex()
|
||||
err := filepath.Walk(absRoot, func(path string, info os.FileInfo, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
base := filepath.Base(path)
|
||||
if _, skip := skipDirs[base]; skip && info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if path == absRoot {
|
||||
return nil
|
||||
}
|
||||
relPath, relErr := filepath.Rel(absRoot, path)
|
||||
if relErr != nil {
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
var flags uint16
|
||||
if info.IsDir() {
|
||||
flags = uint16(FlagDir)
|
||||
} else if info.Mode()&os.ModeSymlink != 0 {
|
||||
flags = uint16(FlagSymlink)
|
||||
}
|
||||
idx.Add(relPath, flags)
|
||||
return nil
|
||||
})
|
||||
return idx, err
|
||||
}
|
||||
|
||||
// NewEngine creates a new Engine.
|
||||
func NewEngine(logger slog.Logger) *Engine {
|
||||
e := &Engine{
|
||||
logger: logger,
|
||||
roots: make(map[string]*rootState),
|
||||
eventCh: make(chan rootEvent, 256),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
empty := make([]*rootSnapshot, 0)
|
||||
e.snap.Store(&empty)
|
||||
e.wg.Add(1)
|
||||
go e.start()
|
||||
return e
|
||||
}
|
||||
|
||||
// ErrClosed is returned when operations are attempted on a
|
||||
// closed engine.
|
||||
var ErrClosed = xerrors.New("engine is closed")
|
||||
|
||||
// AddRoot adds a directory root to the engine.
|
||||
func (e *Engine) AddRoot(ctx context.Context, root string) error {
|
||||
absRoot, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve root: %w", err)
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.closed.Load() {
|
||||
e.mu.Unlock()
|
||||
return ErrClosed
|
||||
}
|
||||
if _, exists := e.roots[absRoot]; exists {
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
// Walk and create the watcher outside the lock to avoid
|
||||
// blocking the event pipeline on filesystem I/O.
|
||||
idx, walkErr := walkRoot(absRoot)
|
||||
if walkErr != nil {
|
||||
return xerrors.Errorf("walk root: %w", walkErr)
|
||||
}
|
||||
wCtx, wCancel := context.WithCancel(context.Background())
|
||||
w, wErr := newFSWatcher(absRoot, e.logger)
|
||||
if wErr != nil {
|
||||
wCancel()
|
||||
return xerrors.Errorf("create watcher: %w", wErr)
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
// Re-check after re-acquiring the lock: another goroutine
|
||||
// may have added this root or closed the engine while we
|
||||
// were walking.
|
||||
if e.closed.Load() {
|
||||
e.mu.Unlock()
|
||||
wCancel()
|
||||
_ = w.Close()
|
||||
return ErrClosed
|
||||
}
|
||||
if _, exists := e.roots[absRoot]; exists {
|
||||
e.mu.Unlock()
|
||||
wCancel()
|
||||
_ = w.Close()
|
||||
return nil
|
||||
}
|
||||
rs := &rootState{root: absRoot, index: idx, watcher: w, cancel: wCancel}
|
||||
e.roots[absRoot] = rs
|
||||
w.Start(wCtx)
|
||||
e.wg.Add(1)
|
||||
go e.forwardEvents(wCtx, absRoot, w)
|
||||
e.publishSnapshot()
|
||||
fileCount := idx.Len()
|
||||
e.mu.Unlock()
|
||||
e.logger.Info(ctx, "added root to engine",
|
||||
slog.F("root", absRoot),
|
||||
slog.F("files", fileCount),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoot stops watching a root and removes it.
|
||||
func (e *Engine) RemoveRoot(root string) error {
|
||||
absRoot, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve root: %w", err)
|
||||
}
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
rs, exists := e.roots[absRoot]
|
||||
if !exists {
|
||||
return xerrors.Errorf("root %q not found", absRoot)
|
||||
}
|
||||
rs.cancel()
|
||||
_ = rs.watcher.Close()
|
||||
delete(e.roots, absRoot)
|
||||
e.publishSnapshot()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search performs a fuzzy file search across all roots.
|
||||
func (e *Engine) Search(_ context.Context, query string, opts SearchOptions) ([]Result, error) {
|
||||
if e.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
snapPtr := e.snap.Load()
|
||||
if snapPtr == nil || len(*snapPtr) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
roots := *snapPtr
|
||||
plan := newQueryPlan(query)
|
||||
if len(plan.Normalized) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if opts.Limit <= 0 {
|
||||
opts.Limit = 100
|
||||
}
|
||||
if opts.MaxCandidates <= 0 {
|
||||
opts.MaxCandidates = 10000
|
||||
}
|
||||
params := defaultScoreParams()
|
||||
var allCands []candidate
|
||||
for _, rs := range roots {
|
||||
allCands = append(allCands, searchSnapshot(plan, rs.snap, opts.MaxCandidates)...)
|
||||
}
|
||||
results := mergeAndScore(allCands, plan, params, opts.Limit)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Close shuts down the engine.
|
||||
func (e *Engine) Close() error {
|
||||
if e.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
close(e.closeCh)
|
||||
e.mu.Lock()
|
||||
for _, rs := range e.roots {
|
||||
rs.cancel()
|
||||
_ = rs.watcher.Close()
|
||||
}
|
||||
e.roots = make(map[string]*rootState)
|
||||
e.mu.Unlock()
|
||||
e.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild forces a complete re-walk and re-index of a root.
|
||||
func (e *Engine) Rebuild(ctx context.Context, root string) error {
|
||||
absRoot, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve root: %w", err)
|
||||
}
|
||||
|
||||
// Walk outside the lock to avoid blocking the event
|
||||
// pipeline on potentially slow filesystem I/O.
|
||||
idx, walkErr := walkRoot(absRoot)
|
||||
if walkErr != nil {
|
||||
return xerrors.Errorf("rebuild walk: %w", walkErr)
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
rs, exists := e.roots[absRoot]
|
||||
if !exists {
|
||||
e.mu.Unlock()
|
||||
return xerrors.Errorf("root %q not found", absRoot)
|
||||
}
|
||||
rs.index = idx
|
||||
e.publishSnapshot()
|
||||
fileCount := idx.Len()
|
||||
e.mu.Unlock()
|
||||
e.logger.Info(ctx, "rebuilt root in engine",
|
||||
slog.F("root", absRoot),
|
||||
slog.F("files", fileCount),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) start() {
|
||||
defer e.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-e.closeCh:
|
||||
return
|
||||
case re, ok := <-e.eventCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
e.applyEvents(re)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) forwardEvents(ctx context.Context, root string, w *fsWatcher) {
|
||||
defer e.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-e.closeCh:
|
||||
return
|
||||
case evts, ok := <-w.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case e.eventCh <- rootEvent{root: root, events: evts}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-e.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) applyEvents(re rootEvent) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
rs, exists := e.roots[re.root]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, ev := range re.events {
|
||||
relPath, err := filepath.Rel(rs.root, ev.Path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
relPath = filepath.ToSlash(relPath)
|
||||
switch ev.Op {
|
||||
case OpCreate:
|
||||
if rs.index.Has(relPath) {
|
||||
continue
|
||||
}
|
||||
var flags uint16
|
||||
if ev.IsDir {
|
||||
flags = uint16(FlagDir)
|
||||
}
|
||||
rs.index.Add(relPath, flags)
|
||||
changed = true
|
||||
case OpRemove, OpRename:
|
||||
if rs.index.Remove(relPath) {
|
||||
changed = true
|
||||
}
|
||||
if ev.IsDir || ev.Op == OpRename {
|
||||
prefix := strings.ToLower(filepath.ToSlash(relPath)) + "/"
|
||||
for path := range rs.index.byPath {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
rs.index.Remove(path)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
case OpModify:
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
e.publishSnapshot()
|
||||
}
|
||||
}
|
||||
|
||||
// publishSnapshot builds and atomically publishes a new snapshot.
|
||||
// Must be called with e.mu held.
|
||||
func (e *Engine) publishSnapshot() {
|
||||
roots := make([]*rootSnapshot, 0, len(e.roots))
|
||||
for _, rs := range e.roots {
|
||||
roots = append(roots, &rootSnapshot{
|
||||
root: rs.root,
|
||||
snap: rs.index.Snapshot(),
|
||||
})
|
||||
}
|
||||
slices.SortFunc(roots, func(a, b *rootSnapshot) int {
|
||||
return strings.Compare(a.root, b.root)
|
||||
})
|
||||
e.snap.Store(&roots)
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
package filefinder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/filefinder"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func newTestEngine(t *testing.T) (*filefinder.Engine, context.Context) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
eng := filefinder.NewEngine(logger)
|
||||
t.Cleanup(func() { _ = eng.Close() })
|
||||
return eng, context.Background()
|
||||
}
|
||||
|
||||
func requireResultHasPath(t *testing.T, results []filefinder.Result, path string) {
|
||||
t.Helper()
|
||||
for _, r := range results {
|
||||
if r.Path == path {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected %q in results, got %v", path, resultPaths(results))
|
||||
}
|
||||
|
||||
func TestEngine_SearchFindsKnownFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "src/main.go", "package main")
|
||||
createFile(t, dir, "src/handler.go", "package main")
|
||||
createFile(t, dir, "README.md", "# hello")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
results, err := eng.Search(ctx, "main.go", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, results, "expected to find main.go")
|
||||
requireResultHasPath(t, results, "src/main.go")
|
||||
}
|
||||
|
||||
func TestEngine_SearchFuzzyMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "src/controllers/user_handler.go", "package controllers")
|
||||
createFile(t, dir, "src/models/user.go", "package models")
|
||||
createFile(t, dir, "docs/api.md", "# API")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
// "handler" should match "user_handler.go".
|
||||
results, err := eng.Search(ctx, "handler", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
// The query is a subsequence of "user_handler.go" so it
|
||||
// should appear somewhere in the results.
|
||||
requireResultHasPath(t, results, "src/controllers/user_handler.go")
|
||||
}
|
||||
|
||||
func TestEngine_IndexPicksUpNewFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "existing.txt", "hello")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
createFile(t, dir, "newfile_unique.txt", "world")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
results, sErr := eng.Search(ctx, "newfile_unique", filefinder.DefaultSearchOptions())
|
||||
if sErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range results {
|
||||
if r.Path == "newfile_unique.txt" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "expected newfile_unique.txt to appear via watcher")
|
||||
}
|
||||
|
||||
func TestEngine_IndexRemovesDeletedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "deleteme_unique.txt", "goodbye")
|
||||
createFile(t, dir, "keeper.txt", "stay")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
results, err := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, results, "expected to find deleteme_unique.txt initially")
|
||||
|
||||
require.NoError(t, os.Remove(filepath.Join(dir, "deleteme_unique.txt")))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
results, sErr := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
|
||||
if sErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range results {
|
||||
if r.Path == "deleteme_unique.txt" {
|
||||
return false // still found
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.WaitShort, testutil.IntervalFast, "expected deleteme_unique.txt to disappear after removal")
|
||||
}
|
||||
|
||||
func TestEngine_MultipleRoots(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir1 := t.TempDir()
|
||||
dir2 := t.TempDir()
|
||||
createFile(t, dir1, "alpha_unique.go", "package alpha")
|
||||
createFile(t, dir2, "beta_unique.go", "package beta")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir1))
|
||||
require.NoError(t, eng.AddRoot(ctx, dir2))
|
||||
|
||||
results, err := eng.Search(ctx, "alpha_unique", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
requireResultHasPath(t, results, "alpha_unique.go")
|
||||
|
||||
results, err = eng.Search(ctx, "beta_unique", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
requireResultHasPath(t, results, "beta_unique.go")
|
||||
}
|
||||
|
||||
func TestEngine_EmptyQueryReturnsEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "something.txt", "data")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
results, err := eng.Search(ctx, "", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, results, "empty query should return no results")
|
||||
}
|
||||
|
||||
func TestEngine_CloseIsClean(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "file.txt", "data")
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
ctx := context.Background()
|
||||
eng := filefinder.NewEngine(logger)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
require.NoError(t, eng.Close())
|
||||
|
||||
_, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEngine_AddRootIdempotent(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "file.txt", "data")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
snapLen := filefinder.EngineSnapLen(eng)
|
||||
require.Equal(t, 1, snapLen, "expected exactly one root after duplicate add")
|
||||
}
|
||||
|
||||
func TestEngine_RemoveRoot(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "file.txt", "data")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
results, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, results)
|
||||
|
||||
require.NoError(t, eng.RemoveRoot(dir))
|
||||
|
||||
results, err = eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, results)
|
||||
}
|
||||
|
||||
func TestEngine_Rebuild(t *testing.T) {
|
||||
t.Parallel()
|
||||
dir := t.TempDir()
|
||||
createFile(t, dir, "original.txt", "data")
|
||||
|
||||
eng, ctx := newTestEngine(t)
|
||||
require.NoError(t, eng.AddRoot(ctx, dir))
|
||||
|
||||
createFile(t, dir, "sneaky_rebuild.txt", "hidden")
|
||||
require.NoError(t, eng.Rebuild(ctx, dir))
|
||||
|
||||
results, err := eng.Search(ctx, "sneaky_rebuild", filefinder.DefaultSearchOptions())
|
||||
require.NoError(t, err)
|
||||
requireResultHasPath(t, results, "sneaky_rebuild.txt")
|
||||
}
|
||||
|
||||
// createFile creates a file (and parent dirs) at relPath under dir.
|
||||
func createFile(t *testing.T, dir, relPath, content string) {
|
||||
t.Helper()
|
||||
full := filepath.Join(dir, relPath)
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755))
|
||||
require.NoError(t, os.WriteFile(full, []byte(content), 0o600))
|
||||
}
|
||||
|
||||
func resultPaths(results []filefinder.Result) []string {
|
||||
paths := make([]string, len(results))
|
||||
for i, r := range results {
|
||||
paths[i] = r.Path
|
||||
}
|
||||
sort.Strings(paths)
|
||||
return paths
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package filefinder
|
||||
|
||||
// Test helpers that need internal access.
|
||||
|
||||
// MakeTestSnapshot builds a Snapshot from a list of paths. Useful for
|
||||
// query-level tests that don't need a real filesystem.
|
||||
func MakeTestSnapshot(paths []string) *Snapshot {
|
||||
idx := NewIndex()
|
||||
for _, p := range paths {
|
||||
idx.Add(p, 0)
|
||||
}
|
||||
return idx.Snapshot()
|
||||
}
|
||||
|
||||
// BuildTestIndex walks root and returns a populated Index, the same
|
||||
// way Engine.AddRoot does but without starting a watcher.
|
||||
func BuildTestIndex(root string) (*Index, error) {
|
||||
return walkRoot(root)
|
||||
}
|
||||
|
||||
// IndexIsDeleted reports whether the document at id is tombstoned.
|
||||
func IndexIsDeleted(idx *Index, id uint32) bool {
|
||||
return idx.deleted[id]
|
||||
}
|
||||
|
||||
// IndexByGramLen returns the number of entries in the trigram index.
|
||||
func IndexByGramLen(idx *Index) int {
|
||||
return len(idx.byGram)
|
||||
}
|
||||
|
||||
// IndexByPrefix1Len returns the number of posting-list entries for
|
||||
// the given single-byte prefix.
|
||||
func IndexByPrefix1Len(idx *Index, b byte) int {
|
||||
return len(idx.byPrefix1[b])
|
||||
}
|
||||
|
||||
// SnapshotCount returns the number of documents in a Snapshot.
|
||||
func SnapshotCount(snap *Snapshot) int {
|
||||
return len(snap.docs)
|
||||
}
|
||||
|
||||
// EngineSnapLen returns the number of root snapshots currently held
|
||||
// by the engine, or -1 if the pointer is nil.
|
||||
func EngineSnapLen(eng *Engine) int {
|
||||
p := eng.snap.Load()
|
||||
if p == nil {
|
||||
return -1
|
||||
}
|
||||
return len(*p)
|
||||
}
|
||||
|
||||
// DefaultScoreParamsForTest exposes defaultScoreParams for tests.
|
||||
var DefaultScoreParamsForTest = defaultScoreParams
|
||||
|
||||
// ScoreParamsForTest is a type alias for scoreParams.
|
||||
type ScoreParamsForTest = scoreParams
|
||||
|
||||
// Exported aliases for internal functions used in tests.
|
||||
var (
|
||||
NewQueryPlanForTest = newQueryPlan
|
||||
SearchSnapshotForTest = searchSnapshot
|
||||
IntersectSortedForTest = intersectSorted
|
||||
IntersectAllForTest = intersectAll
|
||||
MergeAndScoreForTest = mergeAndScore
|
||||
NormalizeQueryForTest = normalizeQuery
|
||||
NormalizePathBytesForTest = normalizePathBytes
|
||||
ExtractTrigramsForTest = extractTrigrams
|
||||
ExtractBasenameForTest = extractBasename
|
||||
ExtractSegmentsForTest = extractSegments
|
||||
Prefix1ForTest = prefix1
|
||||
Prefix2ForTest = prefix2
|
||||
IsSubsequenceForTest = isSubsequence
|
||||
LongestContiguousMatchForTest = longestContiguousMatch
|
||||
IsBoundaryForTest = isBoundary
|
||||
CountBoundaryHitsForTest = countBoundaryHits
|
||||
EqualFoldASCIIForTest = equalFoldASCII
|
||||
ScorePathForTest = scorePath
|
||||
PackTrigramForTest = packTrigram
|
||||
)
|
||||
|
||||
// Type aliases for internal types used in tests.
|
||||
type (
|
||||
CandidateForTest = candidate
|
||||
QueryPlanForTest = queryPlan
|
||||
)
|
||||
@@ -1,299 +0,0 @@
|
||||
package filefinder
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type candidate struct {
|
||||
DocID uint32
|
||||
Path string
|
||||
BaseOff int
|
||||
BaseLen int
|
||||
Depth int
|
||||
Flags uint16
|
||||
}
|
||||
|
||||
// Result is a scored search result returned to callers.
|
||||
type Result struct {
|
||||
Path string
|
||||
Score float32
|
||||
IsDir bool
|
||||
}
|
||||
|
||||
type queryPlan struct {
|
||||
Original string
|
||||
Normalized string
|
||||
Tokens [][]byte
|
||||
Trigrams []uint32
|
||||
IsShort bool
|
||||
HasSlash bool
|
||||
BasenameQ []byte
|
||||
DirTokens [][]byte
|
||||
}
|
||||
|
||||
func newQueryPlan(q string) *queryPlan {
|
||||
norm := normalizeQuery(q)
|
||||
p := &queryPlan{Original: q, Normalized: norm}
|
||||
if len(norm) == 0 {
|
||||
p.IsShort = true
|
||||
return p
|
||||
}
|
||||
raw := strings.ReplaceAll(norm, "/", " ")
|
||||
parts := strings.Fields(raw)
|
||||
p.HasSlash = strings.ContainsRune(norm, '/')
|
||||
for _, part := range parts {
|
||||
p.Tokens = append(p.Tokens, []byte(part))
|
||||
}
|
||||
if len(p.Tokens) > 0 {
|
||||
p.BasenameQ = p.Tokens[len(p.Tokens)-1]
|
||||
if len(p.Tokens) > 1 {
|
||||
p.DirTokens = p.Tokens[:len(p.Tokens)-1]
|
||||
}
|
||||
}
|
||||
p.IsShort = true
|
||||
for _, tok := range p.Tokens {
|
||||
if len(tok) >= 3 {
|
||||
p.IsShort = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !p.IsShort {
|
||||
p.Trigrams = extractQueryTrigrams(p.Tokens)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func extractQueryTrigrams(tokens [][]byte) []uint32 {
|
||||
seen := make(map[uint32]struct{})
|
||||
for _, tok := range tokens {
|
||||
if len(tok) < 3 {
|
||||
continue
|
||||
}
|
||||
for i := 0; i <= len(tok)-3; i++ {
|
||||
seen[packTrigram(tok[i], tok[i+1], tok[i+2])] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(seen) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]uint32, 0, len(seen))
|
||||
for g := range seen {
|
||||
result = append(result, g)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func packTrigram(a, b, c byte) uint32 {
|
||||
return uint32(toLowerASCII(a))<<16 | uint32(toLowerASCII(b))<<8 | uint32(toLowerASCII(c))
|
||||
}
|
||||
|
||||
// searchSnapshot runs the full search pipeline against a single
|
||||
// root snapshot: it selects a strategy (prefix, trigram, or
|
||||
// fuzzy fallback) based on query length, retrieves candidate
|
||||
// doc IDs, and converts them into candidate structs.
|
||||
func searchSnapshot(plan *queryPlan, snap *Snapshot, limit int) []candidate {
|
||||
if snap == nil || len(snap.docs) == 0 || len(plan.Normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
var ids []uint32
|
||||
if plan.IsShort {
|
||||
ids = searchShort(plan, snap)
|
||||
} else {
|
||||
ids = searchTrigrams(plan, snap)
|
||||
if len(ids) == 0 && len(plan.BasenameQ) > 0 {
|
||||
ids = searchFuzzyFallback(plan, snap)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
cands := make([]candidate, 0, min(len(ids), limit))
|
||||
for _, id := range ids {
|
||||
if snap.deleted[id] || int(id) >= len(snap.docs) {
|
||||
continue
|
||||
}
|
||||
d := snap.docs[id]
|
||||
cands = append(cands, candidate{
|
||||
DocID: id, Path: d.path, BaseOff: d.baseOff,
|
||||
BaseLen: d.baseLen, Depth: d.depth, Flags: d.flags,
|
||||
})
|
||||
if len(cands) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return cands
|
||||
}
|
||||
|
||||
func searchShort(plan *queryPlan, snap *Snapshot) []uint32 {
|
||||
if len(plan.BasenameQ) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(plan.BasenameQ) >= 2 {
|
||||
if ids := snap.byPrefix2[prefix2(plan.BasenameQ)]; len(ids) > 0 {
|
||||
return ids
|
||||
}
|
||||
}
|
||||
return snap.byPrefix1[prefix1(plan.BasenameQ)]
|
||||
}
|
||||
|
||||
func searchTrigrams(plan *queryPlan, snap *Snapshot) []uint32 {
|
||||
if len(plan.Trigrams) == 0 {
|
||||
return nil
|
||||
}
|
||||
lists := make([][]uint32, 0, len(plan.Trigrams))
|
||||
for _, g := range plan.Trigrams {
|
||||
ids, ok := snap.byGram[g]
|
||||
if !ok || len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
lists = append(lists, ids)
|
||||
}
|
||||
return intersectAll(lists)
|
||||
}
|
||||
|
||||
func searchFuzzyFallback(plan *queryPlan, snap *Snapshot) []uint32 {
|
||||
if len(plan.BasenameQ) == 0 {
|
||||
return nil
|
||||
}
|
||||
bucket := snap.byPrefix1[prefix1(plan.BasenameQ)]
|
||||
if len(bucket) == 0 {
|
||||
return searchSubsequenceScan(plan, snap, 5000)
|
||||
}
|
||||
var ids []uint32
|
||||
for _, id := range bucket {
|
||||
if snap.deleted[id] || int(id) >= len(snap.docs) {
|
||||
continue
|
||||
}
|
||||
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return searchSubsequenceScan(plan, snap, 5000)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func searchSubsequenceScan(plan *queryPlan, snap *Snapshot, maxCheck int) []uint32 {
|
||||
if len(plan.BasenameQ) == 0 {
|
||||
return nil
|
||||
}
|
||||
var ids []uint32
|
||||
checked := 0
|
||||
for id := 0; id < len(snap.docs) && checked < maxCheck; id++ {
|
||||
uid := uint32(id) //nolint:gosec // Snapshot count is bounded well below 2^32.
|
||||
if snap.deleted[uid] {
|
||||
continue
|
||||
}
|
||||
checked++
|
||||
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
|
||||
ids = append(ids, uid)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func intersectSorted(a, b []uint32) []uint32 {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
var result []uint32
|
||||
ai, bi := 0, 0
|
||||
for ai < len(a) && bi < len(b) {
|
||||
switch {
|
||||
case a[ai] < b[bi]:
|
||||
ai++
|
||||
case a[ai] > b[bi]:
|
||||
bi++
|
||||
default:
|
||||
result = append(result, a[ai])
|
||||
ai++
|
||||
bi++
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func intersectAll(lists [][]uint32) []uint32 {
|
||||
if len(lists) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(lists) == 1 {
|
||||
return lists[0]
|
||||
}
|
||||
slices.SortFunc(lists, func(a, b []uint32) int { return len(a) - len(b) })
|
||||
result := lists[0]
|
||||
for i := 1; i < len(lists) && len(result) > 0; i++ {
|
||||
result = intersectSorted(result, lists[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mergeAndScore(cands []candidate, plan *queryPlan, params scoreParams, topK int) []Result {
|
||||
if topK <= 0 || len(cands) == 0 {
|
||||
return nil
|
||||
}
|
||||
query := []byte(plan.Normalized)
|
||||
h := &resultHeap{}
|
||||
heap.Init(h)
|
||||
for i := range cands {
|
||||
c := &cands[i]
|
||||
s := scorePath([]byte(c.Path), c.BaseOff, c.BaseLen, c.Depth, query, plan.Tokens, params)
|
||||
if s <= 0 {
|
||||
continue
|
||||
}
|
||||
// DirTokenHit is applied here rather than in scorePath because
|
||||
// it depends on the query plan's directory tokens, which are
|
||||
// split from the full query during planning. scorePath operates
|
||||
// on raw query bytes without knowledge of token boundaries.
|
||||
if len(plan.DirTokens) > 0 {
|
||||
segments := extractSegments([]byte(c.Path))
|
||||
for _, dt := range plan.DirTokens {
|
||||
for _, seg := range segments {
|
||||
if equalFoldASCII(seg, dt) {
|
||||
s += params.DirTokenHit
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
r := Result{Path: c.Path, Score: s, IsDir: c.Flags == uint16(FlagDir)}
|
||||
if h.Len() < topK {
|
||||
heap.Push(h, r)
|
||||
} else if s > (*h)[0].Score {
|
||||
(*h)[0] = r
|
||||
heap.Fix(h, 0)
|
||||
}
|
||||
}
|
||||
n := h.Len()
|
||||
results := make([]Result, n)
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
v := heap.Pop(h)
|
||||
if r, ok := v.(Result); ok {
|
||||
results[i] = r
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
type resultHeap []Result
|
||||
|
||||
func (h resultHeap) Len() int { return len(h) }
|
||||
func (h resultHeap) Less(i, j int) bool { return h[i].Score < h[j].Score }
|
||||
func (h resultHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *resultHeap) Push(x interface{}) {
|
||||
r, ok := x.(Result)
|
||||
if ok {
|
||||
*h = append(*h, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resultHeap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[:n-1]
|
||||
return x
|
||||
}
|
||||
@@ -1,343 +0,0 @@
|
||||
package filefinder_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/agent/filefinder"
|
||||
)
|
||||
|
||||
func TestNewQueryPlan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantNorm string
|
||||
wantShort bool
|
||||
wantSlash bool
|
||||
wantBase string
|
||||
wantTokens []string
|
||||
wantDirTok []string
|
||||
wantTriCnt int // -1 to skip check
|
||||
}{
|
||||
{"Simple", "foo", "foo", false, false, "foo", []string{"foo"}, nil, 1},
|
||||
{"MultiToken", "foo bar", "foo bar", false, false, "bar", []string{"foo", "bar"}, []string{"foo"}, -1},
|
||||
{"Slash", "internal/foo", "internal/foo", false, true, "foo", []string{"internal", "foo"}, []string{"internal"}, -1},
|
||||
{"SingleChar", "a", "a", true, false, "a", []string{"a"}, nil, 0},
|
||||
{"TwoChars", "ab", "ab", true, false, "ab", []string{"ab"}, nil, -1},
|
||||
{"ThreeChars", "abc", "abc", false, false, "abc", []string{"abc"}, nil, 1},
|
||||
{"DotPrefix", ".go", ".go", false, false, ".go", []string{".go"}, nil, -1},
|
||||
{"UpperCase", "FOO", "foo", false, false, "foo", []string{"foo"}, nil, -1},
|
||||
{"Empty", "", "", true, false, "", nil, nil, -1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest(tt.query)
|
||||
if plan.Normalized != tt.wantNorm {
|
||||
t.Errorf("normalized = %q, want %q", plan.Normalized, tt.wantNorm)
|
||||
}
|
||||
if plan.IsShort != tt.wantShort {
|
||||
t.Errorf("isShort = %v, want %v", plan.IsShort, tt.wantShort)
|
||||
}
|
||||
if plan.HasSlash != tt.wantSlash {
|
||||
t.Errorf("hasSlash = %v, want %v", plan.HasSlash, tt.wantSlash)
|
||||
}
|
||||
if string(plan.BasenameQ) != tt.wantBase {
|
||||
t.Errorf("basenameQ = %q, want %q", plan.BasenameQ, tt.wantBase)
|
||||
}
|
||||
if tt.wantTokens == nil {
|
||||
if len(plan.Tokens) != 0 {
|
||||
t.Errorf("expected 0 tokens, got %d", len(plan.Tokens))
|
||||
}
|
||||
} else {
|
||||
if len(plan.Tokens) != len(tt.wantTokens) {
|
||||
t.Fatalf("tokens len = %d, want %d", len(plan.Tokens), len(tt.wantTokens))
|
||||
}
|
||||
for i, tok := range plan.Tokens {
|
||||
if string(tok) != tt.wantTokens[i] {
|
||||
t.Errorf("tokens[%d] = %q, want %q", i, tok, tt.wantTokens[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if tt.wantDirTok != nil {
|
||||
if len(plan.DirTokens) != len(tt.wantDirTok) {
|
||||
t.Fatalf("dirTokens len = %d, want %d", len(plan.DirTokens), len(tt.wantDirTok))
|
||||
}
|
||||
for i, tok := range plan.DirTokens {
|
||||
if string(tok) != tt.wantDirTok[i] {
|
||||
t.Errorf("dirTokens[%d] = %q, want %q", i, tok, tt.wantDirTok[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if tt.wantTriCnt >= 0 && len(plan.Trigrams) != tt.wantTriCnt {
|
||||
t.Errorf("trigram count = %d, want %d", len(plan.Trigrams), tt.wantTriCnt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ThreeChars: verify the actual trigram value.
|
||||
plan := filefinder.NewQueryPlanForTest("abc")
|
||||
if want := filefinder.PackTrigramForTest('a', 'b', 'c'); plan.Trigrams[0] != want {
|
||||
t.Errorf("trigram = %x, want %x", plan.Trigrams[0], want)
|
||||
}
|
||||
|
||||
// ShortMultiToken: both tokens < 3 chars so isShort should be true.
|
||||
plan = filefinder.NewQueryPlanForTest("ab cd")
|
||||
if !plan.IsShort {
|
||||
t.Error("expected isShort=true when all tokens < 3 chars")
|
||||
}
|
||||
// One token >= 3 chars, so isShort should be false.
|
||||
plan = filefinder.NewQueryPlanForTest("ab cde")
|
||||
if plan.IsShort {
|
||||
t.Error("expected isShort=false when any token >= 3 chars")
|
||||
}
|
||||
}
|
||||
|
||||
func requireCandHasPath(t *testing.T, cands []filefinder.CandidateForTest, path string) {
|
||||
t.Helper()
|
||||
for _, c := range cands {
|
||||
if c.Path == path {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("expected to find %q in candidates", path)
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_TrigramMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
|
||||
if len(cands) == 0 {
|
||||
t.Fatal("expected at least 1 candidate for 'handler'")
|
||||
}
|
||||
requireCandHasPath(t, cands, "src/handler.go")
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_ShortQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"foo.go", "bar.go", "fab.go"})
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("fo"), snap, 100)
|
||||
if len(cands) == 0 {
|
||||
t.Fatal("expected at least 1 candidate for 'fo'")
|
||||
}
|
||||
requireCandHasPath(t, cands, "foo.go")
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_FuzzyFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("hndlr"), snap, 100)
|
||||
if len(cands) == 0 {
|
||||
t.Fatal("expected fuzzy fallback to find 'handler.go' for query 'hndlr'")
|
||||
}
|
||||
requireCandHasPath(t, cands, "src/handler.go")
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_FuzzyFallbackNoFirstCharMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"src/xylophone.go", "lib/extra.go"})
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("xylo"), snap, 100)
|
||||
if len(cands) == 0 {
|
||||
t.Fatal("expected at least 1 candidate for 'xylo'")
|
||||
}
|
||||
requireCandHasPath(t, cands, "src/xylophone.go")
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_NilSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("foo"), nil, 100)
|
||||
if cands != nil {
|
||||
t.Errorf("expected nil for nil snapshot, got %v", cands)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_EmptyQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"foo.go"})
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest(""), snap, 100)
|
||||
if cands != nil {
|
||||
t.Errorf("expected nil for empty query, got %v", cands)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_DeletedDocsExcluded(t *testing.T) {
|
||||
t.Parallel()
|
||||
idx := filefinder.NewIndex()
|
||||
idx.Add("handler.go", 0)
|
||||
idx.Remove("handler.go")
|
||||
snap := idx.Snapshot()
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
|
||||
for _, c := range cands {
|
||||
if c.Path == "handler.go" {
|
||||
t.Error("deleted doc should not appear in results")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_Limit(t *testing.T) {
|
||||
t.Parallel()
|
||||
paths := make([]string, 50)
|
||||
for i := range paths {
|
||||
paths[i] = "handler" + string(rune('a'+i%26)) + ".go"
|
||||
}
|
||||
snap := filefinder.MakeTestSnapshot(paths)
|
||||
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 3)
|
||||
if len(cands) > 3 {
|
||||
t.Errorf("expected at most 3 candidates, got %d", len(cands))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntersectSorted(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b []uint32
|
||||
want []uint32
|
||||
}{
|
||||
{"both empty", nil, nil, nil},
|
||||
{"a empty", nil, []uint32{1, 2}, nil},
|
||||
{"b empty", []uint32{1, 2}, nil, nil},
|
||||
{"no overlap", []uint32{1, 3, 5}, []uint32{2, 4, 6}, nil},
|
||||
{"full overlap", []uint32{1, 2, 3}, []uint32{1, 2, 3}, []uint32{1, 2, 3}},
|
||||
{"partial overlap", []uint32{1, 2, 3, 5}, []uint32{2, 4, 5}, []uint32{2, 5}},
|
||||
{"single match", []uint32{1, 2, 3}, []uint32{2}, []uint32{2}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.IntersectSortedForTest(tt.a, tt.b)
|
||||
if len(tt.want) == 0 {
|
||||
if len(got) != 0 {
|
||||
t.Errorf("got %v, want empty/nil", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntersectAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := filefinder.IntersectAllForTest(nil); got != nil {
|
||||
t.Errorf("got %v, want nil", got)
|
||||
}
|
||||
})
|
||||
t.Run("single", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3}}); len(got) != 3 {
|
||||
t.Fatalf("len = %d, want 3", len(got))
|
||||
}
|
||||
})
|
||||
t.Run("multiple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3, 4, 5}, {2, 3, 5}, {3, 5, 7}})
|
||||
if !slices.Equal(got, []uint32{3, 5}) {
|
||||
t.Errorf("got %v, want [3 5]", got)
|
||||
}
|
||||
})
|
||||
t.Run("no overlap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2}, {3, 4}}); got != nil {
|
||||
t.Errorf("got %v, want nil", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeAndScore_SortedDescending(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest("foo")
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
cands := []filefinder.CandidateForTest{
|
||||
{DocID: 0, Path: "a/b/c/d/e/foo", BaseOff: 10, BaseLen: 3, Depth: 5},
|
||||
{DocID: 1, Path: "src/foo", BaseOff: 4, BaseLen: 3, Depth: 1},
|
||||
{DocID: 2, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0},
|
||||
}
|
||||
results := filefinder.MergeAndScoreForTest(cands, plan, params, 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected non-empty results")
|
||||
}
|
||||
for i := 1; i < len(results); i++ {
|
||||
if results[i].Score > results[i-1].Score {
|
||||
t.Errorf("results not sorted: [%d].Score=%f > [%d].Score=%f",
|
||||
i, results[i].Score, i-1, results[i-1].Score)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAndScore_TopKLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest("f")
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
var cands []filefinder.CandidateForTest
|
||||
for i := range 20 {
|
||||
p := "f" + string(rune('a'+i))
|
||||
cands = append(cands, filefinder.CandidateForTest{DocID: uint32(i), Path: p, BaseOff: 0, BaseLen: len(p), Depth: 0}) //nolint:gosec // test index is tiny
|
||||
}
|
||||
if results := filefinder.MergeAndScoreForTest(cands, plan, params, 5); len(results) != 5 {
|
||||
t.Errorf("expected 5 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAndScore_ZeroTopK(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest("foo")
|
||||
cands := []filefinder.CandidateForTest{{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0}}
|
||||
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 0); len(results) != 0 {
|
||||
t.Errorf("expected 0 results for topK=0, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAndScore_NoMatchCandidatesDropped(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest("xyz")
|
||||
cands := []filefinder.CandidateForTest{
|
||||
{DocID: 0, Path: "abc", BaseOff: 0, BaseLen: 3, Depth: 0},
|
||||
{DocID: 1, Path: "def", BaseOff: 0, BaseLen: 3, Depth: 0},
|
||||
}
|
||||
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
|
||||
t.Errorf("expected 0 results for non-matching candidates, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAndScore_IsDirFlag(t *testing.T) {
|
||||
t.Parallel()
|
||||
plan := filefinder.NewQueryPlanForTest("foo")
|
||||
cands := []filefinder.CandidateForTest{
|
||||
{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0, Flags: uint16(filefinder.FlagDir)},
|
||||
}
|
||||
results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10)
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
if !results[0].IsDir {
|
||||
t.Error("expected IsDir=true for FlagDir candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAndScore_EmptyCandidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
if results := filefinder.MergeAndScoreForTest(nil, filefinder.NewQueryPlanForTest("foo"), filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
|
||||
t.Errorf("expected 0 results for nil candidates, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchSnapshot_FuzzyFallbackEndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/middleware.go", "pkg/config.go"})
|
||||
plan := filefinder.NewQueryPlanForTest("hndlr")
|
||||
results := filefinder.MergeAndScoreForTest(filefinder.SearchSnapshotForTest(plan, snap, 100), plan, filefinder.DefaultScoreParamsForTest(), 10)
|
||||
if len(results) == 0 {
|
||||
t.Fatal("expected fuzzy fallback to produce scored results for 'hndlr'")
|
||||
}
|
||||
if results[0].Path != "src/handler.go" {
|
||||
t.Errorf("expected top result 'src/handler.go', got %q", results[0].Path)
|
||||
}
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
package filefinder
|
||||
|
||||
import "slices"
|
||||
|
||||
func toLowerASCII(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + ('a' - 'A')
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func normalizeQuery(q string) string {
|
||||
b := make([]byte, 0, len(q))
|
||||
prevSpace := true
|
||||
for i := 0; i < len(q); i++ {
|
||||
c := q[i]
|
||||
if c == '\\' {
|
||||
c = '/'
|
||||
}
|
||||
c = toLowerASCII(c)
|
||||
if c == ' ' {
|
||||
if prevSpace {
|
||||
continue
|
||||
}
|
||||
prevSpace = true
|
||||
} else {
|
||||
prevSpace = false
|
||||
}
|
||||
b = append(b, c)
|
||||
}
|
||||
if len(b) > 0 && b[len(b)-1] == ' ' {
|
||||
b = b[:len(b)-1]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizePathBytes(p []byte) []byte {
|
||||
j := 0
|
||||
prevSlash := false
|
||||
for i := 0; i < len(p); i++ {
|
||||
c := p[i]
|
||||
if c == '\\' {
|
||||
c = '/'
|
||||
}
|
||||
c = toLowerASCII(c)
|
||||
if c == '/' {
|
||||
if prevSlash {
|
||||
continue
|
||||
}
|
||||
prevSlash = true
|
||||
} else {
|
||||
prevSlash = false
|
||||
}
|
||||
p[j] = c
|
||||
j++
|
||||
}
|
||||
return p[:j]
|
||||
}
|
||||
|
||||
// extractTrigrams returns deduplicated, sorted trigrams (three-byte
|
||||
// subsequences) from s. Trigrams are the primary index key: a
|
||||
// document matches a query only if every query trigram appears in
|
||||
// the document, giving O(1) candidate filtering per trigram.
|
||||
func extractTrigrams(s []byte) []uint32 {
|
||||
if len(s) < 3 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[uint32]struct{}, len(s))
|
||||
for i := 0; i <= len(s)-3; i++ {
|
||||
b0 := toLowerASCII(s[i])
|
||||
b1 := toLowerASCII(s[i+1])
|
||||
b2 := toLowerASCII(s[i+2])
|
||||
gram := uint32(b0)<<16 | uint32(b1)<<8 | uint32(b2)
|
||||
seen[gram] = struct{}{}
|
||||
}
|
||||
result := make([]uint32, 0, len(seen))
|
||||
for g := range seen {
|
||||
result = append(result, g)
|
||||
}
|
||||
slices.Sort(result)
|
||||
return result
|
||||
}
|
||||
|
||||
func extractBasename(path []byte) (offset int, length int) {
|
||||
end := len(path)
|
||||
if end > 0 && path[end-1] == '/' {
|
||||
end--
|
||||
}
|
||||
if end == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
i := end - 1
|
||||
for i >= 0 && path[i] != '/' {
|
||||
i--
|
||||
}
|
||||
start := i + 1
|
||||
return start, end - start
|
||||
}
|
||||
|
||||
func extractSegments(path []byte) [][]byte {
|
||||
var segments [][]byte
|
||||
start := 0
|
||||
for i := 0; i <= len(path); i++ {
|
||||
if i == len(path) || path[i] == '/' {
|
||||
if i > start {
|
||||
segments = append(segments, path[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
func prefix1(name []byte) byte {
|
||||
if len(name) == 0 {
|
||||
return 0
|
||||
}
|
||||
return toLowerASCII(name[0])
|
||||
}
|
||||
|
||||
func prefix2(name []byte) uint16 {
|
||||
if len(name) == 0 {
|
||||
return 0
|
||||
}
|
||||
hi := uint16(toLowerASCII(name[0])) << 8
|
||||
if len(name) < 2 {
|
||||
return hi
|
||||
}
|
||||
return hi | uint16(toLowerASCII(name[1]))
|
||||
}
|
||||
|
||||
// scoreParams controls the weights for each scoring signal.
|
||||
type scoreParams struct {
|
||||
BasenameMatch float32
|
||||
BasenamePrefix float32
|
||||
ExactSegment float32
|
||||
BoundaryHit float32
|
||||
ContiguousRun float32
|
||||
DirTokenHit float32
|
||||
DepthPenalty float32
|
||||
LengthPenalty float32
|
||||
}
|
||||
|
||||
func defaultScoreParams() scoreParams {
|
||||
return scoreParams{
|
||||
BasenameMatch: 6.0,
|
||||
BasenamePrefix: 3.5,
|
||||
ExactSegment: 2.5,
|
||||
BoundaryHit: 1.8,
|
||||
ContiguousRun: 1.2,
|
||||
DirTokenHit: 0.4,
|
||||
DepthPenalty: 0.08,
|
||||
LengthPenalty: 0.01,
|
||||
}
|
||||
}
|
||||
|
||||
func isSubsequence(haystack, needle []byte) bool {
|
||||
if len(needle) == 0 {
|
||||
return true
|
||||
}
|
||||
ni := 0
|
||||
for _, hb := range haystack {
|
||||
if toLowerASCII(hb) == toLowerASCII(needle[ni]) {
|
||||
ni++
|
||||
if ni == len(needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func longestContiguousMatch(haystack, needle []byte) int {
|
||||
if len(needle) == 0 || len(haystack) == 0 {
|
||||
return 0
|
||||
}
|
||||
best := 0
|
||||
ni := 0
|
||||
run := 0
|
||||
for _, hb := range haystack {
|
||||
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
|
||||
run++
|
||||
ni++
|
||||
if run > best {
|
||||
best = run
|
||||
}
|
||||
} else {
|
||||
run = 0
|
||||
ni = 0
|
||||
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
|
||||
run = 1
|
||||
ni = 1
|
||||
if run > best {
|
||||
best = run
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
func isBoundary(b byte) bool {
|
||||
return b == '/' || b == '.' || b == '_' || b == '-'
|
||||
}
|
||||
|
||||
func countBoundaryHits(path []byte, query []byte) int {
|
||||
if len(query) == 0 || len(path) == 0 {
|
||||
return 0
|
||||
}
|
||||
hits := 0
|
||||
qi := 0
|
||||
for pi := 0; pi < len(path) && qi < len(query); pi++ {
|
||||
atBoundary := pi == 0 || isBoundary(path[pi-1])
|
||||
if atBoundary && toLowerASCII(path[pi]) == toLowerASCII(query[qi]) {
|
||||
hits++
|
||||
qi++
|
||||
}
|
||||
}
|
||||
return hits
|
||||
}
|
||||
|
||||
func equalFoldASCII(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if toLowerASCII(a[i]) != toLowerASCII(b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasPrefixFoldASCII(haystack, prefix []byte) bool {
|
||||
if len(prefix) > len(haystack) {
|
||||
return false
|
||||
}
|
||||
for i := range prefix {
|
||||
if toLowerASCII(haystack[i]) != toLowerASCII(prefix[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// scorePath computes a relevance score for a candidate path
|
||||
// against a query. The score combines several signals:
|
||||
// basename match, basename prefix, exact segment match,
|
||||
// word-boundary hits, longest contiguous run, and penalties
|
||||
// for depth and length. A return value of 0 means no match
|
||||
// (the query is not a subsequence of the path).
|
||||
func scorePath(
|
||||
path []byte,
|
||||
baseOff int,
|
||||
baseLen int,
|
||||
depth int,
|
||||
query []byte,
|
||||
queryTokens [][]byte,
|
||||
params scoreParams,
|
||||
) float32 {
|
||||
if !isSubsequence(path, query) {
|
||||
return 0
|
||||
}
|
||||
var score float32
|
||||
basename := path[baseOff : baseOff+baseLen]
|
||||
if isSubsequence(basename, query) {
|
||||
score += params.BasenameMatch
|
||||
}
|
||||
if hasPrefixFoldASCII(basename, query) {
|
||||
score += params.BasenamePrefix
|
||||
}
|
||||
segments := extractSegments(path)
|
||||
for _, token := range queryTokens {
|
||||
for _, seg := range segments {
|
||||
if equalFoldASCII(seg, token) {
|
||||
score += params.ExactSegment
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
bh := countBoundaryHits(path, query)
|
||||
score += float32(bh) * params.BoundaryHit
|
||||
lcm := longestContiguousMatch(path, query)
|
||||
score += float32(lcm) * params.ContiguousRun
|
||||
score -= float32(depth) * params.DepthPenalty
|
||||
score -= float32(len(path)) * params.LengthPenalty
|
||||
return score
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
package filefinder_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/v2/agent/filefinder"
|
||||
)
|
||||
|
||||
func TestNormalizeQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"empty", "", ""},
|
||||
{"leading and trailing spaces", " hello ", "hello"},
|
||||
{"multiple internal spaces", "foo bar baz", "foo bar baz"},
|
||||
{"uppercase to lower", "FooBar", "foobar"},
|
||||
{"backslash to slash", `foo\bar\baz`, "foo/bar/baz"},
|
||||
{"mixed case and spaces", " Hello World ", "hello world"},
|
||||
{"unicode passthrough", "héllo wörld", "héllo wörld"},
|
||||
{"only spaces", " ", ""},
|
||||
{"single char", "A", "a"},
|
||||
{"slashes preserved", "/foo/bar/", "/foo/bar/"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.NormalizeQueryForTest(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTrigrams(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []uint32
|
||||
}{
|
||||
{"too short", "ab", nil},
|
||||
{"exactly three bytes", "abc", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
|
||||
{"case insensitive", "ABC", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
|
||||
{"deduplication", "aaaa", []uint32{uint32('a')<<16 | uint32('a')<<8 | uint32('a')}},
|
||||
{"four bytes produces two trigrams", "abcd", []uint32{
|
||||
uint32('a')<<16 | uint32('b')<<8 | uint32('c'),
|
||||
uint32('b')<<16 | uint32('c')<<8 | uint32('d'),
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.ExtractTrigramsForTest([]byte(tt.input))
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("extractTrigrams(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBasename(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantOff int
|
||||
wantName string
|
||||
}{
|
||||
{"full path", "/foo/bar/baz.go", 9, "baz.go"},
|
||||
{"bare filename", "baz.go", 0, "baz.go"},
|
||||
{"trailing slash", "/a/b/", 3, "b"},
|
||||
{"root slash", "/", 0, ""},
|
||||
{"empty", "", 0, ""},
|
||||
{"single dir with slash", "/foo", 1, "foo"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
off, length := filefinder.ExtractBasenameForTest([]byte(tt.path))
|
||||
if off != tt.wantOff {
|
||||
t.Errorf("extractBasename(%q) offset = %d, want %d", tt.path, off, tt.wantOff)
|
||||
}
|
||||
gotName := string([]byte(tt.path)[off : off+length])
|
||||
if gotName != tt.wantName {
|
||||
t.Errorf("extractBasename(%q) name = %q, want %q", tt.path, gotName, tt.wantName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSegments(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want []string
|
||||
}{
|
||||
{"absolute path", "/foo/bar/baz", []string{"foo", "bar", "baz"}},
|
||||
{"relative path", "foo/bar", []string{"foo", "bar"}},
|
||||
{"trailing slash", "/a/b/", []string{"a", "b"}},
|
||||
{"multiple slashes", "//a///b//", []string{"a", "b"}},
|
||||
{"empty", "", nil},
|
||||
{"single segment", "foo", []string{"foo"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.ExtractSegmentsForTest([]byte(tt.path))
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("extractSegments(%q) got %d segments, want %d", tt.path, len(got), len(tt.want))
|
||||
}
|
||||
for i := range got {
|
||||
if string(got[i]) != tt.want[i] {
|
||||
t.Errorf("extractSegments(%q)[%d] = %q, want %q", tt.path, i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefix1(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want byte
|
||||
}{
|
||||
{"lowercase", "foo", 'f'},
|
||||
{"uppercase", "Foo", 'f'},
|
||||
{"empty", "", 0},
|
||||
{"digit", "1abc", '1'},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.Prefix1ForTest([]byte(tt.in))
|
||||
if got != tt.want {
|
||||
t.Errorf("prefix1(%q) = %d (%c), want %d (%c)", tt.in, got, got, tt.want, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrefix2(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want uint16
|
||||
}{
|
||||
{"two chars", "ab", uint16('a')<<8 | uint16('b')},
|
||||
{"uppercase", "AB", uint16('a')<<8 | uint16('b')},
|
||||
{"single char", "A", uint16('a') << 8},
|
||||
{"empty", "", 0},
|
||||
{"longer string", "Hello", uint16('h')<<8 | uint16('e')},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.Prefix2ForTest([]byte(tt.in))
|
||||
if got != tt.want {
|
||||
t.Errorf("prefix2(%q) = %d, want %d", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePathBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"backslash to slash", `C:\Users\test`, "c:/users/test"},
|
||||
{"collapse slashes", "//foo///bar//", "/foo/bar/"},
|
||||
{"lowercase", "FooBar", "foobar"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
buf := []byte(tt.input)
|
||||
got := string(filefinder.NormalizePathBytesForTest(buf))
|
||||
if got != tt.want {
|
||||
t.Errorf("normalizePathBytes(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubsequence(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
haystack string
|
||||
needle string
|
||||
want bool
|
||||
}{
|
||||
{"empty needle", "anything", "", true},
|
||||
{"empty both", "", "", true},
|
||||
{"empty haystack", "", "a", false},
|
||||
{"exact match", "abc", "abc", true},
|
||||
{"scattered", "axbycz", "abc", true},
|
||||
{"prefix", "abcdef", "abc", true},
|
||||
{"suffix", "xyzabc", "abc", true},
|
||||
{"case insensitive", "AbCdEf", "ace", true},
|
||||
{"case insensitive reverse", "abcdef", "ACE", true},
|
||||
{"no match", "abcdef", "xyz", false},
|
||||
{"partial match", "abcdef", "abz", false},
|
||||
{"longer needle", "ab", "abc", false},
|
||||
{"single char match", "hello", "l", true},
|
||||
{"single char no match", "hello", "z", false},
|
||||
{"path like", "src/internal/foo.go", "sif", true},
|
||||
{"path like no match", "src/internal/foo.go", "zzz", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.IsSubsequenceForTest([]byte(tt.haystack), []byte(tt.needle))
|
||||
if got != tt.want {
|
||||
t.Errorf("isSubsequence(%q, %q) = %v, want %v", tt.haystack, tt.needle, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLongestContiguousMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
haystack string
|
||||
needle string
|
||||
want int
|
||||
}{
|
||||
{"empty needle", "abc", "", 0},
|
||||
{"empty haystack", "", "abc", 0},
|
||||
{"full match", "abc", "abc", 3},
|
||||
{"prefix match", "abcdef", "abc", 3},
|
||||
{"middle match", "xxabcyy", "abc", 3},
|
||||
{"suffix match", "xxabc", "abc", 3},
|
||||
{"partial", "axbc", "abc", 1},
|
||||
{"scattered no contiguous", "axbxcx", "abc", 1},
|
||||
{"case insensitive", "ABCdef", "abc", 3},
|
||||
{"no match", "xyz", "abc", 0},
|
||||
{"single char", "abc", "b", 1},
|
||||
{"repeated", "aababc", "abc", 3},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.LongestContiguousMatchForTest([]byte(tt.haystack), []byte(tt.needle))
|
||||
if got != tt.want {
|
||||
t.Errorf("longestContiguousMatch(%q, %q) = %d, want %d", tt.haystack, tt.needle, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBoundary(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, b := range []byte{'/', '.', '_', '-'} {
|
||||
if !filefinder.IsBoundaryForTest(b) {
|
||||
t.Errorf("isBoundary(%q) = false, want true", b)
|
||||
}
|
||||
}
|
||||
for _, b := range []byte{'a', 'Z', '0', ' ', '('} {
|
||||
if filefinder.IsBoundaryForTest(b) {
|
||||
t.Errorf("isBoundary(%q) = true, want false", b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountBoundaryHits(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
query string
|
||||
want int
|
||||
}{
|
||||
{"start of string", "foo/bar", "f", 1},
|
||||
{"after slash", "foo/bar", "fb", 2},
|
||||
{"after dot", "foo.bar", "fb", 2},
|
||||
{"after underscore", "foo_bar", "fb", 2},
|
||||
{"no hits", "xxxx", "y", 0},
|
||||
{"empty query", "foo", "", 0},
|
||||
{"empty path", "", "f", 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := filefinder.CountBoundaryHitsForTest([]byte(tt.path), []byte(tt.query))
|
||||
if got != tt.want {
|
||||
t.Errorf("countBoundaryHits(%q, %q) = %d, want %d", tt.path, tt.query, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorePath_NoSubsequenceReturnsZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
path := []byte("src/internal/handler.go")
|
||||
query := []byte("zzz")
|
||||
tokens := [][]byte{[]byte("zzz")}
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
s := filefinder.ScorePathForTest(path, 13, 10, 2, query, tokens, params)
|
||||
if s != 0 {
|
||||
t.Errorf("expected 0 for no subsequence match, got %f", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorePath_ExactBasenameOverPartial(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
query := []byte("main")
|
||||
tokens := [][]byte{query}
|
||||
pathExact := []byte("src/main")
|
||||
scoreExact := filefinder.ScorePathForTest(pathExact, 4, 4, 1, query, tokens, params)
|
||||
pathPartial := []byte("module/amazing")
|
||||
scorePartial := filefinder.ScorePathForTest(pathPartial, 7, 7, 1, query, tokens, params)
|
||||
if scoreExact <= scorePartial {
|
||||
t.Errorf("exact basename (%f) should score higher than partial (%f)", scoreExact, scorePartial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorePath_BasenamePrefixOverScattered(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
query := []byte("han")
|
||||
tokens := [][]byte{query}
|
||||
pathPrefix := []byte("src/handler.go")
|
||||
scorePrefix := filefinder.ScorePathForTest(pathPrefix, 4, 10, 1, query, tokens, params)
|
||||
pathScattered := []byte("has/another/thing")
|
||||
scoreScattered := filefinder.ScorePathForTest(pathScattered, 12, 5, 2, query, tokens, params)
|
||||
if scorePrefix <= scoreScattered {
|
||||
t.Errorf("basename prefix (%f) should score higher than scattered (%f)", scorePrefix, scoreScattered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorePath_ShallowOverDeep(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
query := []byte("foo")
|
||||
tokens := [][]byte{query}
|
||||
pathShallow := []byte("src/foo.go")
|
||||
scoreShallow := filefinder.ScorePathForTest(pathShallow, 4, 6, 1, query, tokens, params)
|
||||
pathDeep := []byte("a/b/c/d/e/foo.go")
|
||||
scoreDeep := filefinder.ScorePathForTest(pathDeep, 10, 6, 5, query, tokens, params)
|
||||
if scoreShallow <= scoreDeep {
|
||||
t.Errorf("shallow path (%f) should score higher than deep (%f)", scoreShallow, scoreDeep)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorePath_ShorterOverLongerSameMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
query := []byte("foo")
|
||||
tokens := [][]byte{query}
|
||||
pathShort := []byte("x/foo")
|
||||
scoreShort := filefinder.ScorePathForTest(pathShort, 2, 3, 1, query, tokens, params)
|
||||
pathLong := []byte("x/foo_extremely_long_suffix_name")
|
||||
scoreLong := filefinder.ScorePathForTest(pathLong, 2, 29, 1, query, tokens, params)
|
||||
if scoreShort <= scoreLong {
|
||||
t.Errorf("shorter path (%f) should score higher than longer (%f)", scoreShort, scoreLong)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkScorePath(b *testing.B) {
|
||||
path := []byte("src/internal/coderd/database/queries/workspaces.sql")
|
||||
query := []byte("workspace")
|
||||
tokens := [][]byte{query}
|
||||
params := filefinder.DefaultScoreParamsForTest()
|
||||
baseOff, baseLen := filefinder.ExtractBasenameForTest(path)
|
||||
s := filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
|
||||
if s == 0 {
|
||||
b.Fatal("expected non-zero score for benchmark path")
|
||||
}
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
|
||||
}
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
package filefinder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// FSEvent represents a filesystem change event.
|
||||
type FSEvent struct {
|
||||
Op FSEventOp
|
||||
Path string
|
||||
IsDir bool
|
||||
}
|
||||
|
||||
// FSEventOp represents the type of filesystem operation.
|
||||
type FSEventOp uint8
|
||||
|
||||
// Filesystem operations reported by the watcher.
|
||||
const (
|
||||
OpCreate FSEventOp = iota
|
||||
OpRemove
|
||||
OpRename
|
||||
OpModify
|
||||
)
|
||||
|
||||
var skipDirs = map[string]struct{}{
|
||||
".git": {}, "node_modules": {}, ".hg": {}, ".svn": {},
|
||||
"__pycache__": {}, ".cache": {}, ".venv": {}, "vendor": {}, ".terraform": {},
|
||||
}
|
||||
|
||||
type fsWatcher struct {
|
||||
w *fsnotify.Watcher
|
||||
root string
|
||||
events chan []FSEvent
|
||||
logger slog.Logger
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newFSWatcher(root string, logger slog.Logger) (*fsWatcher, error) {
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fsWatcher{
|
||||
w: w,
|
||||
root: root,
|
||||
events: make(chan []FSEvent, 64),
|
||||
logger: logger,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fw *fsWatcher) Start(ctx context.Context) {
|
||||
initEvents := fw.addRecursive(fw.root)
|
||||
if len(initEvents) > 0 {
|
||||
select {
|
||||
case fw.events <- initEvents:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
fw.logger.Debug(ctx, "fs watcher started", slog.F("root", fw.root))
|
||||
go fw.loop(ctx)
|
||||
}
|
||||
func (fw *fsWatcher) Events() <-chan []FSEvent { return fw.events }
|
||||
func (fw *fsWatcher) Close() error {
|
||||
fw.mu.Lock()
|
||||
if fw.closed {
|
||||
fw.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
fw.closed = true
|
||||
fw.mu.Unlock()
|
||||
err := fw.w.Close()
|
||||
<-fw.done
|
||||
return err
|
||||
}
|
||||
|
||||
func (fw *fsWatcher) loop(ctx context.Context) {
|
||||
defer close(fw.done)
|
||||
const batchWindow = 50 * time.Millisecond
|
||||
var (
|
||||
batch []FSEvent
|
||||
seen = make(map[string]struct{})
|
||||
timer *time.Timer
|
||||
timerC <-chan time.Time
|
||||
)
|
||||
flush := func() {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case fw.events <- batch:
|
||||
default:
|
||||
fw.logger.Warn(ctx, "fs watcher dropping batch", slog.F("count", len(batch)))
|
||||
}
|
||||
batch = nil
|
||||
seen = make(map[string]struct{})
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = nil
|
||||
timerC = nil
|
||||
}
|
||||
addToBatch := func(ev FSEvent) {
|
||||
if _, dup := seen[ev.Path]; dup {
|
||||
return
|
||||
}
|
||||
seen[ev.Path] = struct{}{}
|
||||
batch = append(batch, ev)
|
||||
if timer == nil {
|
||||
timer = time.NewTimer(batchWindow)
|
||||
timerC = timer.C
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flush()
|
||||
return
|
||||
case ev, ok := <-fw.w.Events:
|
||||
if !ok {
|
||||
flush()
|
||||
return
|
||||
}
|
||||
fsev := translateEvent(ev)
|
||||
if fsev == nil {
|
||||
continue
|
||||
}
|
||||
if fsev.IsDir && fsev.Op == OpCreate {
|
||||
for _, s := range fw.addRecursive(fsev.Path) {
|
||||
addToBatch(s)
|
||||
}
|
||||
}
|
||||
addToBatch(*fsev)
|
||||
case err, ok := <-fw.w.Errors:
|
||||
if !ok {
|
||||
flush()
|
||||
return
|
||||
}
|
||||
fw.logger.Warn(ctx, "fsnotify watcher error", slog.Error(err))
|
||||
case <-timerC:
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
|
||||
var events []FSEvent
|
||||
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil //nolint:nilerr // best-effort
|
||||
}
|
||||
base := filepath.Base(path)
|
||||
if _, skip := skipDirs[base]; skip && info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if info.IsDir() {
|
||||
if addErr := fw.w.Add(path); addErr != nil {
|
||||
fw.logger.Debug(context.Background(), "failed to add watch",
|
||||
slog.F("path", path), slog.Error(addErr))
|
||||
}
|
||||
if path != dir {
|
||||
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: true})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: false})
|
||||
return nil
|
||||
})
|
||||
return events
|
||||
}
|
||||
|
||||
func translateEvent(ev fsnotify.Event) *FSEvent {
|
||||
var op FSEventOp
|
||||
switch {
|
||||
case ev.Op&fsnotify.Create != 0:
|
||||
op = OpCreate
|
||||
case ev.Op&fsnotify.Remove != 0:
|
||||
op = OpRemove
|
||||
case ev.Op&fsnotify.Rename != 0:
|
||||
op = OpRename
|
||||
case ev.Op&fsnotify.Write != 0:
|
||||
op = OpModify
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
isDir := false
|
||||
if op == OpCreate || op == OpModify {
|
||||
fi, err := os.Lstat(ev.Name)
|
||||
if err == nil {
|
||||
isDir = fi.IsDir()
|
||||
}
|
||||
}
|
||||
if isDir {
|
||||
if _, skip := skipDirs[filepath.Base(ev.Name)]; skip {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &FSEvent{Op: op, Path: ev.Name, IsDir: isDir}
|
||||
}
|
||||
+330
-544
File diff suppressed because it is too large
Load Diff
+1
-20
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
|
||||
}
|
||||
|
||||
repeated DisplayApp display_apps = 6;
|
||||
|
||||
|
||||
optional bytes id = 7;
|
||||
}
|
||||
|
||||
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
|
||||
|
||||
message ReportBoundaryLogsResponse {}
|
||||
|
||||
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
|
||||
message UpdateAppStatusRequest {
|
||||
string slug = 1;
|
||||
|
||||
enum AppStatusState {
|
||||
WORKING = 0;
|
||||
IDLE = 1;
|
||||
COMPLETE = 2;
|
||||
FAILURE = 3;
|
||||
}
|
||||
AppStatusState state = 2;
|
||||
|
||||
string message = 3;
|
||||
string uri = 4;
|
||||
}
|
||||
|
||||
message UpdateAppStatusResponse {}
|
||||
|
||||
service Agent {
|
||||
rpc GetManifest(GetManifestRequest) returns (Manifest);
|
||||
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
|
||||
@@ -530,5 +512,4 @@ service Agent {
|
||||
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
|
||||
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
|
||||
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
|
||||
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
|
||||
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type drpcAgentClient struct {
|
||||
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
out := new(UpdateAppStatusResponse)
|
||||
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type DRPCAgentServer interface {
|
||||
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
|
||||
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
|
||||
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
|
||||
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
|
||||
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
|
||||
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
type DRPCAgentUnimplementedServer struct{}
|
||||
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCAgentDescription struct{}
|
||||
|
||||
func (DRPCAgentDescription) NumMethods() int { return 18 }
|
||||
func (DRPCAgentDescription) NumMethods() int { return 17 }
|
||||
|
||||
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
|
||||
in1.(*ReportBoundaryLogsRequest),
|
||||
)
|
||||
}, DRPCAgentServer.ReportBoundaryLogs, true
|
||||
case 17:
|
||||
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCAgentServer).
|
||||
UpdateAppStatus(
|
||||
ctx,
|
||||
in1.(*UpdateAppStatusRequest),
|
||||
)
|
||||
}, DRPCAgentServer.UpdateAppStatus, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCAgent_UpdateAppStatusStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*UpdateAppStatusResponse) error
|
||||
}
|
||||
|
||||
type drpcAgent_UpdateAppStatusStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
|
||||
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
|
||||
}
|
||||
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds
|
||||
// - a SubagentId field to the WorkspaceAgentDevcontainer message
|
||||
// - an Id field to the CreateSubAgentRequest message.
|
||||
// - UpdateAppStatus RPC.
|
||||
//
|
||||
// Compatible with Coder v2.31+
|
||||
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
|
||||
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
|
||||
// message. Compatible with Coder v2.31+
|
||||
type DRPCAgentClient28 interface {
|
||||
DRPCAgentClient27
|
||||
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
|
||||
}
|
||||
|
||||
+21
-25
@@ -3,11 +3,11 @@
|
||||
"enabled": true,
|
||||
"clientKind": "git",
|
||||
"useIgnoreFile": true,
|
||||
"defaultBranch": "main",
|
||||
"defaultBranch": "main"
|
||||
},
|
||||
"files": {
|
||||
"includes": ["**", "!**/pnpm-lock.yaml"],
|
||||
"ignoreUnknown": true,
|
||||
"ignoreUnknown": true
|
||||
},
|
||||
"linter": {
|
||||
"rules": {
|
||||
@@ -15,18 +15,18 @@
|
||||
"noSvgWithoutTitle": "off",
|
||||
"useButtonType": "off",
|
||||
"useSemanticElements": "off",
|
||||
"noStaticElementInteractions": "off",
|
||||
"noStaticElementInteractions": "off"
|
||||
},
|
||||
"correctness": {
|
||||
"noUnusedImports": "warn",
|
||||
"correctness": {
|
||||
"noUnusedImports": "warn",
|
||||
"useUniqueElementIds": "off", // TODO: This is new but we want to fix it
|
||||
"noNestedComponentDefinitions": "off", // TODO: Investigate, since it is used by shadcn components
|
||||
"noUnusedVariables": {
|
||||
"level": "warn",
|
||||
"noUnusedVariables": {
|
||||
"level": "warn",
|
||||
"options": {
|
||||
"ignoreRestSiblings": true,
|
||||
},
|
||||
},
|
||||
"ignoreRestSiblings": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"style": {
|
||||
"noNonNullAssertion": "off",
|
||||
@@ -45,10 +45,6 @@
|
||||
"level": "error",
|
||||
"options": {
|
||||
"paths": {
|
||||
"react": {
|
||||
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
|
||||
"importNames": ["forwardRef"],
|
||||
},
|
||||
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
|
||||
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
|
||||
// "@mui/material/Autocomplete": "Use shadcn/ui Combobox instead.",
|
||||
@@ -115,10 +111,10 @@
|
||||
"@emotion/styled": "Use Tailwind CSS instead.",
|
||||
// "@emotion/cache": "Use Tailwind CSS instead.",
|
||||
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
|
||||
"lodash": "Use lodash/<name> instead.",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lodash": "Use lodash/<name> instead."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"suspicious": {
|
||||
"noArrayIndexKey": "off",
|
||||
@@ -129,14 +125,14 @@
|
||||
"noConsole": {
|
||||
"level": "error",
|
||||
"options": {
|
||||
"allow": ["error", "info", "warn"],
|
||||
},
|
||||
},
|
||||
"allow": ["error", "info", "warn"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"complexity": {
|
||||
"noImportantStyles": "off", // TODO: check and fix !important styles
|
||||
},
|
||||
},
|
||||
"noImportantStyles": "off" // TODO: check and fix !important styles
|
||||
}
|
||||
}
|
||||
},
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
|
||||
}
|
||||
|
||||
+1
-1
@@ -489,7 +489,7 @@ func workspaceAgent() *serpent.Command {
|
||||
},
|
||||
{
|
||||
Flag: "socket-server-enabled",
|
||||
Default: "true",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_SOCKET_SERVER_ENABLED",
|
||||
Description: "Enable the agent socket server.",
|
||||
Value: serpent.BoolOf(&socketServerEnabled),
|
||||
|
||||
@@ -44,7 +44,6 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
"--agent-token", r.AgentToken,
|
||||
"--agent-url", client.URL.String(),
|
||||
"--log-dir", logDir,
|
||||
"--socket-path", testutil.AgentSocketPath(t),
|
||||
)
|
||||
|
||||
clitest.Start(t, inv)
|
||||
@@ -77,7 +76,6 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
"--agent-token", r.AgentToken,
|
||||
"--agent-url", client.URL.String(),
|
||||
"--log-dir", logDir,
|
||||
"--socket-path", testutil.AgentSocketPath(t),
|
||||
)
|
||||
// Set the subsystems for the agent.
|
||||
inv.Environ.Set(agent.EnvAgentSubsystem, fmt.Sprintf("%s,%s", codersdk.AgentSubsystemExectrace, codersdk.AgentSubsystemEnvbox))
|
||||
@@ -160,7 +158,6 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
"--agent-header", "X-Testing=agent",
|
||||
"--agent-header", "Cool-Header=Ethan was Here!",
|
||||
"--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
|
||||
"--socket-path", testutil.AgentSocketPath(t),
|
||||
)
|
||||
clitest.Start(t, agentInv)
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
|
||||
@@ -202,7 +199,6 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
"--pprof-address", "",
|
||||
"--prometheus-address", "",
|
||||
"--debug-address", "",
|
||||
"--socket-path", testutil.AgentSocketPath(t),
|
||||
)
|
||||
|
||||
clitest.Start(t, inv)
|
||||
|
||||
@@ -30,15 +30,9 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
|
||||
_, _ = fmt.Fprint(inv.Stdout, "\033[1A")
|
||||
|
||||
var defaults []string
|
||||
defaultSource := defaultValue
|
||||
if defaultSource == "" {
|
||||
defaultSource = templateVersionParameter.DefaultValue
|
||||
}
|
||||
if defaultSource != "" {
|
||||
err = json.Unmarshal([]byte(defaultSource), &defaults)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &defaults)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
values, err := RichMultiSelect(inv, RichMultiSelectOptions{
|
||||
|
||||
+45
-50
@@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
@@ -24,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
defer cancel()
|
||||
defer srv.queue.Close()
|
||||
|
||||
cliui.Infof(inv.Stderr, "Failed to watch screen events")
|
||||
// Start the reporter, watcher, and server. These are all tied to the
|
||||
// lifetime of the MCP server, which is itself tied to the lifetime of the
|
||||
// AI agent.
|
||||
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
|
||||
}
|
||||
|
||||
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
|
||||
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
|
||||
if err == nil {
|
||||
retrier.Reset()
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
// If the screen is stable, report idle.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
case event := <-eventsCh:
|
||||
switch ev := event.(type) {
|
||||
case agentapi.EventStatusChange:
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if ev.Status == agentapi.StatusStable {
|
||||
state = codersdk.WorkspaceAppStatusStateIdle
|
||||
}
|
||||
err := s.queue.Push(taskReport{
|
||||
state: state,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case agentapi.EventMessageUpdate:
|
||||
if ev.Role == agentapi.RoleUser {
|
||||
err := s.queue.Push(taskReport{
|
||||
messageID: &ev.Id,
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
})
|
||||
if err != nil {
|
||||
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
|
||||
return
|
||||
}
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
|
||||
// Add tool dependencies.
|
||||
toolOpts := []func(*toolsdk.Deps){
|
||||
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
|
||||
state := codersdk.WorkspaceAppStatusState(args.State)
|
||||
// The agent does not reliably report idle, so when AgentAPI is
|
||||
// enabled we override idle to working and let the screen watcher
|
||||
// detect the real idle via StatusStable. Final states (failure,
|
||||
// complete) are trusted from the agent since the screen watcher
|
||||
// cannot produce them.
|
||||
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
|
||||
state = codersdk.WorkspaceAppStatusStateWorking
|
||||
// The agent does not reliably report its status correctly. If AgentAPI
|
||||
// is enabled, we will always set the status to "working" when we get an
|
||||
// MCP message, and rely on the screen watcher to eventually catch the
|
||||
// idle state.
|
||||
state := codersdk.WorkspaceAppStatusStateWorking
|
||||
if s.aiAgentAPIClient == nil {
|
||||
state = codersdk.WorkspaceAppStatusState(args.State)
|
||||
}
|
||||
return s.queue.Push(taskReport{
|
||||
link: args.Link,
|
||||
|
||||
+1
-185
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
// We override idle from the agent to working, but trust final states.
|
||||
// We ignore the state from the agent and assume "working".
|
||||
{
|
||||
name: "IgnoreAgentState",
|
||||
// AI agent reports that it is finished but the summary says it is doing
|
||||
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
Message: "finished",
|
||||
},
|
||||
},
|
||||
// Agent reports failure; trusted even with AgentAPI enabled.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateFailure,
|
||||
summary: "something broke",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateFailure,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
// After failure, watcher reports stable -> idle.
|
||||
{
|
||||
event: makeStatusEvent(agentapi.StatusStable),
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateIdle,
|
||||
Message: "something broke",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Final states pass through with AgentAPI enabled.
|
||||
{
|
||||
name: "AllowFinalStates",
|
||||
tests: []test{
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateWorking,
|
||||
summary: "doing work",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateWorking,
|
||||
Message: "doing work",
|
||||
},
|
||||
},
|
||||
// Agent reports complete; not overridden.
|
||||
{
|
||||
state: codersdk.WorkspaceAppStatusStateComplete,
|
||||
summary: "all done",
|
||||
expected: &codersdk.WorkspaceAppStatus{
|
||||
State: codersdk.WorkspaceAppStatusStateComplete,
|
||||
Message: "all done",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// When AgentAPI is not being used, we accept agent state updates as-is.
|
||||
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
|
||||
<-cmdDone
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Reconnect", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test deployment and workspace.
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
|
||||
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user2.ID,
|
||||
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
|
||||
a[0].Apps = []*proto.App{
|
||||
{
|
||||
Slug: "vscode",
|
||||
},
|
||||
}
|
||||
return a
|
||||
}).Do()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
|
||||
|
||||
// Watch the workspace for changes.
|
||||
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
var lastAppStatus codersdk.WorkspaceAppStatus
|
||||
nextUpdate := func() codersdk.WorkspaceAppStatus {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
require.FailNow(t, "timed out waiting for status update")
|
||||
case w, ok := <-watcher:
|
||||
require.True(t, ok, "watch channel closed")
|
||||
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
|
||||
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
|
||||
lastAppStatus = *w.LatestAppStatus
|
||||
return lastAppStatus
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock AI AgentAPI server that supports disconnect/reconnect.
|
||||
disconnect := make(chan struct{})
|
||||
listening := make(chan func(sse codersdk.ServerSentEvent) error)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a cancelable context so we can stop the SSE sender
|
||||
// goroutine on disconnect without waiting for the HTTP
|
||||
// serve loop to cancel r.Context().
|
||||
sseCtx, sseCancel := context.WithCancel(r.Context())
|
||||
defer sseCancel()
|
||||
r = r.WithContext(sseCtx)
|
||||
|
||||
send, closed, err := httpapi.ServerSentEventSender(w, r)
|
||||
if err != nil {
|
||||
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error setting up server-sent events.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Send initial message so the watcher knows the agent is active.
|
||||
send(*makeMessageEvent(0, agentapi.RoleAgent))
|
||||
select {
|
||||
case listening <- send:
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-closed:
|
||||
case <-disconnect:
|
||||
sseCancel()
|
||||
<-closed
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
inv, _ := clitest.New(t,
|
||||
"exp", "mcp", "server",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", r.AgentToken,
|
||||
"--app-status-slug", "vscode",
|
||||
"--allowed-tools=coder_report_task",
|
||||
"--ai-agentapi-url", srv.URL,
|
||||
)
|
||||
inv = inv.WithContext(ctx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
stderr := ptytest.New(t)
|
||||
inv.Stderr = stderr.Output()
|
||||
|
||||
// Run the MCP server.
|
||||
clitest.Start(t, inv)
|
||||
|
||||
// Initialize.
|
||||
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
pty.WriteLine(payload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore init response
|
||||
|
||||
// Get first sender from the initial SSE connection.
|
||||
sender := testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// Self-report a working status via tool call.
|
||||
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got := nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "doing work", got.Message)
|
||||
|
||||
// Watcher sends stable, verify idle is reported.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
// Disconnect the SSE connection by signaling the handler to return.
|
||||
testutil.RequireSend(ctx, t, disconnect, struct{}{})
|
||||
|
||||
// Wait for the watcher to reconnect and get the new sender.
|
||||
sender = testutil.RequireReceive(ctx, t, listening)
|
||||
|
||||
// After reconnect, self-report a working status again.
|
||||
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
|
||||
pty.WriteLine(toolPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echo
|
||||
_ = pty.ReadLine(ctx) // ignore response
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
|
||||
require.Equal(t, "reconnected", got.Message)
|
||||
|
||||
// Verify the watcher still processes events after reconnect.
|
||||
err = sender(*makeStatusEvent(agentapi.StatusStable))
|
||||
require.NoError(t, err)
|
||||
got = nextUpdate()
|
||||
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
|
||||
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
templateVersionJobTimeout time.Duration
|
||||
prebuildWorkspaceTimeout time.Duration
|
||||
noCleanup bool
|
||||
provisionerTags []string
|
||||
|
||||
tracingFlags = &scaletestTracingFlags{}
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
tags, err := ParseProvisionerTags(provisionerTags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range numTemplates {
|
||||
id := strconv.Itoa(int(i))
|
||||
cfg := prebuilds.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
ProvisionerTags: tags,
|
||||
NumPresets: int(numPresets),
|
||||
NumPresetPrebuilds: int(numPresetPrebuilds),
|
||||
TemplateVersionJobTimeout: templateVersionJobTimeout,
|
||||
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
|
||||
Description: "Skip cleanup (deletion test) and leave resources intact.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "provisioner-tag",
|
||||
Description: "Specify a set of tags to target provisioner daemons.",
|
||||
Value: serpent.StringArrayOf(&provisionerTags),
|
||||
},
|
||||
}
|
||||
|
||||
tracingFlags.attach(&cmd.Options)
|
||||
|
||||
+1
-46
@@ -4,9 +4,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
@@ -19,29 +16,6 @@ import (
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// detectGitRef attempts to resolve the current git branch and remote
|
||||
// origin URL from the given working directory. These are sent to the
|
||||
// control plane so it can look up PR/diff status via the GitHub API
|
||||
// without SSHing into the workspace. Failures are silently ignored
|
||||
// since this is best-effort.
|
||||
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
|
||||
run := func(args ...string) string {
|
||||
//nolint:gosec
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
if workingDirectory != "" {
|
||||
cmd.Dir = workingDirectory
|
||||
}
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
|
||||
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
|
||||
return branch, remoteOrigin
|
||||
}
|
||||
|
||||
// gitAskpass is used by the Coder agent to automatically authenticate
|
||||
// with Git providers based on a hostname.
|
||||
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
@@ -64,21 +38,8 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
workingDirectory = ""
|
||||
}
|
||||
|
||||
// Detect the current git branch and remote origin so
|
||||
// the control plane can resolve diffs without needing
|
||||
// to SSH back into the workspace.
|
||||
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
|
||||
|
||||
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
|
||||
Match: host,
|
||||
GitBranch: gitBranch,
|
||||
GitRemoteOrigin: gitRemoteOrigin,
|
||||
ChatID: inv.Environ.Get("CODER_CHAT_ID"),
|
||||
Match: host,
|
||||
})
|
||||
if err != nil {
|
||||
var apiError *codersdk.Error
|
||||
@@ -97,12 +58,6 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
|
||||
return xerrors.Errorf("get git token: %w", err)
|
||||
}
|
||||
if token.URL != "" {
|
||||
// This is to help the agent authenticate with Git.
|
||||
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
|
||||
return cliui.ErrCanceled
|
||||
}
|
||||
|
||||
if err := openURL(inv, token.URL); err == nil {
|
||||
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
|
||||
} else {
|
||||
|
||||
+5
-1
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -232,7 +231,7 @@ next:
|
||||
continue // immutables should not be passed to consecutive builds
|
||||
}
|
||||
|
||||
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, *tvp) {
|
||||
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, tvp.Options) {
|
||||
continue // do not propagate invalid options
|
||||
}
|
||||
|
||||
@@ -298,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
|
||||
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
|
||||
}
|
||||
|
||||
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
|
||||
if !tvp.Mutable && action != WorkspaceCreate {
|
||||
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
|
||||
}
|
||||
}
|
||||
@@ -366,7 +365,7 @@ func (pr *ParameterResolver) isLastBuildParameterInvalidOption(templateVersionPa
|
||||
|
||||
for _, buildParameter := range pr.lastBuildParameters {
|
||||
if buildParameter.Name == templateVersionParameter.Name {
|
||||
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter)
|
||||
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter.Options)
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -390,31 +389,8 @@ func findWorkspaceBuildParameter(parameterName string, params []codersdk.Workspa
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, templateVersionParameter codersdk.TemplateVersionParameter) bool {
|
||||
// Multi-select parameters store values as a JSON array (e.g.
|
||||
// '["vim","emacs"]'), so we need to parse the array and validate
|
||||
// each element individually against the allowed options.
|
||||
if templateVersionParameter.Type == "list(string)" {
|
||||
var values []string
|
||||
if err := json.Unmarshal([]byte(buildParameter.Value), &values); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, v := range values {
|
||||
found := false
|
||||
for _, opt := range templateVersionParameter.Options {
|
||||
if opt.Value == v {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for _, opt := range templateVersionParameter.Options {
|
||||
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, options []codersdk.TemplateVersionParameterOption) bool {
|
||||
for _, opt := range options {
|
||||
if opt.Value == buildParameter.Value {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestIsValidTemplateParameterOption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
options := []codersdk.TemplateVersionParameterOption{
|
||||
{Name: "Vim", Value: "vim"},
|
||||
{Name: "Emacs", Value: "emacs"},
|
||||
{Name: "VS Code", Value: "vscode"},
|
||||
}
|
||||
|
||||
t.Run("SingleSelectValid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "vim"}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editor",
|
||||
Type: "string",
|
||||
Options: options,
|
||||
}
|
||||
assert.True(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
|
||||
t.Run("SingleSelectInvalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "notepad"}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editor",
|
||||
Type: "string",
|
||||
Options: options,
|
||||
}
|
||||
assert.False(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
|
||||
t.Run("MultiSelectAllValid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","emacs"]`}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editors",
|
||||
Type: "list(string)",
|
||||
Options: options,
|
||||
}
|
||||
assert.True(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
|
||||
t.Run("MultiSelectOneInvalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","notepad"]`}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editors",
|
||||
Type: "list(string)",
|
||||
Options: options,
|
||||
}
|
||||
assert.False(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
|
||||
t.Run("MultiSelectEmptyArray", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `[]`}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editors",
|
||||
Type: "list(string)",
|
||||
Options: options,
|
||||
}
|
||||
assert.True(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
|
||||
t.Run("MultiSelectInvalidJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `not-json`}
|
||||
tvp := codersdk.TemplateVersionParameter{
|
||||
Name: "editors",
|
||||
Type: "list(string)",
|
||||
Options: options,
|
||||
}
|
||||
assert.False(t, isValidTemplateParameterOption(bp, tvp))
|
||||
})
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
@@ -35,10 +35,7 @@ func TestProvisioners_Golden(t *testing.T) {
|
||||
provisioners, err := coderdAPI.Database.GetProvisionerDaemons(systemCtx)
|
||||
require.NoError(t, err)
|
||||
slices.SortFunc(provisioners, func(a, b database.ProvisionerDaemon) int {
|
||||
return cmp.Or(
|
||||
a.CreatedAt.Compare(b.CreatedAt),
|
||||
bytes.Compare(a.ID[:], b.ID[:]),
|
||||
)
|
||||
return a.CreatedAt.Compare(b.CreatedAt)
|
||||
})
|
||||
pIdx := 0
|
||||
for _, p := range provisioners {
|
||||
@@ -50,10 +47,7 @@ func TestProvisioners_Golden(t *testing.T) {
|
||||
jobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
|
||||
require.NoError(t, err)
|
||||
slices.SortFunc(jobs, func(a, b database.ProvisionerJob) int {
|
||||
return cmp.Or(
|
||||
a.CreatedAt.Compare(b.CreatedAt),
|
||||
bytes.Compare(a.ID[:], b.ID[:]),
|
||||
)
|
||||
return a.CreatedAt.Compare(b.CreatedAt)
|
||||
})
|
||||
jIdx := 0
|
||||
for _, j := range jobs {
|
||||
@@ -82,15 +76,11 @@ func TestProvisioners_Golden(t *testing.T) {
|
||||
firstProvisioner := coderdtest.NewTaggedProvisionerDaemon(t, coderdAPI, "default-provisioner", map[string]string{"owner": "", "scope": "organization"})
|
||||
t.Cleanup(func() { _ = firstProvisioner.Close() })
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, completeWithAgent())
|
||||
version = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status,
|
||||
"template version import should succeed, got error: %s", version.Job.Error)
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
wb := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
require.Equal(t, codersdk.ProvisionerJobSucceeded, wb.Job.Status,
|
||||
"workspace build job should succeed, got error: %s", wb.Job.Error)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// Stop the provisioner so it doesn't grab any more jobs.
|
||||
firstProvisioner.Close()
|
||||
@@ -99,17 +89,7 @@ func TestProvisioners_Golden(t *testing.T) {
|
||||
replace[version.ID.String()] = "00000000-0000-0000-cccc-000000000000"
|
||||
replace[workspace.LatestBuild.ID.String()] = "00000000-0000-0000-dddd-000000000000"
|
||||
|
||||
// Base synthetic times off the latest real job's CreatedAt, not the
|
||||
// wall clock. Using dbtime.Now() here is racy because NTP clock
|
||||
// steps can make it return a time before the real jobs' CreatedAt.
|
||||
systemCtx := dbauthz.AsSystemRestricted(context.Background())
|
||||
existingJobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, existingJobs, "expected at least one provisioner job")
|
||||
latestJob := slices.MaxFunc(existingJobs, func(a, b database.ProvisionerJob) int {
|
||||
return a.CreatedAt.Compare(b.CreatedAt)
|
||||
})
|
||||
now := latestJob.CreatedAt.Add(time.Second)
|
||||
now := dbtime.Now()
|
||||
|
||||
// Create a provisioner that's working on a job.
|
||||
pd1 := dbgen.ProvisionerDaemon(t, coderdAPI.Database, database.ProvisionerDaemon{
|
||||
|
||||
+4
-15
@@ -884,27 +884,16 @@ func (o *OrganizationContext) Selected(inv *serpent.Invocation, client *codersdk
|
||||
index := slices.IndexFunc(orgs, func(org codersdk.Organization) bool {
|
||||
return org.Name == o.FlagSelect || org.ID.String() == o.FlagSelect
|
||||
})
|
||||
if index >= 0 {
|
||||
return orgs[index], nil
|
||||
}
|
||||
|
||||
// Not in membership list - try direct fetch.
|
||||
// This allows site-wide admins (e.g., Owners) to use orgs they aren't
|
||||
// members of.
|
||||
org, err := client.OrganizationByName(inv.Context(), o.FlagSelect)
|
||||
if err != nil {
|
||||
if index < 0 {
|
||||
var names []string
|
||||
for _, org := range orgs {
|
||||
names = append(names, org.Name)
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound {
|
||||
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
|
||||
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
|
||||
}
|
||||
return codersdk.Organization{}, xerrors.Errorf("get organization %q: %w", o.FlagSelect, err)
|
||||
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
|
||||
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
|
||||
}
|
||||
return org, nil
|
||||
return orgs[index], nil
|
||||
}
|
||||
|
||||
if len(orgs) == 1 {
|
||||
|
||||
+45
-129
@@ -95,7 +95,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/webpush"
|
||||
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
|
||||
"github.com/coder/coder/v2/coderd/workspacestats"
|
||||
"github.com/coder/coder/v2/coderd/wsbuilder"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
@@ -137,15 +136,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
|
||||
}
|
||||
|
||||
if vals.OIDC.RedirectURL.String() != "" {
|
||||
redirectURL, err = vals.OIDC.RedirectURL.Value().Parse("/api/v2/users/oidc/callback")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse oidc redirect url %q", err)
|
||||
}
|
||||
logger.Warn(ctx, "custom OIDC redirect URL used instead of 'access_url', ensure this matches the value configured in your OIDC provider")
|
||||
}
|
||||
|
||||
// If the scopes contain 'groups', we enable group support.
|
||||
// Do not override any custom value set by the user.
|
||||
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
|
||||
@@ -617,8 +607,28 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
}
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
|
||||
promRegistry := prometheus.NewRegistry()
|
||||
oauthInstrument := promoauth.NewFactory(promRegistry)
|
||||
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
|
||||
externalAuthConfigs, err := externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
vals.ExternalAuthConfigs.Value,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range externalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
|
||||
if err != nil {
|
||||
@@ -649,7 +659,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
Pubsub: nil,
|
||||
CacheDir: cacheDir,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
ExternalAuthConfigs: nil,
|
||||
ExternalAuthConfigs: externalAuthConfigs,
|
||||
RealIPConfig: realIPConfig,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TracerProvider: tracerProvider,
|
||||
@@ -809,43 +819,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("set deployment id: %w", err)
|
||||
}
|
||||
|
||||
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("read external auth providers from env: %w", err)
|
||||
}
|
||||
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
|
||||
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
|
||||
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
|
||||
|
||||
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx,
|
||||
options.Logger,
|
||||
options.Database,
|
||||
vals,
|
||||
mergedExternalAuthProviders,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
|
||||
}
|
||||
|
||||
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
|
||||
oauthInstrument,
|
||||
mergedExternalAuthProviders,
|
||||
vals.AccessURL.Value(),
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("convert external auth config: %w", err)
|
||||
}
|
||||
for _, c := range options.ExternalAuthConfigs {
|
||||
logger.Debug(
|
||||
ctx, "loaded external auth config",
|
||||
slog.F("id", c.ID),
|
||||
)
|
||||
}
|
||||
|
||||
// Manage push notifications.
|
||||
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
|
||||
if experiments.Enabled(codersdk.ExperimentWebPush) || buildinfo.IsDev() {
|
||||
if experiments.Enabled(codersdk.ExperimentWebPush) {
|
||||
if !strings.HasPrefix(options.AccessURL.String(), "https://") {
|
||||
options.Logger.Warn(ctx, "access URL is not HTTPS, so web push notifications may not work on some browsers", slog.F("access_url", options.AccessURL.String()))
|
||||
}
|
||||
@@ -959,12 +935,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
options.StatsBatcher = batcher
|
||||
defer closeBatcher()
|
||||
|
||||
wsBuilderMetrics, err := wsbuilder.NewMetrics(options.PrometheusRegistry)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to register workspace builder metrics: %w", err)
|
||||
}
|
||||
options.WorkspaceBuilderMetrics = wsBuilderMetrics
|
||||
|
||||
// Manage notifications.
|
||||
var (
|
||||
notificationsCfg = options.DeploymentValues.Notifications
|
||||
@@ -1148,7 +1118,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
|
||||
defer autobuildTicker.Stop()
|
||||
autobuildExecutor := autobuild.NewExecutor(
|
||||
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments, coderAPI.WorkspaceBuilderMetrics)
|
||||
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments)
|
||||
autobuildExecutor.Run()
|
||||
|
||||
jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value())
|
||||
@@ -1940,79 +1910,6 @@ type githubOAuth2ConfigParams struct {
|
||||
enterpriseBaseURL string
|
||||
}
|
||||
|
||||
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
|
||||
// We want to enable the default provider only for new deployments, and avoid
|
||||
// enabling it if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return false, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return false, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultEligible, nil
|
||||
}
|
||||
|
||||
func maybeAppendDefaultGithubExternalAuthProvider(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
vals *codersdk.DeploymentValues,
|
||||
mergedExplicitProviders []codersdk.ExternalAuthConfig,
|
||||
) ([]codersdk.ExternalAuthConfig, error) {
|
||||
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "disabled by configuration"),
|
||||
slog.F("flag", "external-auth-github-default-provider-enable"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
if len(mergedExplicitProviders) > 0 {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "explicit external auth providers configured"),
|
||||
slog.F("provider_count", len(mergedExplicitProviders)),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !defaultEligible {
|
||||
logger.Info(ctx, "default github external auth provider suppressed",
|
||||
slog.F("reason", "deployment is not eligible"),
|
||||
)
|
||||
return mergedExplicitProviders, nil
|
||||
}
|
||||
|
||||
logger.Info(ctx, "injecting default github external auth provider",
|
||||
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
|
||||
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
|
||||
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
|
||||
)
|
||||
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
|
||||
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
|
||||
ClientID: GithubOAuth2DefaultProviderClientID,
|
||||
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
|
||||
params := githubOAuth2ConfigParams{
|
||||
accessURL: vals.AccessURL.Value(),
|
||||
@@ -2037,9 +1934,28 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
|
||||
// We want to enable it only for new deployments, and avoid enabling it
|
||||
// if a deployment was upgraded from an older version.
|
||||
// nolint:gocritic // Requires system privileges
|
||||
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, xerrors.Errorf("get github default eligible: %w", err)
|
||||
}
|
||||
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
|
||||
|
||||
if defaultEligibleNotSet {
|
||||
// nolint:gocritic // User count requires system privileges
|
||||
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user count: %w", err)
|
||||
}
|
||||
// We check if a deployment is new by checking if it has any users.
|
||||
defaultEligible = userCount == 0
|
||||
// nolint:gocritic // Requires system privileges
|
||||
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
|
||||
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !defaultEligible {
|
||||
|
||||
@@ -53,7 +53,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/migrations"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
"github.com/coder/coder/v2/coderd/userpassword"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/cryptorand"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
@@ -303,7 +302,6 @@ func TestServer(t *testing.T) {
|
||||
"open install.sh: file does not exist",
|
||||
"telemetry disabled, unable to notify of security issues",
|
||||
"installed terraform version newer than expected",
|
||||
"report generator",
|
||||
}
|
||||
|
||||
countLines := func(fullOutput string) int {
|
||||
@@ -1742,18 +1740,6 @@ func TestServer(t *testing.T) {
|
||||
|
||||
// Next, we instruct the same server to display the YAML config
|
||||
// and then save it.
|
||||
// Because this is literally the same invocation, DefaultFn sets the
|
||||
// value of 'Default'. Which triggers a mutually exclusive error
|
||||
// on the next parse.
|
||||
// Usually we only parse flags once, so this is not an issue
|
||||
for _, c := range inv.Command.Children {
|
||||
if c.Name() == "server" {
|
||||
for i := range c.Options {
|
||||
c.Options[i].DefaultFn = nil
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
|
||||
//nolint:gocritic
|
||||
inv.Args = append(args, "--write-config")
|
||||
@@ -1807,155 +1793,6 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
args []string
|
||||
env map[string]string
|
||||
createUserPreStart bool
|
||||
expectedProviders []string
|
||||
}
|
||||
|
||||
run := func(t *testing.T, tc testCase) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
unsetPrefixedEnv := func(prefix string) {
|
||||
t.Helper()
|
||||
for _, envVar := range os.Environ() {
|
||||
envKey, _, found := strings.Cut(envVar, "=")
|
||||
if !found || !strings.HasPrefix(envKey, prefix) {
|
||||
continue
|
||||
}
|
||||
value, had := os.LookupEnv(envKey)
|
||||
require.True(t, had)
|
||||
require.NoError(t, os.Unsetenv(envKey))
|
||||
keyCopy := envKey
|
||||
valueCopy := value
|
||||
t.Cleanup(func() {
|
||||
// This is for setting/unsetting a number of prefixed env vars.
|
||||
// t.Setenv doesn't cover this use case.
|
||||
// nolint:usetesting
|
||||
_ = os.Setenv(keyCopy, valueCopy)
|
||||
})
|
||||
}
|
||||
}
|
||||
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
|
||||
unsetPrefixedEnv("CODER_GITAUTH_")
|
||||
|
||||
dbURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
|
||||
|
||||
const (
|
||||
existingUserEmail = "existing-user@coder.com"
|
||||
existingUserUsername = "existing-user"
|
||||
existingUserPassword = "SomeSecurePassword!"
|
||||
)
|
||||
if tc.createUserPreStart {
|
||||
hashedPassword, err := userpassword.Hash(existingUserPassword)
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.User(t, db, database.User{
|
||||
Email: existingUserEmail,
|
||||
Username: existingUserUsername,
|
||||
HashedPassword: []byte(hashedPassword),
|
||||
})
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"server",
|
||||
"--postgres-url", dbURL,
|
||||
"--http-address", ":0",
|
||||
"--access-url", "https://example.com",
|
||||
}
|
||||
args = append(args, tc.args...)
|
||||
|
||||
inv, cfg := clitest.New(t, args...)
|
||||
for envKey, value := range tc.env {
|
||||
t.Setenv(envKey, value)
|
||||
}
|
||||
clitest.Start(t, inv)
|
||||
|
||||
accessURL := waitAccessURL(t, cfg)
|
||||
client := codersdk.New(accessURL)
|
||||
|
||||
if tc.createUserPreStart {
|
||||
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
|
||||
Email: existingUserEmail,
|
||||
Password: existingUserPassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
client.SetSessionToken(loginResp.SessionToken)
|
||||
} else {
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
}
|
||||
|
||||
externalAuthResp, err := client.ListExternalAuths(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
|
||||
for _, provider := range externalAuthResp.Providers {
|
||||
gotProviders[provider.ID] = provider
|
||||
}
|
||||
require.Len(t, gotProviders, len(tc.expectedProviders))
|
||||
|
||||
for _, providerID := range tc.expectedProviders {
|
||||
provider, ok := gotProviders[providerID]
|
||||
require.Truef(t, ok, "expected provider %q to be configured", providerID)
|
||||
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
|
||||
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
|
||||
require.True(t, provider.Device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range []testCase{
|
||||
{
|
||||
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
|
||||
},
|
||||
{
|
||||
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
|
||||
createUserPreStart: true,
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
"--external-auth-github-default-provider-enable=false",
|
||||
},
|
||||
expectedProviders: nil,
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
|
||||
args: []string{
|
||||
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
{
|
||||
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
|
||||
env: map[string]string{
|
||||
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
|
||||
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
|
||||
},
|
||||
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
run(t, tc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // This test sets environment variables.
|
||||
func TestServer_Logging_NoParallel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
+31
-7
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, nil)
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
orgOwner = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
|
||||
+1
-1
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
|
||||
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
|
||||
version := workspace.LatestBuild.TemplateVersionID
|
||||
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
|
||||
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
|
||||
version = workspace.TemplateActiveVersionID
|
||||
if version != workspace.LatestBuild.TemplateVersionID {
|
||||
action = WorkspaceUpdate
|
||||
|
||||
+4
-4
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
statefilePath := filepath.Join(t.TempDir(), "state")
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
|
||||
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
|
||||
var gotState bytes.Buffer
|
||||
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
|
||||
Do()
|
||||
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
|
||||
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
|
||||
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: taUser.ID,
|
||||
}).
|
||||
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
|
||||
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
|
||||
Do()
|
||||
wantState := []byte("updated state")
|
||||
stateFile, err := os.CreateTemp(t.TempDir(), "")
|
||||
|
||||
+7
-9
@@ -1,3 +1,5 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
@@ -5,7 +7,6 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -24,15 +25,12 @@ func setupSocketServer(t *testing.T) (path string, cleanup func()) {
|
||||
t.Helper()
|
||||
|
||||
// Use a temporary socket path for each test
|
||||
socketPath := testutil.AgentSocketPath(t)
|
||||
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
|
||||
|
||||
// Create parent directory if needed. Not necessary on Windows because named pipes live in an abstract namespace
|
||||
// not tied to any real files.
|
||||
if runtime.GOOS != "windows" {
|
||||
parentDir := filepath.Dir(socketPath)
|
||||
err := os.MkdirAll(parentDir, 0o700)
|
||||
require.NoError(t, err, "create socket directory")
|
||||
}
|
||||
// Create parent directory if needed
|
||||
parentDir := filepath.Dir(socketPath)
|
||||
err := os.MkdirAll(parentDir, 0o700)
|
||||
require.NoError(t, err, "create socket directory")
|
||||
|
||||
server, err := agentsocket.NewServer(
|
||||
slog.Make().Leveled(slog.LevelDebug),
|
||||
|
||||
@@ -17,8 +17,6 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
|
||||
r.taskDelete(),
|
||||
r.taskList(),
|
||||
r.taskLogs(),
|
||||
r.taskPause(),
|
||||
r.taskResume(),
|
||||
r.taskSend(),
|
||||
r.taskStatus(),
|
||||
},
|
||||
|
||||
+10
-5
@@ -41,7 +41,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client // user already has access to their own workspace
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
@@ -64,7 +65,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
|
||||
output := clitest.Capture(inv)
|
||||
@@ -87,7 +89,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
output := clitest.Capture(inv)
|
||||
@@ -141,7 +144,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.ID.String())
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
@@ -197,7 +201,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
|
||||
t.Run("SnapshotWithoutLogs_NoSnapshotCaptured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userClient, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
|
||||
client, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
|
||||
userClient := client
|
||||
|
||||
inv, root := clitest.New(t, "task", "logs", task.Name)
|
||||
output := clitest.Capture(inv)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskPause() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "pause <task>",
|
||||
Short: "Pause a task",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Pause a task by name",
|
||||
Command: "coder task pause my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Pause another user's task",
|
||||
Command: "coder task pause alice/my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Pause a task without confirmation",
|
||||
Command: "coder task pause my-task --yes",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
|
||||
}
|
||||
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
if task.Status == codersdk.TaskStatusPaused {
|
||||
return xerrors.Errorf("task %q is already paused", display)
|
||||
}
|
||||
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Pause task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
|
||||
IsConfirm: true,
|
||||
Default: cliui.ConfirmNo,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("pause task %q: %w", display, err)
|
||||
}
|
||||
|
||||
if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("pause task %q: no workspace build returned", display)
|
||||
}
|
||||
|
||||
err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("watch pause build for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(
|
||||
inv.Stdout,
|
||||
"\nThe %s task has been paused at %s!\n",
|
||||
cliui.Keyword(task.Name),
|
||||
cliui.Timestamp(time.Now()),
|
||||
)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExpTaskPause(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: Expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
// OtherUserTask verifies that an admin can pause a task owned by
|
||||
// another user using the "owner/name" identifier format.
|
||||
t.Run("OtherUserTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A different user's running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
|
||||
// Then: We expect the task to be paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been paused")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptConfirm", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// And: We confirm we want to pause the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Pause task")
|
||||
pty.WriteLine("yes")
|
||||
|
||||
// Then: We expect the task to be paused
|
||||
pty.ExpectMatchContext(ctx, "has been paused")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptDecline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to pause the task
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// But: We say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Pause task")
|
||||
pty.WriteLine("no")
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to not be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("TaskAlreadyPaused", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// And: We paused the running task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
|
||||
|
||||
// When: We attempt to pause the task again
|
||||
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is already paused
|
||||
err = inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "is already paused")
|
||||
})
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/pretty"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) taskResume() *serpent.Command {
|
||||
var noWait bool
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "resume <task>",
|
||||
Short: "Resume a task",
|
||||
Long: FormatExamples(
|
||||
Example{
|
||||
Description: "Resume a task by name",
|
||||
Command: "coder task resume my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Resume another user's task",
|
||||
Command: "coder task resume alice/my-task",
|
||||
},
|
||||
Example{
|
||||
Description: "Resume a task without confirmation",
|
||||
Command: "coder task resume my-task --yes",
|
||||
},
|
||||
),
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Flag: "no-wait",
|
||||
Description: "Return immediately after resuming the task.",
|
||||
Value: serpent.BoolOf(&noWait),
|
||||
},
|
||||
cliui.SkipPromptOption(),
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
|
||||
}
|
||||
|
||||
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
|
||||
if task.Status == codersdk.TaskStatusError || task.Status == codersdk.TaskStatusUnknown {
|
||||
return xerrors.Errorf("task %q is in %s state and cannot be resumed; check the workspace build logs and agent status for details", display, task.Status)
|
||||
} else if task.Status != codersdk.TaskStatusPaused {
|
||||
return xerrors.Errorf("task %q cannot be resumed (current status: %s)", display, task.Status)
|
||||
}
|
||||
|
||||
_, err = cliui.Prompt(inv, cliui.PromptOptions{
|
||||
Text: fmt.Sprintf("Resume task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
|
||||
IsConfirm: true,
|
||||
Default: cliui.ConfirmNo,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resume task %q: %w", display, err)
|
||||
} else if resp.WorkspaceBuild == nil {
|
||||
return xerrors.Errorf("resume task %q: no workspace build returned", display)
|
||||
}
|
||||
|
||||
if noWait {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Resuming task %q in the background.\n", cliui.Keyword(display))
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID); err != nil {
|
||||
return xerrors.Errorf("watch resume build for task %q: %w", display, err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "\nThe %s task has been resumed.\n", cliui.Keyword(display))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExpTaskResume(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// pauseTask is a helper that pauses a task and waits for the stop
|
||||
// build to complete.
|
||||
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pauseResp.WorkspaceBuild)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
|
||||
}
|
||||
|
||||
t.Run("WithYesFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
// OtherUserTask verifies that an admin can resume a task owned by
|
||||
// another user using the "owner/name" identifier format.
|
||||
t.Run("OtherUserTask", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A different user's paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume their task
|
||||
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
|
||||
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, adminClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "has been resumed")
|
||||
|
||||
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("NoWait", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task (and specify no wait)
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
|
||||
output := clitest.Capture(inv)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect the task to be resumed in the background
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, output.Stdout(), "in the background")
|
||||
|
||||
// And: The task to eventually be resumed
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptConfirm", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// And: We confirm we want to resume the task
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Resume task")
|
||||
pty.WriteLine("yes")
|
||||
|
||||
// Then: We expect the task to be resumed
|
||||
pty.ExpectMatchContext(ctx, "has been resumed")
|
||||
require.NoError(t, w.Wait())
|
||||
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("PromptDecline", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A paused task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
pauseTask(setupCtx, t, userClient, task)
|
||||
|
||||
// When: We attempt to resume the task
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name)
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// But: Say no at the confirmation screen
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
inv = inv.WithContext(ctx)
|
||||
pty := ptytest.New(t).Attach(inv)
|
||||
w := clitest.StartWithWaiter(t, inv)
|
||||
pty.ExpectMatchContext(ctx, "Resume task")
|
||||
pty.WriteLine("no")
|
||||
require.Error(t, w.Wait())
|
||||
|
||||
// Then: We expect the task to still be paused
|
||||
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
|
||||
})
|
||||
|
||||
t.Run("TaskNotPaused", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: A running task
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
|
||||
|
||||
// When: We attempt to resume the task that is not paused
|
||||
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
|
||||
clitest.SetupConfig(t, userClient, root)
|
||||
|
||||
// Then: We expect to get an error that the task is not paused
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
err := inv.WithContext(ctx).Run()
|
||||
require.ErrorContains(t, err, "cannot be resumed")
|
||||
})
|
||||
}
|
||||
@@ -25,7 +25,8 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
|
||||
@@ -41,7 +42,8 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
|
||||
@@ -57,7 +59,8 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
|
||||
userClient := client
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
|
||||
@@ -110,7 +113,7 @@ func Test_TaskSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupCtx := testutil.Context(t, testutil.WaitLong)
|
||||
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
|
||||
|
||||
var stdout strings.Builder
|
||||
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
|
||||
|
||||
+10
-44
@@ -120,40 +120,6 @@ func Test_Tasks(t *testing.T) {
|
||||
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pause task",
|
||||
cmdArgs: []string{"task", "pause", taskName, "--yes"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
require.Contains(t, stdout, "has been paused", "pause output should confirm task was paused")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get task status after pause",
|
||||
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var task codersdk.Task
|
||||
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
|
||||
require.Equal(t, taskName, task.Name, "task name should match")
|
||||
require.Equal(t, codersdk.TaskStatusPaused, task.Status, "task should be paused")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resume task",
|
||||
cmdArgs: []string{"task", "resume", taskName, "--yes"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
require.Contains(t, stdout, "has been resumed", "resume output should confirm task was resumed")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get task status after resume",
|
||||
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
|
||||
assertFn: func(stdout string, userClient *codersdk.Client) {
|
||||
var task codersdk.Task
|
||||
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
|
||||
require.Equal(t, taskName, task.Name, "task name should match")
|
||||
require.Equal(t, codersdk.TaskStatusInitializing, task.Status, "task should be initializing after resume")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete task",
|
||||
cmdArgs: []string{"task", "delete", taskName, "--yes"},
|
||||
@@ -272,17 +238,17 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
|
||||
// setupCLITaskTest creates a test workspace with an AI task template and agent,
|
||||
// with a fake agent API configured with the provided set of handlers.
|
||||
// Returns the user client and workspace.
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
|
||||
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
|
||||
t.Helper()
|
||||
|
||||
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, ownerClient)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
fakeAPI := startFakeAgentAPI(t, agentAPIHandlers)
|
||||
|
||||
authToken := uuid.NewString()
|
||||
template := createAITaskTemplate(t, ownerClient, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
|
||||
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
|
||||
|
||||
wantPrompt := "test prompt"
|
||||
task, err := userClient.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
|
||||
@@ -296,17 +262,17 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
|
||||
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
|
||||
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
|
||||
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
|
||||
o.Client = agentClient
|
||||
})
|
||||
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
|
||||
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
|
||||
WaitFor(coderdtest.AgentsReady)
|
||||
|
||||
return ownerClient, userClient, task
|
||||
return userClient, task
|
||||
}
|
||||
|
||||
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
|
||||
|
||||
@@ -139,10 +139,8 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
|
||||
type templateVersionRow struct {
|
||||
// For json format:
|
||||
TemplateVersion codersdk.TemplateVersion `table:"-"`
|
||||
ActiveJSON bool `json:"active" table:"-"`
|
||||
|
||||
// For table format:
|
||||
ID string `json:"-" table:"id"`
|
||||
Name string `json:"-" table:"name,default_sort"`
|
||||
CreatedAt time.Time `json:"-" table:"created at"`
|
||||
CreatedBy string `json:"-" table:"created by"`
|
||||
@@ -168,8 +166,6 @@ func templateVersionsToRows(activeVersionID uuid.UUID, templateVersions ...coder
|
||||
|
||||
rows[i] = templateVersionRow{
|
||||
TemplateVersion: templateVersion,
|
||||
ActiveJSON: templateVersion.ID == activeVersionID,
|
||||
ID: templateVersion.ID.String(),
|
||||
Name: templateVersion.Name,
|
||||
CreatedAt: templateVersion.CreatedAt,
|
||||
CreatedBy: templateVersion.CreatedBy.Username,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -42,33 +40,6 @@ func TestTemplateVersions(t *testing.T) {
|
||||
pty.ExpectMatch(version.CreatedBy.Username)
|
||||
pty.ExpectMatch("Active")
|
||||
})
|
||||
|
||||
t.Run("ListVersionsJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
|
||||
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
|
||||
|
||||
inv, root := clitest.New(t, "templates", "versions", "list", template.Name, "--output", "json")
|
||||
clitest.SetupConfig(t, member, root)
|
||||
|
||||
var stdout bytes.Buffer
|
||||
inv.Stdout = &stdout
|
||||
|
||||
require.NoError(t, inv.Run())
|
||||
|
||||
var rows []struct {
|
||||
TemplateVersion codersdk.TemplateVersion `json:"TemplateVersion"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows))
|
||||
require.Len(t, rows, 1)
|
||||
assert.Equal(t, version.ID, rows[0].TemplateVersion.ID)
|
||||
assert.True(t, rows[0].Active)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTemplateVersionsPromote(t *testing.T) {
|
||||
|
||||
+1
-1
@@ -74,7 +74,7 @@ OPTIONS:
|
||||
--socket-path string, $CODER_AGENT_SOCKET_PATH
|
||||
Specify the path for the agent socket.
|
||||
|
||||
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: true)
|
||||
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: false)
|
||||
Enable the agent socket server.
|
||||
|
||||
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
|
||||
|
||||
+11
-22
@@ -49,9 +49,10 @@ OPTIONS:
|
||||
security purposes if a --wildcard-access-url is configured.
|
||||
|
||||
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
|
||||
Disable workspace sharing. Workspace ACL checking is disabled and only
|
||||
owners can have ssh, apps and terminal access to workspaces. Access
|
||||
based on the 'owner' role is also allowed unless disabled via
|
||||
Disable workspace sharing (requires the "workspace-sharing" experiment
|
||||
to be enabled). Workspace ACL checking is disabled and only owners can
|
||||
have ssh, apps and terminal access to workspaces. Access based on the
|
||||
'owner' role is also allowed unless disabled via
|
||||
--disable-owner-workspace-access.
|
||||
|
||||
--swagger-enable bool, $CODER_SWAGGER_ENABLE
|
||||
@@ -62,9 +63,6 @@ OPTIONS:
|
||||
Separate multiple experiments with commas, or enter '*' to opt-in to
|
||||
all available experiments.
|
||||
|
||||
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
|
||||
Enable the default GitHub external auth provider managed by Coder.
|
||||
|
||||
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
|
||||
Type of auth to use when connecting to postgres. For AWS RDS, using
|
||||
IAM authentication (awsiamrds) is recommended.
|
||||
@@ -175,22 +173,19 @@ AI BRIDGE OPTIONS:
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file for AI Bridge Proxy.
|
||||
|
||||
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
|
||||
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
|
||||
provider requests.
|
||||
|
||||
--aibridge-proxy-key-file string, $CODER_AIBRIDGE_PROXY_KEY_FILE
|
||||
Path to the CA private key file for AI Bridge Proxy.
|
||||
|
||||
--aibridge-proxy-listen-addr string, $CODER_AIBRIDGE_PROXY_LISTEN_ADDR (default: :8888)
|
||||
The address the AI Bridge Proxy will listen on.
|
||||
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file used to intercept (MITM) HTTPS traffic
|
||||
from AI clients. This CA must be trusted by AI clients for the proxy
|
||||
to decrypt their requests.
|
||||
|
||||
--aibridge-proxy-key-file string, $CODER_AIBRIDGE_PROXY_KEY_FILE
|
||||
Path to the CA private key file used to intercept (MITM) HTTPS traffic
|
||||
from AI clients.
|
||||
|
||||
--aibridge-proxy-upstream string, $CODER_AIBRIDGE_PROXY_UPSTREAM
|
||||
URL of an upstream HTTP proxy to chain tunneled (non-allowlisted)
|
||||
requests through. Format: http://[user:pass@]host:port or
|
||||
@@ -388,19 +383,13 @@ NETWORKING OPTIONS:
|
||||
--samesite-auth-cookie lax|none, $CODER_SAMESITE_AUTH_COOKIE (default: lax)
|
||||
Controls the 'SameSite' property is set on browser session cookies.
|
||||
|
||||
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE (default: false)
|
||||
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE
|
||||
Controls if the 'Secure' property is set on browser session cookies.
|
||||
|
||||
--wildcard-access-url string, $CODER_WILDCARD_ACCESS_URL
|
||||
Specifies the wildcard hostname to use for workspace applications in
|
||||
the form "*.example.com".
|
||||
|
||||
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
|
||||
Recommended to be enabled. Enables `__Host-` prefix for cookies to
|
||||
guarantee they are only set by the right domain. This change is
|
||||
disruptive to any workspaces built before release 2.31, requiring a
|
||||
workspace restart.
|
||||
|
||||
NETWORKING / DERP OPTIONS:
|
||||
Most Coder deployments never have to think about DERP because all connections
|
||||
between workspaces and users are peer-to-peer. However, when Coder cannot
|
||||
|
||||
-2
@@ -12,8 +12,6 @@ SUBCOMMANDS:
|
||||
delete Delete tasks
|
||||
list List tasks
|
||||
logs Show a task's logs
|
||||
pause Pause a task
|
||||
resume Resume a task
|
||||
send Send input to a task
|
||||
status Show the status of a task.
|
||||
|
||||
|
||||
-25
@@ -1,25 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder task pause [flags] <task>
|
||||
|
||||
Pause a task
|
||||
|
||||
- Pause a task by name:
|
||||
|
||||
$ coder task pause my-task
|
||||
|
||||
- Pause another user's task:
|
||||
|
||||
$ coder task pause alice/my-task
|
||||
|
||||
- Pause a task without confirmation:
|
||||
|
||||
$ coder task pause my-task --yes
|
||||
|
||||
OPTIONS:
|
||||
-y, --yes bool
|
||||
Bypass confirmation prompts.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
-28
@@ -1,28 +0,0 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder task resume [flags] <task>
|
||||
|
||||
Resume a task
|
||||
|
||||
- Resume a task by name:
|
||||
|
||||
$ coder task resume my-task
|
||||
|
||||
- Resume another user's task:
|
||||
|
||||
$ coder task resume alice/my-task
|
||||
|
||||
- Resume a task without confirmation:
|
||||
|
||||
$ coder task resume my-task --yes
|
||||
|
||||
OPTIONS:
|
||||
--no-wait bool
|
||||
Return immediately after resuming the task.
|
||||
|
||||
-y, --yes bool
|
||||
Bypass confirmation prompts.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
@@ -9,7 +9,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
|
||||
-c, --column [name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
|
||||
Columns to display in table output.
|
||||
|
||||
--include-archived bool
|
||||
|
||||
+1
-1
@@ -27,7 +27,7 @@ USAGE:
|
||||
SUBCOMMANDS:
|
||||
create Create a token
|
||||
list List tokens
|
||||
remove Expire or delete a token
|
||||
remove Delete a token
|
||||
view Display detailed information about a token
|
||||
|
||||
———
|
||||
|
||||
@@ -15,10 +15,6 @@ OPTIONS:
|
||||
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
|
||||
Columns to display in table output.
|
||||
|
||||
--include-expired bool
|
||||
Include expired tokens in the output. By default, expired tokens are
|
||||
hidden.
|
||||
|
||||
-o, --output table|json (default: table)
|
||||
Output format.
|
||||
|
||||
|
||||
+2
-10
@@ -1,19 +1,11 @@
|
||||
coder v0.0.0-devel
|
||||
|
||||
USAGE:
|
||||
coder tokens remove [flags] <name|id|token>
|
||||
coder tokens remove <name|id|token>
|
||||
|
||||
Expire or delete a token
|
||||
Delete a token
|
||||
|
||||
Aliases: delete, rm
|
||||
|
||||
Remove a token by expiring it. Use --delete to permanently hard-delete the
|
||||
token instead.
|
||||
|
||||
OPTIONS:
|
||||
--delete bool
|
||||
Permanently delete the token instead of expiring it. This removes the
|
||||
audit trail.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+7
-23
@@ -176,16 +176,11 @@ networking:
|
||||
# (default: <unset>, type: string-array)
|
||||
proxyTrustedOrigins: []
|
||||
# Controls if the 'Secure' property is set on browser session cookies.
|
||||
# (default: false, type: bool)
|
||||
# (default: <unset>, type: bool)
|
||||
secureAuthCookie: false
|
||||
# Controls the 'SameSite' property is set on browser session cookies.
|
||||
# (default: lax, type: enum[lax\|none])
|
||||
sameSiteAuthCookie: lax
|
||||
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
|
||||
# they are only set by the right domain. This change is disruptive to any
|
||||
# workspaces built before release 2.31, requiring a workspace restart.
|
||||
# (default: false, type: bool)
|
||||
hostPrefixCookie: false
|
||||
# Whether Coder only allows connections to workspaces via the browser.
|
||||
# (default: <unset>, type: bool)
|
||||
browserOnly: false
|
||||
@@ -422,11 +417,6 @@ oidc:
|
||||
# an insecure OIDC configuration. It is not recommended to use this flag.
|
||||
# (default: <unset>, type: bool)
|
||||
dangerousSkipIssuerChecks: false
|
||||
# Optional override of the default redirect url which uses the deployment's access
|
||||
# url. Useful in situations where a deployment has more than 1 domain. Using this
|
||||
# setting can also break OIDC, so use with caution.
|
||||
# (default: <unset>, type: url)
|
||||
oidc-redirect-url:
|
||||
# Telemetry is critical to our ability to improve Coder. We strip all personal
|
||||
# information before sending data to our servers. Please only disable telemetry
|
||||
# when required by your organization's security policy.
|
||||
@@ -524,10 +514,10 @@ disablePathApps: false
|
||||
# workspaces.
|
||||
# (default: <unset>, type: bool)
|
||||
disableOwnerWorkspaceAccess: false
|
||||
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
|
||||
# can have ssh, apps and terminal access to workspaces. Access based on the
|
||||
# 'owner' role is also allowed unless disabled via
|
||||
# --disable-owner-workspace-access.
|
||||
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
|
||||
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
|
||||
# and terminal access to workspaces. Access based on the 'owner' role is also
|
||||
# allowed unless disabled via --disable-owner-workspace-access.
|
||||
# (default: <unset>, type: bool)
|
||||
disableWorkspaceSharing: false
|
||||
# These options change the behavior of how clients interact with the Coder.
|
||||
@@ -564,9 +554,6 @@ supportLinks: []
|
||||
# External Authentication providers.
|
||||
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
|
||||
externalAuthProviders: []
|
||||
# Enable the default GitHub external auth provider managed by Coder.
|
||||
# (default: true, type: bool)
|
||||
externalAuthGithubDefaultProviderEnable: true
|
||||
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
|
||||
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
|
||||
# "tunnel.example.com".
|
||||
@@ -830,13 +817,10 @@ aibridgeproxy:
|
||||
# The address the AI Bridge Proxy will listen on.
|
||||
# (default: :8888, type: string)
|
||||
listen_addr: :8888
|
||||
# Path to the CA certificate file used to intercept (MITM) HTTPS traffic from AI
|
||||
# clients. This CA must be trusted by AI clients for the proxy to decrypt their
|
||||
# requests.
|
||||
# Path to the CA certificate file for AI Bridge Proxy.
|
||||
# (default: <unset>, type: string)
|
||||
cert_file: ""
|
||||
# Path to the CA private key file used to intercept (MITM) HTTPS traffic from AI
|
||||
# clients.
|
||||
# Path to the CA private key file for AI Bridge Proxy.
|
||||
# (default: <unset>, type: string)
|
||||
key_file: ""
|
||||
# Comma-separated list of AI provider domains for which HTTPS traffic will be
|
||||
|
||||
+14
-37
@@ -218,10 +218,9 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
var (
|
||||
all bool
|
||||
includeExpired bool
|
||||
displayTokens []tokenListRow
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
all bool
|
||||
displayTokens []tokenListRow
|
||||
formatter = cliui.NewOutputFormatter(
|
||||
cliui.TableFormat([]tokenListRow{}, defaultCols),
|
||||
cliui.JSONFormat(),
|
||||
)
|
||||
@@ -241,8 +240,7 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
}
|
||||
|
||||
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
|
||||
IncludeAll: all,
|
||||
IncludeExpired: includeExpired,
|
||||
IncludeAll: all,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list tokens: %w", err)
|
||||
@@ -276,12 +274,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
|
||||
Description: "Specifies whether all users' tokens will be listed or not (must have Owner role to see all tokens).",
|
||||
Value: serpent.BoolOf(&all),
|
||||
},
|
||||
{
|
||||
Name: "include-expired",
|
||||
Flag: "include-expired",
|
||||
Description: "Include expired tokens in the output. By default, expired tokens are hidden.",
|
||||
Value: serpent.BoolOf(&includeExpired),
|
||||
},
|
||||
}
|
||||
|
||||
formatter.AttachOptions(&cmd.Options)
|
||||
@@ -331,13 +323,10 @@ func (r *RootCmd) viewToken() *serpent.Command {
|
||||
}
|
||||
|
||||
func (r *RootCmd) removeToken() *serpent.Command {
|
||||
var deleteToken bool
|
||||
cmd := &serpent.Command{
|
||||
Use: "remove <name|id|token>",
|
||||
Aliases: []string{"delete"},
|
||||
Short: "Expire or delete a token",
|
||||
Long: "Remove a token by expiring it. Use --delete to permanently hard-" +
|
||||
"delete the token instead.",
|
||||
Short: "Delete a token",
|
||||
Middleware: serpent.Chain(
|
||||
serpent.RequireNArgs(1),
|
||||
),
|
||||
@@ -349,7 +338,7 @@ func (r *RootCmd) removeToken() *serpent.Command {
|
||||
|
||||
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, inv.Args[0])
|
||||
if err != nil {
|
||||
// If it's a token, we need to extract the ID.
|
||||
// If it's a token, we need to extract the ID
|
||||
maybeID := strings.Split(inv.Args[0], "-")[0]
|
||||
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
|
||||
if err != nil {
|
||||
@@ -357,29 +346,17 @@ func (r *RootCmd) removeToken() *serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
if deleteToken {
|
||||
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete api key: %w", err)
|
||||
}
|
||||
cliui.Infof(inv.Stdout, "Token has been deleted.")
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.ExpireAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("expire api key: %w", err)
|
||||
return xerrors.Errorf("delete api key: %w", err)
|
||||
}
|
||||
cliui.Infof(inv.Stdout, "Token has been expired.")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "delete",
|
||||
Description: "Permanently delete the token instead of expiring it. This removes the audit trail.",
|
||||
Value: serpent.BoolOf(&deleteToken),
|
||||
cliui.Infof(
|
||||
inv.Stdout,
|
||||
"Token has been deleted.",
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user