Compare commits

..

50 Commits

Author SHA1 Message Date
Charlie Voiselle eee13c42a4 docs(cli): reference --oidc-group-mapping flag name instead of 'legacy'
Issue: Used internal nomenclature instead of user-facing flag name

The previous fix referenced 'legacy group name mapping' but users don't
know what that means - it's an internal implementation detail. Users
configure this via the --oidc-group-mapping flag.

Changed to: 'This filter is applied after the oidc-group-mapping.'

This directly references the flag name users would actually use, making
the relationship clear and actionable. Users can now understand:
- The regex filter applies to group names
- Group names may have been transformed by --oidc-group-mapping first
- They need to write regex patterns that match the mapped names

Example: If --oidc-group-mapping transforms 'developers' to 'dev-team',
the regex in --oidc-group-regex-filter will match against 'dev-team'.
2026-02-09 16:14:00 -05:00
Charlie Voiselle 65b48c0f84 docs(cli): fix help text for --oidc-group-regex-filter (clarify mapping order)
Issue: Removed ordering information when it was actually helpful

The previous correction removed the sentence about filter order to avoid
confusion, but this actually made the description LESS clear. Users need
to understand that the regex filter operates on group names AFTER any
legacy name mapping has been applied.

Example: If IdP sends 'developers' and LegacyNameMapping renames it to
'dev-team', the regex filter will match against 'dev-team', not 'developers'.

Changed to: 'This filter is applied after legacy group name mapping.'

This clarifies:
1. It's the LEGACY mapping (name→name) not the new Mapping (name→IDs)
2. The regex operates on potentially-renamed group names
3. The filter happens before the final ID mapping

Code reference: coderd/idpsync/group.go lines 379-398
- Line 380: LegacyNameMapping (name → name)
- Line 386: RegexFilter (on the potentially renamed name)
- Line 392: Mapping (name → []uuid.UUID)
2026-02-09 16:13:59 -05:00
Charlie Voiselle 30cdf29e52 docs(cli): fix help text for --oidc-group-regex-filter (final correction)
Issue: Previous description incorrectly stated filter order

The correction commit stated 'This filter is applied after the group mapping'
but the actual code order in coderd/idpsync/group.go lines 379-398 shows:
1. Legacy group mappings
2. Regex filter
3. (New) group mapping

Since the filter order is complex and the description was causing confusion,
removed the last sentence entirely. The first two sentences clearly explain
what the flag does without introducing incorrect ordering claims.

This follows the verification report's recommendation to remove the
confusing last sentence.
2026-02-09 16:13:59 -05:00
Charlie Voiselle b1d2bb6d71 docs(cli): fix help text for --external-auth-providers
Issue: Clarity - vague description

Changed 'External Authentication providers.' to 'Configure external authentication providers for Git and other services.' to explain what these providers are actually used for.
2026-02-09 16:13:59 -05:00
Charlie Voiselle 94bad2a956 docs(cli): fix help text for --workspace-prebuilds-reconciliation-backoff-lookback-period
Issue: Clarity - unclear purpose

Changed 'Interval to look back to determine number of failed prebuilds, which influences backoff' to 'Time period to look back when counting failed prebuilds to calculate the backoff delay' to clarify this determines the time window for counting failures.
2026-02-09 16:13:59 -05:00
Charlie Voiselle 111714c7ed docs(cli): fix help text for --workspace-prebuilds-reconciliation-backoff-interval
Issue: Clarity - confusing wording about backoff behavior

Changed 'Interval to increase reconciliation backoff by when prebuilds fail, after which a retry attempt is made' to 'Amount of time to add to the reconciliation backoff delay after each prebuild failure, before the next retry attempt is made' to clarify this is an incremental addition to the backoff delay.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 1f9c516c5c docs(cli): fix help text for --workspace-prebuilds-failure-hard-limit
Issue: Clarity - unclear what 'hits the hard limit' means

Changed 'before a preset hits the hard limit' to 'before a preset is considered hard-limited and stops automatic prebuild creation' to explain what actually happens when the limit is reached.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 3645c65bb2 docs(cli): fix help text for --workspace-hostname-suffix
Issue: Clarity - incomplete example hostname

Changed 'in SSH config and Coder Connect on Coder Desktop' to 'for SSH connections and Coder Connect' for conciseness. Updated the example from 'myworkspace.coder' to the full format 'agent.workspace.owner.coder' to show the complete hostname structure.
2026-02-09 16:13:58 -05:00
Charlie Voiselle d3d2d2fb1e docs(cli): fix help text for --workspace-agent-logs-retention
Issue: Clarity - ambiguous scope

Changed 'Logs from the latest build are always retained' to 'Logs from the latest build for each workspace are always retained' to clarify that this applies per-workspace, not just one latest build globally.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 086fb1f5d5 docs(cli): fix help text for --block-direct-connections
Issue: Clarity - imprecise wording about STUN behavior

Clarified that 'Workspace agents' (not 'Workspaces') reach out to STUN servers, changed 'get their address' to 'discover their address', and simplified 'until they are restarted after this change has been made' to just 'until they are restarted'.
2026-02-09 16:13:58 -05:00
Charlie Voiselle a73a535a5b docs(cli): fix help text for --proxy-health-interval
Issue: Clarity - awkward phrasing

Changed 'in which coderd should be checking' to 'at which coderd checks' for more concise, natural phrasing.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 96e01c3018 docs(cli): fix help text for --email-tls-cert-key-file
Issue: Clarity - vague description

Changed 'Certificate key file to use' to 'Private key file for the client certificate' to clarify this is the private key that pairs with --email-tls-cert-file.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 6b10a0359b docs(cli): fix help text for --email-tls-cert-file
Issue: Clarity - vague description

Changed 'Certificate file to use' to 'Client certificate file for mutual TLS authentication' to clarify what this certificate is for and when it's needed.
2026-02-09 16:13:57 -05:00
Charlie Voiselle b62583ad4b docs(cli): fix help text for --oidc-user-role-default
Issue: Clarity - ambiguous relationship between defaults and synced roles

Added 'in addition to synced roles' to clarify that these defaults don't replace synced roles. Also clarified that 'member' is always assigned 'regardless of this setting' to avoid confusion about whether this setting affects the member role.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 3d6727a2cb docs(cli): fix help text for --oidc-group-field
Issue: Clarity - unclear structure

Reordered to put the primary purpose first: 'OIDC claim field to use as the user's groups' before the conditional requirement. This makes the description more scannable and understandable.
2026-02-09 16:13:56 -05:00
Charlie Voiselle b163962a14 docs(cli): fix help text for --aibridge-circuit-breaker-interval
Issue: Clarity - confusing technical jargon

Changed 'Cyclic period of the closed state for clearing internal failure counts' to 'Time window for counting failures before resetting the failure count in the closed state' to explain what the interval actually does in clearer terms.
2026-02-09 16:13:56 -05:00
Charlie Voiselle 9aca4ea27c docs(cli): fix help text for --aibridge-circuit-breaker-enabled
Issue: Clarity - ambiguous error code description

Changed '(429, 503, 529 overloaded)' to '(HTTP 429, 503, 529)' and added 'and overload errors' to clarify that these are HTTP status codes and what they represent.
2026-02-09 16:13:56 -05:00
Charlie Voiselle b0c10131ea docs(cli): fix help text for --aibridge-retention
Issue: Clarity - wordy phrasing

Simplified 'Length of time to retain data such as interceptions and all related records (token, prompt, tool use)' to 'How long to retain AI Bridge data including interceptions, tokens, prompts, and tool usage records' for more natural, clearer phrasing.
2026-02-09 16:13:56 -05:00
Charlie Voiselle c8c7e13e96 docs(cli): fix help text for --aibridge-inject-coder-mcp-tools
Issue: Clarity - awkward phrasing and formatting

Changed 'Whether to inject' to 'Enable injection of' for consistency with other boolean flags. Simplified the requirements clause and changed double quotes to single quotes for consistency.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 249b7ea38e docs(cli): fix help text for --aibridge-enabled
Issue: Clarity - unclear technical jargon

Changed 'Whether to start an in-memory aibridged instance' to 'Enable the embedded AI Bridge service to intercept and record AI provider requests' to explain what the feature actually does in user-friendly terms.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 1333096e25 docs(cli): fix help text for --oidc-group-regex-filter (correction)
Issue: Previous fix introduced confusing circular wording

The previous commit incorrectly changed the ending to 'after the group mapping and regex filter' which is nonsensical since this flag configures THE regex filter itself. Reverted to the correct wording: 'after the group mapping'.

The only valid changes from the original are:
- Added comma after 'If provided'
- Simplified 'allows for filtering' to 'allows filtering'
2026-02-09 16:13:55 -05:00
Charlie Voiselle 54bc9324dd docs(cli): fix help text for --samesite-auth-cookie
Issue: Grammar - missing word

Added missing 'if' to read 'Controls if the SameSite property is set' instead of 'Controls the SameSite property is set'.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 109e5f2b19 docs(cli): fix help text for --enable-authz-recordings
Issue: Grammar - acronym capitalization

Capitalized 'API' (Application Programming Interface) - should always be uppercase.
2026-02-09 16:13:55 -05:00
Charlie Voiselle ee176b4207 docs(cli): fix help text for --ssh-config-options
Issue: Grammar - missing space after period

Added missing space after period between sentences: 'commas.' + 'Using' → 'commas. ' + 'Using'.
2026-02-09 16:13:54 -05:00
Charlie Voiselle 7e1e16be33 docs(cli): fix help text for --prometheus-address
Issue: Grammar - proper noun capitalization

Capitalized 'Prometheus' as it's a proper noun.
2026-02-09 16:13:54 -05:00
Charlie Voiselle 5cfe8082ce docs(cli): fix help text for --prometheus-enable
Issue: Grammar - proper noun capitalization

Capitalized 'Prometheus' as it's a proper noun (the name of the monitoring system).
2026-02-09 16:13:54 -05:00
Charlie Voiselle 6b7f672834 docs(cli): fix help text for --allow-custom-quiet-hours
Issue: Grammar - awkward phrasing

Changed 'for workspaces to stop in' to 'for when workspaces are stopped' for more natural phrasing.
2026-02-09 16:13:53 -05:00
Charlie Voiselle c55f6252a1 docs(cli): fix help text for --tls-client-ca-file
Issue: Grammar - missing article

Added missing article 'the' before 'client' to read 'authenticity of the client'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 842553b677 docs(cli): fix help text for --tls-ciphers
Issue: Grammar - missing verb

Fixed missing 'are' in 'that allowed to be used' → 'that are allowed to be used'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 05a771ba77 docs(cli): fix help text for --derp-server-stun-addresses
Issue: Grammar - incorrect possessive

Fixed "it's" (contraction of "it is") → "its" (possessive). Should be 'Each STUN server will get its own DERP region'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 70a0d42e65 docs(cli): fix help text for --derp-server-region-name
Issue: Grammar - malformed sentence

Fixed malformed sentence 'Region name that for' → 'Region name to use for'. The original was missing a verb.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 6b1d73b466 docs(cli): fix help text for --notifications-store-sync-buffer-size
Issue: Grammar - typo

Fixed typo: 'change' → 'chance'. Same typo as in --notifications-store-sync-interval.
2026-02-09 16:13:52 -05:00
Charlie Voiselle d7b9596145 docs(cli): fix help text for --notifications-store-sync-interval
Issue: Grammar - typo

Fixed typo: 'change' → 'chance'. The sentence should read 'the lower the chance of state inconsistency'.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 7a0aa1a40a docs(cli): fix help text for --oidc-signups-disabled-text
Issue: Grammar - awkward phrasing

Changed 'The custom text to show on the error page informing about disabled OIDC signups' to 'Custom text to show on the error page when OIDC signups are disabled' for clearer, more direct phrasing. Removed unnecessary 'The' article.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 4d8ea43e11 docs(cli): fix help text for --oidc-icon-url
Issue: Grammar - redundant phrasing

Changed 'URL pointing to the icon' to 'URL of the icon'. The phrase 'pointing to' is redundant since a URL inherently points to a resource.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 6fddae98f6 docs(cli): fix help text for --oidc-group-regex-filter
Issue: Grammar - missing comma + simplification + filter order clarification

Added missing comma after 'If provided'. Simplified 'allows for filtering' to 'allows filtering'. Clarified filter order to match the actual implementation.
2026-02-09 16:13:51 -05:00
Charlie Voiselle e33fbb6087 docs(cli): fix help text for --oidc-group-mapping
Issue: Grammar - subject-verb agreement + awkward phrasing

Changed 'the group in Coder it should map to' to 'the groups in Coder they should map to' for proper plural agreement. Also simplified 'for when' to 'when'.
2026-02-09 16:13:51 -05:00
Charlie Voiselle 2337393e13 docs(cli): fix help text for --oidc-client-cert-file
Issue: Grammar - incorrect acronym capitalization

Changed 'Pem' to 'PEM', 'oauth2' to 'OAuth2', and 'x509' to 'X.509'. These are standard capitalizations for these acronyms and standards.
2026-02-09 16:13:51 -05:00
Charlie Voiselle d7357a1b0a docs(cli): fix help text for --oidc-client-key-file
Issue: Grammar - incorrect acronym capitalization

Changed 'Pem' to 'PEM' (Privacy Enhanced Mail), 'oauth2' to 'OAuth2', and 'IDP' to 'IdP' (Identity Provider). These are standard capitalizations for these acronyms.
2026-02-09 16:13:51 -05:00
Charlie Voiselle afbf1af29c docs(cli): fix help text for --oauth2-github-allow-everyone
Issue: Grammar - unclear and run-on sentence

Changed 'Allow all logins, setting this option means...' to 'Allow all GitHub users to authenticate. When enabled, allowed orgs and teams must be empty.' This separates the run-on sentence and clarifies what 'all logins' means (all GitHub users).
2026-02-09 16:13:50 -05:00
Charlie Voiselle 1d834c747c docs(cli): fix help text for --aibridge-circuit-breaker-failure-threshold
Issue: Grammar - subject-verb agreement

Changed 'triggers' to 'trigger' for correct subject-verb agreement. 'Number' is the subject, which takes the singular form, but 'failures' is the head of the relative clause 'that trigger...', making 'trigger' (plural) correct.
2026-02-09 16:13:50 -05:00
Charlie Voiselle a80edec752 docs(cli): fix help text for --aibridge-bedrock-access-key-secret
Issue: Grammar - wordy and redundant phrasing

Simplified from 'The access key secret to use with the access key to authenticate against' to 'AWS secret access key for authenticating with'. Uses standard AWS terminology and eliminates redundancy.
2026-02-09 16:13:50 -05:00
Charlie Voiselle 2a6473e8c6 docs(cli): fix help text for --aibridge-bedrock-access-key
Issue: Grammar - awkward phrasing

Changed 'The access key to authenticate against' to 'AWS access key for authenticating with' for consistency and clarity. Uses standard AWS terminology.
2026-02-09 16:13:50 -05:00
Charlie Voiselle 1f9c0b9b7f docs(cli): fix help text for --aibridge-anthropic-key
Issue: Grammar - awkward phrasing

Changed 'The key to authenticate against' to 'API key for authenticating with' for consistency with --aibridge-openai-key and more natural phrasing.
2026-02-09 16:13:49 -05:00
Charlie Voiselle 5494afabd8 docs(cli): fix help text for --aibridge-openai-key
Issue: Grammar - awkward phrasing

Changed 'The key to authenticate against' to 'API key for authenticating with' for more natural, concise phrasing. This matches standard API documentation conventions.
2026-02-09 16:13:49 -05:00
Charlie Voiselle 07c6e86a50 docs(cli): fix help text for --notifications-email-hello
Issue: Factually incorrect description of SMTP HELO/EHLO (deprecated alias)

Same issue as --email-hello. This is a deprecated alias but still needs the correct description. The HELO/EHLO command identifies the client to the server, not the server itself.

Fix: Clarified this identifies 'this client to the SMTP server'.
2026-02-09 16:13:49 -05:00
Charlie Voiselle b543821a1c docs(cli): fix help text for --email-hello
Issue: Factually incorrect description of SMTP HELO/EHLO

The description incorrectly stated this identifies 'the SMTP server' when it actually identifies the CLIENT to the server. The HELO/EHLO command is how the client introduces itself to the SMTP server during connection.

Fix: Clarified this identifies 'this client to the SMTP server' which accurately reflects the SMTP protocol.
2026-02-09 16:13:49 -05:00
Charlie Voiselle e8b7045a9b docs(cli): fix help text for --pprof-enable
Issue: Factually incorrect terminology

The description incorrectly stated pprof serves 'metrics' when it actually serves profiling data (CPU profiles, memory profiles, goroutines, etc.). Metrics are Prometheus's domain, not pprof's.

Fix: Changed 'metrics' to 'profiling endpoints' to accurately describe what pprof provides.
2026-02-09 16:13:48 -05:00
Charlie Voiselle 2571089528 docs(cli): fix help text for --oidc-user-role-mapping
Issue: Factually incorrect (confuses roles with groups) + grammar error

The description incorrectly stated this maps to 'groups in Coder' when it actually maps to site ROLES (member, admin, etc.). Also had a grammar error: 'will ignored' should be 'will be ignored'.

Fix: Corrected to clarify this maps OIDC role names to Coder role names, and fixed the grammar error.
2026-02-09 16:13:48 -05:00
Charlie Voiselle 1fb733fe1e docs(cli): fix help text for --oidc-allowed-groups
Issue: Factually incorrect filter order

The description incorrectly stated that the check is applied 'after the group mapping and before the regex filter'. This is wrong.

Fix: Updated to reflect actual behavior where the check is applied BEFORE any group mapping or filtering. Also clarified the positive case (users WITH at least one matching group are allowed) instead of the confusing double-negative phrasing.
2026-02-09 16:13:48 -05:00
952 changed files with 14487 additions and 85887 deletions
-249
View File
@@ -1,249 +0,0 @@
# Modern Go (1.181.26)
Reference for writing idiomatic Go. Covers what changed, what it
replaced, and what to reach for. Respect the project's `go.mod` `go`
line: don't emit features from a version newer than what the module
declares. Check `go.mod` before writing code.
## How modern Go thinks differently
**Generics** (1.18): Design reusable code with type parameters instead
of `interface{}` casts, code generation, or the `sort.Interface`
pattern. Use `any` for unconstrained types, `comparable` for map keys
and equality, `cmp.Ordered` for sortable types. Type inference usually
makes explicit type arguments unnecessary (improved in 1.21).
**Per-iteration loop variables** (1.22): Each loop iteration gets its
own variable copy. Closures inside loops capture the correct value. The
`v := v` shadow trick is dead. Remove it when you see it.
**Iterators** (1.23): `iter.Seq[V]` and `iter.Seq2[K,V]` are the
standard iterator types. Containers expose `.All()` methods returning
these. Combined with `slices.Collect`, `slices.Sorted`, `maps.Keys`,
etc., they replace ad-hoc "loop and append" code with composable,
lazy pipelines. When a sequence is consumed only once, prefer an
iterator over materializing a slice.
**Error trees** (1.201.26): Errors compose as trees, not chains.
`errors.Join` aggregates multiple errors. `fmt.Errorf` accepts multiple
`%w` verbs. `errors.Is`/`As` traverse the full tree. Custom error
types that wrap multiple causes must implement `Unwrap() []error` (the
slice form), not `Unwrap() error`, or tree traversal won't find the
children. `errors.AsType[T]` (1.26) is the type-safe way to match
error types. Propagate cancellation reasons with
`context.WithCancelCause`.
**Structured logging** (1.21): `log/slog` is the standard structured
logger. This project uses `cdr.dev/slog/v3` instead, which has a
different API. Do not use `log/slog` directly.
## Replace these patterns
The left column reflects common patterns from pre-1.22 Go. Write the
right column instead. The "Since" column tells you the minimum `go`
directive version required in `go.mod`.
| Old pattern | Modern replacement | Since |
|---|---|---|
| `interface{}` | `any` | 1.18 |
| `v := v` inside loops | remove it | 1.22 |
| `for i := 0; i < n; i++` | `for i := range n` | 1.22 |
| `for i := 0; i < b.N; i++` (benchmarks) | `for b.Loop()` (correct timing, future-proof) | 1.24 |
| `sort.Slice(s, func(i,j int) bool{…})` | `slices.SortFunc(s, cmpFn)` | 1.21 |
| `wg.Add(1); go func(){ defer wg.Done(); … }()` | `wg.Go(func(){…})` | 1.25 |
| `func ptr[T any](v T) *T { return &v }` | `new(expr)` e.g. `new(time.Now())` | 1.26 |
| `var target *E; errors.As(err, &target)` | `t, ok := errors.AsType[*E](err)` | 1.26 |
| Custom multi-error type | `errors.Join(err1, err2, …)` | 1.20 |
| Single `%w` for multiple causes | `fmt.Errorf("…: %w, %w", e1, e2)` | 1.20 |
| `rand.Seed(time.Now().UnixNano())` | delete it (auto-seeded); prefer `math/rand/v2` | 1.20/1.22 |
| `sync.Once` + captured variable | `sync.OnceValue(func() T {…})` / `OnceValues` | 1.21 |
| Custom `min`/`max` helpers | `min(a, b)` / `max(a, b)` builtins (any ordered type) | 1.21 |
| `for k := range m { delete(m, k) }` | `clear(m)` (also zeroes slices) | 1.21 |
| Index+slice or `SplitN(s, sep, 2)` | `strings.Cut(s, sep)` / `bytes.Cut` | 1.18 |
| `TrimPrefix` + check if anything was trimmed | `strings.CutPrefix` / `CutSuffix` (returns ok bool) | 1.20 |
| `strings.Split` + loop when no slice is needed | `strings.SplitSeq` / `Lines` / `FieldsSeq` (iterator, no alloc) | 1.24 |
| `"2006-01-02"` / `"2006-01-02 15:04:05"` / `"15:04:05"` | `time.DateOnly` / `time.DateTime` / `time.TimeOnly` | 1.20 |
| Manual `Before`/`After`/`Equal` chains for comparison | `time.Time.Compare` (returns -1/0/+1; works with `slices.SortFunc`) | 1.20 |
| Loop collecting map keys into slice | `slices.Sorted(maps.Keys(m))` | 1.23 |
| `fmt.Sprintf` + append to `[]byte` | `fmt.Appendf(buf, …)` (also `Append`, `Appendln`) | 1.18 |
| `reflect.TypeOf((*T)(nil)).Elem()` | `reflect.TypeFor[T]()` | 1.22 |
| `*(*[4]byte)(slice)` unsafe cast | `[4]byte(slice)` direct conversion | 1.20 |
| `atomic.LoadInt64` / `StoreInt64` | `atomic.Int64` (also `Bool`, `Uint64`, `Pointer[T]`) | 1.19 |
| `crypto/rand.Read(buf)` + hex/base64 encode | `crypto/rand.Text()` (one call) | 1.24 |
| Checking `crypto/rand.Read` error | don't: return is always nil | 1.24 |
| `time.Sleep` in tests | `testing/synctest` (deterministic fake clock) | 1.24/1.25 |
| `json:",omitempty"` on zero-value structs like `time.Time{}` | `json:",omitzero"` (uses `IsZero()` method) | 1.24 |
| `strings.Title` | `golang.org/x/text/cases` | 1.18 |
| `net.IP` in new code | `net/netip.Addr` (immutable, comparable, lighter) | 1.18 |
| `tools.go` with blank imports | `tool` directive in `go.mod` | 1.24 |
| `runtime.SetFinalizer` | `runtime.AddCleanup` (multiple per object, no pointer cycles) | 1.24 |
| `httputil.ReverseProxy.Director` | `.Rewrite` hook + `ProxyRequest` (Director deprecated in 1.26) | 1.20 |
| `sql.NullString`, `sql.NullInt64`, etc. | `sql.Null[T]` | 1.22 |
| Manual `ctx, cancel := context.WithCancel(…)` + `t.Cleanup(cancel)` | `t.Context()` (auto-canceled when test ends) | 1.24 |
| `if d < 0 { d = -d }` on durations | `d.Abs()` (handles `math.MinInt64`) | 1.19 |
| Implement only `TextMarshaler` | also implement `TextAppender` for alloc-free marshaling | 1.24 |
| Custom `Unwrap() error` on multi-cause errors | `Unwrap() []error` (slice form; required for tree traversal) | 1.20 |
## New capabilities
These enable things that weren't practical before. Reach for them in the
described situations.
| What | Since | When to use it |
|---|---|---|
| `cmp.Or(a, b, c)` | 1.22 | Defaults/fallback chains: returns first non-zero value. Replaces verbose `if a != "" { return a }` cascades. |
| `context.WithoutCancel(ctx)` | 1.21 | Background work that must outlive the request (e.g. async cleanup after HTTP response). Derived context keeps parent's values but ignores cancellation. |
| `context.AfterFunc(ctx, fn)` | 1.21 | Register cleanup that fires on context cancellation without spawning a goroutine that blocks on `<-ctx.Done()`. |
| `context.WithCancelCause` / `Cause` | 1.20 | When callers need to know WHY a context was canceled, not just that it was. Retrieve cause with `context.Cause(ctx)`. |
| `context.WithDeadlineCause` / `WithTimeoutCause` | 1.21 | Attach a domain-specific error to deadline/timeout expiry (e.g. distinguish "DB query timed out" from "HTTP request timed out"). |
| `errors.ErrUnsupported` | 1.21 | Standard sentinel for "not supported." Use instead of per-package custom sentinels. Check with `errors.Is`. |
| `http.ResponseController` | 1.20 | Per-request flush, hijack, and deadline control without type-asserting `ResponseWriter` to `http.Flusher` or `http.Hijacker`. |
| Enhanced `ServeMux` routing | 1.22 | `"GET /items/{id}"` patterns in `http.ServeMux`. Access with `r.PathValue("id")`. Wildcards: `{name}`, catch-all: `{path...}`, exact: `{$}`. Eliminates many third-party router dependencies. |
| `os.Root` / `OpenRoot` | 1.24 | Confined directory access that prevents symlink escape. 1.25 adds `MkdirAll`, `ReadFile`, `WriteFile` for real use. |
| `os.CopyFS` | 1.23 | Copy an entire `fs.FS` to local filesystem in one call. |
| `os/signal.NotifyContext` with cause | 1.26 | Cancellation cause identifies which signal (SIGTERM vs SIGINT) triggered shutdown. |
| `io/fs.SkipAll` / `filepath.SkipAll` | 1.20 | Return from `WalkDir` callback to stop walking entirely. Cleaner than a sentinel error. |
| `GOMEMLIMIT` env / `debug.SetMemoryLimit` | 1.19 | Soft memory limit for GC. Use alongside or instead of `GOGC` in memory-constrained containers. |
| `net/url.JoinPath` | 1.19 | Join URL path segments correctly. Replaces error-prone string concatenation. |
| `go test -skip` | 1.20 | Skip tests matching a pattern. Useful when running a subset of a large test suite. |
## Key packages
### `slices` (1.21, iterators added 1.23)
Replaces `sort.Slice`, manual search loops, and manual contains checks.
Search: `Contains`, `ContainsFunc`, `Index`, `IndexFunc`,
`BinarySearch`, `BinarySearchFunc`.
Sort: `Sort`, `SortFunc`, `SortStableFunc`, `IsSorted`, `IsSortedFunc`,
`Min`, `MinFunc`, `Max`, `MaxFunc`.
Transform: `Clone`, `Compact`, `CompactFunc`, `Grow`, `Clip`,
`Concat` (1.22), `Repeat` (1.23), `Reverse`, `Insert`, `Delete`,
`Replace`.
Compare: `Equal`, `EqualFunc`, `Compare`.
Iterators (1.23): `All`, `Values`, `Backward`, `Collect`, `AppendSeq`,
`Sorted`, `SortedFunc`, `SortedStableFunc`, `Chunk`.
### `maps` (1.21, iterators added 1.23)
Core: `Clone`, `Copy`, `Equal`, `EqualFunc`, `DeleteFunc`.
Iterators (1.23): `All`, `Keys`, `Values`, `Insert`, `Collect`.
### `cmp` (1.21, `Or` added 1.22)
`Ordered` constraint for any ordered type. `Compare(a, b)` returns
-1/0/+1. `Less(a, b)` returns bool. `Or(vals...)` returns first
non-zero value.
### `iter` (1.23)
`Seq[V]` is `func(yield func(V) bool)`. `Seq2[K,V]` is
`func(yield func(K, V) bool)`. Return these from your container's
`.All()` methods. Consume with `for v := range seq` or pass to
`slices.Collect`, `slices.Sorted`, `maps.Collect`, etc.
### `math/rand/v2` (1.22)
Replaces `math/rand`. `IntN` not `Intn`. Generic `N[T]()` for any
integer type. Default source is `ChaCha8` (crypto-quality). No global
`Seed`. Use `rand.New(source)` for reproducible sequences.
### `log/slog` (1.21)
`slog.Info`, `slog.Warn`, `slog.Error`, `slog.Debug` with key-value
pairs. `slog.With(attrs...)` for logger with preset fields.
`slog.GroupAttrs` (1.25) for clean group creation. Implement
`slog.Handler` for custom backends.
**Note:** This project uses `cdr.dev/slog/v3`, not `log/slog`. The
API is different. Read existing code for usage patterns.
## Pitfalls
Things that are easy to get wrong, even when you know the modern API
exists. Check your output against these.
**Version misuse.** The replacement table has a "Since" column. If the
project's `go.mod` says `go 1.22`, you cannot use `wg.Go` (1.25),
`errors.AsType` (1.26), `new(expr)` (1.26), `b.Loop()` (1.24), or
`testing/synctest` (1.24). Fall back to the older pattern. Always
check before reaching for a replacement.
**`slices.Sort` vs `slices.SortFunc`.** `slices.Sort` requires
`cmp.Ordered` types (int, string, float64, etc.). For structs, custom
types, or multi-field sorting, use `slices.SortFunc` with a comparator
function. Using `slices.Sort` on a non-ordered type is a compile error.
**`for range n` still binds the index.** `for range n` discards the
index. If you need it, write `for i := range n`. Writing
`for range n` and then trying to use `i` inside the loop is a compile
error.
**Don't hand-roll iterators when the stdlib returns one.** Functions
like `maps.Keys`, `slices.Values`, `strings.SplitSeq`, and
`strings.Lines` already return `iter.Seq` or `iter.Seq2`. Don't
reimplement them. Compose with `slices.Collect`, `slices.Sorted`, etc.
**Don't mix `math/rand` and `math/rand/v2`.** They have different
function names (`Intn` vs `IntN`) and different default sources. Pick
one per package. Prefer v2 for new code. The v1 global source is
auto-seeded since 1.20, so delete `rand.Seed` calls either way.
**Iterator protocol.** When implementing `iter.Seq`, you must respect
the `yield` return value. If `yield` returns `false`, stop iteration
immediately and return. Ignoring it violates the contract and causes
panics when consumers break out of `for range` loops early.
**`errors.Join` with nil.** `errors.Join` skips nil arguments. This is
intentional and useful for aggregating optional errors, but don't
assume the result is always non-nil. `errors.Join(nil, nil)` returns
nil.
**`cmp.Or` evaluates all arguments.** Unlike a chain of `if`
statements, `cmp.Or(a(), b(), c())` calls all three functions. If any
have side effects or are expensive, use `if`/`else` instead.
**Timer channel semantics changed in 1.23.** Code that checks
`len(timer.C)` to see if a value is pending no longer works (channel
capacity is 0). Use a non-blocking `select` receive instead:
`select { case <-timer.C: default: }`.
**`context.WithoutCancel` still propagates values.** The derived
context inherits all values from the parent. If any middleware stores
request-scoped state (deadlines, trace IDs) via `context.WithValue`,
the background work sees it. This is usually desired but can be
surprising if the values hold references that should not outlive the
request.
## Behavioral changes that affect code
- **Timers** (1.23): unstopped `Timer`/`Ticker` are GC'd immediately.
Channels are unbuffered: no stale values after `Reset`/`Stop`. You no
longer need `defer t.Stop()` to prevent leaks.
- **Error tree traversal** (1.20): `errors.Is`/`As` follow
`Unwrap() []error`, not just `Unwrap() error`. Multi-error types must
expose the slice form for child errors to be found.
- **`math/rand` auto-seeded** (1.20): the global RNG is auto-seeded.
`rand.Seed` is a no-op in 1.24+. Don't call it.
- **GODEBUG compat** (1.21): behavioral changes are gated by `go.mod`'s
`go` line. Upgrading the version opts into new defaults.
- **Build tags** (1.18): `//go:build` is the only syntax. `// +build`
is gone.
- **Tool install** (1.18): `go get` no longer builds. Use
`go install pkg@version`.
- **Doc comments** (1.19): support `[links]`, lists, and headings.
- **`go test -skip`** (1.20): skip tests by name pattern from the
command line.
- **`go fix ./...` modernizers** (1.26): auto-rewrites code to use
newer idioms. Run after Go version upgrades.
## Transparent improvements (no code changes)
Swiss Tables maps, Green Tea GC, PGO, faster `io.ReadAll`,
stack-allocated slices, reduced cgo overhead, container-aware
GOMAXPROCS. Free on upgrade.
-4
View File
@@ -1,4 +0,0 @@
# All artifacts of the build processed are dumped here.
# Ignore it for docker context, as all Dockerfiles should build their own
# binaries.
build
+6 -3
View File
@@ -4,7 +4,10 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.25.7"
default: "1.25.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
use-cache:
description: "Whether to use the cache."
default: "true"
@@ -12,9 +15,9 @@ runs:
using: "composite"
steps:
- name: Setup Go
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5.6.0
uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2
with:
go-version: ${{ inputs.version }}
go-version: ${{ inputs.use-preinstalled-go == 'false' && inputs.version || '' }}
cache: ${{ inputs.use-cache }}
- name: Install gotestsum
+1 -1
View File
@@ -7,5 +7,5 @@ runs:
- name: Install Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
with:
terraform_version: 1.14.5
terraform_version: 1.14.1
terraform_wrapper: false
+21 -25
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -329,7 +329,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -381,7 +381,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -422,6 +422,10 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
- name: Setup Terraform
@@ -485,14 +489,6 @@ jobs:
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
- name: Increase PTY limit (macOS)
if: runner.os == 'macOS'
shell: bash
run: |
# Increase PTY limit to avoid exhaustion during tests.
# Default is 511; 999 is the maximum value on CI runner.
sudo sysctl -w kern.tty.ptmx_max=999
- name: Test with PostgreSQL Database (Linux)
if: runner.os == 'Linux'
uses: ./.github/actions/test-go-pg
@@ -582,7 +578,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -644,7 +640,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -716,7 +712,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -743,7 +739,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -776,7 +772,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -856,7 +852,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -937,7 +933,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1009,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1124,7 +1120,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1179,7 +1175,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1576,7 +1572,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
-21
View File
@@ -1,21 +0,0 @@
# This workflow triggers a Vercel deploy hook which builds+deploys coder.com
# (a Next.js app), to keep coder.com/docs URLs in sync with docs/manifest.json
#
# https://vercel.com/docs/deploy-hooks#triggering-a-deploy-hook
name: Update coder.com/docs
on:
push:
branches:
- main
paths:
- "docs/manifest.json"
jobs:
deploy-docs:
runs-on: ubuntu-latest
steps:
- name: Deploy docs site
run: |
curl -X POST "${{ secrets.DEPLOY_DOCS_VERCEL_WEBHOOK }}"
+3 -3
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+3 -3
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -58,11 +58,11 @@ jobs:
run: mkdir base-build-context
- name: Install depot.dev CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: wl5hnrrkns
context: base-build-context
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -75,7 +75,7 @@ jobs:
BRANCH_NAME: ${{ steps.branch-name.outputs.current_branch }}
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
@@ -88,7 +88,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and push Non-Nix image
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: b4q6ltmpzh
token: ${{ secrets.DEPOT_TOKEN }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+6 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -64,6 +64,11 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
- name: Setup Terraform
uses: ./.github/actions/setup-tf
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+5 -5
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+37 -5
View File
@@ -158,7 +158,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -386,12 +386,12 @@ jobs:
- name: Install depot.dev CLI
if: steps.image-base-tag.outputs.tag != ''
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
if: steps.image-base-tag.outputs.tag != ''
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: wl5hnrrkns
context: base-build-context
@@ -796,7 +796,7 @@ jobs:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -872,7 +872,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -955,3 +955,35 @@ jobs:
# different repo.
GH_TOKEN: ${{ secrets.CDRCI_GITHUB_TOKEN }}
VERSION: ${{ needs.release.outputs.version }}
# publish-sqlc pushes the latest schema to sqlc cloud.
# At present these pushes cannot be tagged, so the last push is always the latest.
publish-sqlc:
name: "Publish to schema sqlc cloud"
runs-on: "ubuntu-latest"
needs: release
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 1
persist-credentials: false
# We need golang to run the migration main.go
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Setup sqlc
uses: ./.github/actions/setup-sqlc
- name: Push schema to sqlc cloud
# Don't block a release on this
continue-on-error: true
run: |
make sqlc-push
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+3 -3
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -69,7 +69,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
+4 -4
View File
@@ -18,12 +18,12 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
- name: stale
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
-3
View File
@@ -98,6 +98,3 @@ AGENTS.local.md
# Ignore plans written by AI agents.
PLAN.md
# Ignore any dev licenses
license.txt
+3 -2
View File
@@ -198,12 +198,13 @@ reviewer time and clutters the diff.
**Don't delete existing comments** that explain non-obvious behavior. These
comments preserve important context about why code works a certain way.
**When adding tests for new behavior**, read existing tests first to understand what's covered. Add new cases for uncovered behavior. Edit existing tests as needed, but don't change what they verify.
**When adding tests for new behavior**, add new test cases instead of modifying
existing ones. This preserves coverage for the original behavior and makes it
clear what the new test covers.
## Detailed Development Guides
@.claude/docs/ARCHITECTURE.md
@.claude/docs/GO.md
@.claude/docs/OAUTH2.md
@.claude/docs/TESTING.md
@.claude/docs/TROUBLESHOOTING.md
+9 -34
View File
@@ -427,7 +427,6 @@ SITE_GEN_FILES := \
site/src/api/typesGenerated.ts \
site/src/api/rbacresourcesGenerated.ts \
site/src/api/countriesGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
site/src/theme/icons.json
site/out/index.html: \
@@ -655,7 +654,6 @@ GEN_FILES := \
tailnet/proto/tailnet.pb.go \
agent/proto/agent.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
vpn/vpn.pb.go \
@@ -711,7 +709,6 @@ gen/mark-fresh:
provisionersdk/proto/provisioner.pb.go \
provisionerd/proto/provisionerd.pb.go \
agent/agentsocket/proto/agentsocket.pb.go \
agent/boundarylogproxy/codec/boundary.pb.go \
vpn/vpn.pb.go \
enterprise/aibridged/proto/aibridged.pb.go \
coderd/database/dump.sql \
@@ -722,7 +719,6 @@ gen/mark-fresh:
coderd/rbac/scopes_constants_gen.go \
site/src/api/rbacresourcesGenerated.ts \
site/src/api/countriesGenerated.ts \
site/src/api/chatModelOptionsGenerated.json \
docs/admin/integrations/prometheus.md \
docs/reference/cli/index.md \
docs/admin/security/audit-logs.md \
@@ -847,12 +843,6 @@ vpn/vpn.pb.go: vpn/vpn.proto
--go_opt=paths=source_relative \
./vpn/vpn.proto
agent/boundarylogproxy/codec/boundary.pb.go: agent/boundarylogproxy/codec/boundary.proto agent/proto/agent.proto
protoc \
--go_out=. \
--go_opt=paths=source_relative \
./agent/boundarylogproxy/codec/boundary.proto
enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged.proto
protoc \
--go_out=. \
@@ -864,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
./scripts/biome_format.sh src/api/typesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
touch "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
@@ -873,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
./scripts/biome_format.sh src/theme/icons.json
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
touch "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
@@ -911,22 +901,15 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
touch "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
go run scripts/typegen/main.go countries > "$@"
./scripts/biome_format.sh src/api/countriesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
touch "$@"
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go
go run ./scripts/modeloptionsgen/main.go | tail -n +2 > "$@"
cd site && pnpm biome format --write src/api/chatModelOptionsGenerated.json
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
go run ./scripts/metricsdocgen/scanner > $@
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics
go run scripts/metricsdocgen/main.go
pnpm exec markdownlint-cli2 --fix ./docs/admin/integrations/prometheus.md
pnpm exec markdown-table-formatter ./docs/admin/integrations/prometheus.md
@@ -964,11 +947,11 @@ coderd/apidoc/.gen: \
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
./scripts/biome_format.sh ../docs/manifest.json
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
touch "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
touch "$@"
update-golden-files:
@@ -1013,19 +996,11 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
touch "$@"
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
fi
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
touch "$@"
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
fi
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
touch "$@"
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
+5 -25
View File
@@ -41,7 +41,6 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
@@ -112,12 +111,6 @@ type Client interface {
ConnectRPC28(ctx context.Context) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
// ConnectRPC28WithRole is like ConnectRPC28 but sends an explicit
// role query parameter to the server. The workspace agent should
// use role "agent" to enable connection monitoring.
ConnectRPC28WithRole(ctx context.Context, role string) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
tailnet.DERPMapRewriter
agentsdk.RefreshableSessionTokenProvider
}
@@ -303,8 +296,7 @@ type agent struct {
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
filesAPI *agentfiles.API
processAPI *agentproc.API
filesAPI *agentfiles.API
socketServerEnabled bool
socketPath string
@@ -377,7 +369,6 @@ func (a *agent) init() {
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
@@ -410,7 +401,7 @@ func (a *agent) initSocketServer() {
agentsocket.WithPath(a.socketPath),
)
if err != nil {
a.logger.Error(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
a.logger.Warn(a.hardCtx, "failed to create socket server", slog.Error(err), slog.F("path", a.socketPath))
return
}
@@ -420,12 +411,7 @@ func (a *agent) initSocketServer() {
// startBoundaryLogProxyServer starts the boundary log proxy socket server.
func (a *agent) startBoundaryLogProxyServer() {
if a.boundaryLogProxySocketPath == "" {
a.logger.Warn(a.hardCtx, "boundary log proxy socket path not defined; not starting proxy")
return
}
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath, a.prometheusRegistry)
proxy := boundarylogproxy.NewServer(a.logger, a.boundaryLogProxySocketPath)
if err := proxy.Start(); err != nil {
a.logger.Warn(a.hardCtx, "failed to start boundary log proxy", slog.Error(err))
return
@@ -1011,10 +997,8 @@ func (a *agent) run() (retErr error) {
return xerrors.Errorf("refresh token: %w", err)
}
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs.
// We pass role "agent" to enable connection monitoring on the server, which tracks
// the agent's connectivity state (first_connected_at, last_connected_at, disconnected_at).
aAPI, tAPI, err := a.client.ConnectRPC28WithRole(a.hardCtx, "agent")
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC28(a.hardCtx)
if err != nil {
return err
}
@@ -2038,10 +2022,6 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}
if err := a.processAPI.Close(); err != nil {
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
-1
View File
@@ -29,7 +29,6 @@ func (api *API) Routes() http.Handler {
r.Post("/list-directory", api.HandleLS)
r.Get("/read-file", api.HandleReadFile)
r.Get("/read-file-lines", api.HandleReadFileLines)
r.Post("/write-file", api.HandleWriteFile)
r.Post("/edit-files", api.HandleEditFiles)
+7 -283
View File
@@ -10,10 +10,11 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"github.com/icholy/replace"
"github.com/spf13/afero"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -22,22 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ReadFileLinesResponse is the JSON response for the line-based file reader.
type ReadFileLinesResponse struct {
// Success indicates whether the read was successful.
Success bool `json:"success"`
// FileSize is the original file size in bytes.
FileSize int64 `json:"file_size,omitempty"`
// TotalLines is the total number of lines in the file.
TotalLines int `json:"total_lines,omitempty"`
// LinesRead is the count of lines returned in this response.
LinesRead int `json:"lines_read,omitempty"`
// Content is the line-numbered file content.
Content string `json:"content,omitempty"`
// Error is the error message when success is false.
Error string `json:"error,omitempty"`
}
type HTTPResponseCode = int
func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
@@ -118,166 +103,6 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
return 0, nil
}
func (api *API) HandleReadFileLines(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
offset := parser.PositiveInt64(query, 1, "offset")
limit := parser.PositiveInt64(query, 0, "limit")
maxFileSize := parser.PositiveInt64(query, workspacesdk.DefaultMaxFileSize, "max_file_size")
maxLineBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxLineBytes, "max_line_bytes")
maxResponseLines := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseLines, "max_response_lines")
maxResponseBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseBytes, "max_response_bytes")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
resp := api.readFileLines(ctx, path, offset, limit, workspacesdk.ReadFileLinesLimits{
MaxFileSize: maxFileSize,
MaxLineBytes: int(maxLineBytes),
MaxResponseLines: int(maxResponseLines),
MaxResponseBytes: int(maxResponseBytes),
})
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) readFileLines(_ context.Context, path string, offset, limit int64, limits workspacesdk.ReadFileLinesLimits) ReadFileLinesResponse {
errResp := func(msg string) ReadFileLinesResponse {
return ReadFileLinesResponse{Success: false, Error: msg}
}
if !filepath.IsAbs(path) {
return errResp(fmt.Sprintf("file path must be absolute: %q", path))
}
f, err := api.filesystem.Open(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return errResp(fmt.Sprintf("file does not exist: %s", path))
}
if errors.Is(err, os.ErrPermission) {
return errResp(fmt.Sprintf("permission denied: %s", path))
}
return errResp(fmt.Sprintf("open file: %s", err))
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return errResp(fmt.Sprintf("stat file: %s", err))
}
if stat.IsDir() {
return errResp(fmt.Sprintf("not a file: %s", path))
}
fileSize := stat.Size()
if fileSize > limits.MaxFileSize {
return errResp(fmt.Sprintf(
"file is %d bytes which exceeds the maximum of %d bytes. Use grep, sed, or awk to extract the content you need, or use offset and limit to read a portion.",
fileSize, limits.MaxFileSize,
))
}
// Read the entire file (up to MaxFileSize).
data, err := io.ReadAll(f)
if err != nil {
return errResp(fmt.Sprintf("read file: %s", err))
}
// Split into lines.
content := string(data)
// Handle empty file.
if content == "" {
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: 0,
LinesRead: 0,
Content: "",
}
}
lines := strings.Split(content, "\n")
totalLines := len(lines)
// offset is 1-based line number.
if offset < 1 {
offset = 1
}
if offset > int64(totalLines) {
return errResp(fmt.Sprintf(
"offset %d is beyond the file length of %d lines",
offset, totalLines,
))
}
// Default limit.
if limit <= 0 {
limit = int64(limits.MaxResponseLines)
}
startIdx := int(offset - 1) // convert to 0-based
endIdx := startIdx + int(limit)
if endIdx > totalLines {
endIdx = totalLines
}
var numbered []string
totalBytesAccumulated := 0
for i := startIdx; i < endIdx; i++ {
line := lines[i]
// Per-line truncation.
if len(line) > limits.MaxLineBytes {
line = line[:limits.MaxLineBytes] + "... [truncated]"
}
// Format with 1-based line number.
numberedLine := fmt.Sprintf("%d\t%s", i+1, line)
lineBytes := len(numberedLine)
// Check total byte budget.
newTotal := totalBytesAccumulated + lineBytes
if len(numbered) > 0 {
newTotal++ // account for \n joiner
}
if newTotal > limits.MaxResponseBytes {
return errResp(fmt.Sprintf(
"output would exceed %d bytes. Read less at a time using offset and limit parameters.",
limits.MaxResponseBytes,
))
}
// Check line count.
if len(numbered) >= limits.MaxResponseLines {
return errResp(fmt.Sprintf(
"output would exceed %d lines. Read less at a time using offset and limit parameters.",
limits.MaxResponseLines,
))
}
numbered = append(numbered, numberedLine)
totalBytesAccumulated = newTotal
}
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: totalLines,
LinesRead: len(numbered),
Content: strings.Join(numbered, "\n"),
}
}
func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -420,21 +245,9 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}
data, err := io.ReadAll(f)
if err != nil {
return http.StatusInternalServerError, xerrors.Errorf("read %s: %w", path, err)
}
content := string(data)
for _, edit := range edits {
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
transforms := make([]transform.Transformer, len(edits))
for i, edit := range edits {
transforms[i] = replace.String(edit.Search, edit.Replace)
}
// Create an adjacent file to ensure it will be on the same device and can be
@@ -445,7 +258,8 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
}
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
if err != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
@@ -459,93 +273,3 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
// searchLines according to the provided `eq` function. It returns the start and
// end (exclusive) indices into contentLines of the match.
func seekLines(contentLines, searchLines []string, eq func(a, b string) bool) (start, end int, ok bool) {
if len(searchLines) == 0 {
return 0, 0, true
}
if len(searchLines) > len(contentLines) {
return 0, 0, false
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
return i, i + len(searchLines), true
}
return 0, 0, false
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
var b strings.Builder
for _, l := range contentLines[:start] {
_, _ = b.WriteString(l)
}
_, _ = b.WriteString(replacement)
for _, l := range contentLines[end:] {
_, _ = b.WriteString(l)
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
-285
View File
@@ -649,106 +649,6 @@ func TestEditFiles(t *testing.T) {
filepath.Join(tmpdir, "file3"): "edited3 3",
},
},
{
name: "TrailingWhitespace",
contents: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "foo \nbar\t\t\nbaz"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "trailing-ws"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo\nbar\nbaz",
Replace: "replaced",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "replaced"},
},
{
name: "TabsVsSpaces",
contents: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tfoo()\n\t}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "tabs-vs-spaces"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces but file uses tabs.
Search: " if true {\n foo()\n }",
Replace: "\tif true {\n\t\tbar()\n\t}",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tbar()\n\t}"},
},
{
name: "DifferentIndentDepth",
contents: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tnested()"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "indent-depth"),
Edits: []workspacesdk.FileEdit{
{
// Search has wrong indent depth (1 tab instead of 3).
Search: "\tdeep()\n\tnested()",
Replace: "\t\t\tdeep()\n\t\t\tchanged()",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tchanged()"},
},
{
name: "ExactMatchPreferred",
contents: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "hello world"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "exact-preferred"),
Edits: []workspacesdk.FileEdit{
{
Search: "hello world",
Replace: "goodbye world",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "no-match"),
Edits: []workspacesdk.FileEdit{
{
Search: "this does not exist in the file",
Replace: "whatever",
},
},
},
},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "mixed-ws"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces, file uses tabs.
Search: " result := compute()\n fmt.Println(result)\n",
Replace: "\tresult := compute()\n\tlog.Println(result)\n",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tlog.Println(result)\n}"},
},
{
name: "MultiError",
contents: map[string]string{
@@ -837,188 +737,3 @@ func TestEditFiles(t *testing.T) {
})
}
}
func TestReadFileLines(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-lines")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := newTestFs(afero.NewMemMapFs(), func(call, file string) error {
if file == noPermsFilePath {
return os.ErrPermission
}
return nil
})
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "a-directory-lines")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
emptyFilePath := filepath.Join(tmpdir, "empty-file")
err = afero.WriteFile(fs, emptyFilePath, []byte(""), 0o644)
require.NoError(t, err)
basicFilePath := filepath.Join(tmpdir, "basic-file")
err = afero.WriteFile(fs, basicFilePath, []byte("line1\nline2\nline3"), 0o644)
require.NoError(t, err)
longLine := string(bytes.Repeat([]byte("x"), 1025))
longLineFilePath := filepath.Join(tmpdir, "long-line-file")
err = afero.WriteFile(fs, longLineFilePath, []byte(longLine), 0o644)
require.NoError(t, err)
largeFilePath := filepath.Join(tmpdir, "large-file")
err = afero.WriteFile(fs, largeFilePath, bytes.Repeat([]byte("x"), 1<<20+1), 0o644)
require.NoError(t, err)
tests := []struct {
name string
path string
offset int64
limit int64
expSuccess bool
expError string
expContent string
expTotal int
expRead int
expSize int64
// useCodersdk is set for cases where the handler returns
// codersdk.Response (query param validation) instead of ReadFileLinesResponse.
useCodersdk bool
}{
{
name: "NoPath",
path: "",
useCodersdk: true,
expError: "is required",
},
{
name: "RelativePath",
path: "relative/path",
expError: "file path must be absolute",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "does-not-exist"),
expError: "file does not exist",
},
{
name: "IsDir",
path: dirPath,
expError: "not a file",
},
{
name: "NoPermissions",
path: noPermsFilePath,
expError: "permission denied",
},
{
name: "EmptyFile",
path: emptyFilePath,
expSuccess: true,
expTotal: 0,
expRead: 0,
expSize: 0,
},
{
name: "BasicRead",
path: basicFilePath,
expSuccess: true,
expContent: "1\tline1\n2\tline2\n3\tline3",
expTotal: 3,
expRead: 3,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2",
path: basicFilePath,
offset: 2,
expSuccess: true,
expContent: "2\tline2\n3\tline3",
expTotal: 3,
expRead: 2,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Limit1",
path: basicFilePath,
limit: 1,
expSuccess: true,
expContent: "1\tline1",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2Limit1",
path: basicFilePath,
offset: 2,
limit: 1,
expSuccess: true,
expContent: "2\tline2",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "OffsetBeyondFile",
path: basicFilePath,
offset: 100,
expError: "offset 100 is beyond the file length of 3 lines",
},
{
name: "LongLineTruncation",
path: longLineFilePath,
expSuccess: true,
expContent: "1\t" + string(bytes.Repeat([]byte("x"), 1024)) + "... [truncated]",
expTotal: 1,
expRead: 1,
expSize: 1025,
},
{
name: "LargeFile",
path: largeFilePath,
expError: "exceeds the maximum",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/read-file-lines?path=%s&offset=%d&limit=%d", tt.path, tt.offset, tt.limit), nil)
api.Routes().ServeHTTP(w, r)
if tt.useCodersdk {
// Query param validation errors return codersdk.Response.
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), tt.expError)
return
}
var resp agentfiles.ReadFileLinesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if tt.expSuccess {
require.Equal(t, http.StatusOK, w.Code)
require.True(t, resp.Success)
require.Equal(t, tt.expContent, resp.Content)
require.Equal(t, tt.expTotal, resp.TotalLines)
require.Equal(t, tt.expRead, resp.LinesRead)
require.Equal(t, tt.expSize, resp.FileSize)
} else {
require.Equal(t, http.StatusOK, w.Code)
require.False(t, resp.Success)
require.Contains(t, resp.Error, tt.expError)
}
})
}
}
-175
View File
@@ -1,175 +0,0 @@
package agentproc
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
manager *manager
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv),
}
}
// Close shuts down the process manager, killing all running
// processes.
func (api *API) Close() error {
return api.manager.Close()
}
// Routes returns the HTTP handler for process-related routes.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Post("/start", api.handleStartProcess)
r.Get("/list", api.handleListProcesses)
r.Get("/{id}/output", api.handleProcessOutput)
r.Post("/{id}/signal", api.handleSignalProcess)
return r
}
// handleStartProcess starts a new process.
func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.StartProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Command == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Command is required.",
})
return
}
proc, err := api.manager.start(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start process.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
ID: proc.id,
Started: true,
})
}
// handleListProcesses lists all tracked processes.
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
infos := api.manager.list()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
Processes: infos,
})
}
// handleProcessOutput returns the output of a process.
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
proc, ok := api.manager.get(id)
if !ok {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
output, truncated := proc.output()
info := proc.info()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ProcessOutputResponse{
Output: output,
Truncated: truncated,
Running: info.Running,
ExitCode: info.ExitCode,
})
}
// handleSignalProcess sends a signal to a running process.
func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Signal == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Signal is required.",
})
return
}
if req.Signal != "kill" && req.Signal != "terminate" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf(
"Unsupported signal %q. Use \"kill\" or \"terminate\".",
req.Signal,
),
})
return
}
if err := api.manager.signal(id, req.Signal); err != nil {
switch {
case errors.Is(err, errProcessNotFound):
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
case errors.Is(err, errProcessNotRunning):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf(
"Process %q is not running.", id,
),
})
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to signal process.",
Detail: err.Error(),
})
}
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf(
"Signal %q sent to process %q.", req.Signal, id,
),
})
}
-691
View File
@@ -1,691 +0,0 @@
package agentproc_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
// postStart sends a POST /start request and returns the recorder.
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// getList sends a GET /list request and returns the recorder.
func getList(t *testing.T, handler http.Handler) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
handler.ServeHTTP(w, r)
return w
}
// getOutput sends a GET /{id}/output request and returns the
// recorder.
func getOutput(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/%s/output", id), nil)
handler.ServeHTTP(w, r)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("/%s/signal", id), bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// newTestAPI creates a new API with a test logger and default
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
t.Cleanup(func() {
_ = api.Close()
})
return api.Routes()
}
// waitForExit polls the output endpoint until the process is
// no longer running or the context expires.
func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.ProcessOutputResponse {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.Fatal("timed out waiting for process to exit")
case <-ticker.C:
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if !resp.Running {
return resp
}
}
}
}
// startAndGetID is a helper that starts a process and returns
// the process ID.
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
t.Helper()
w := postStart(t, handler, req)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
return resp.ID
}
func TestStartProcess(t *testing.T) {
t.Parallel()
t.Run("ForegroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello",
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("BackgroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo background",
Background: true,
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("EmptyCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Command is required")
})
t.Run("MalformedJSON", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", strings.NewReader("{invalid json"))
handler.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "valid JSON")
})
t.Run("CustomWorkDir", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
tmpDir := t.TempDir()
// Write a marker file to verify the command ran in
// the correct directory. Comparing pwd output is
// unreliable on Windows where Git Bash returns POSIX
// paths.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
WorkDir: tmpDir,
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Use a unique env var name to avoid collisions in
// parallel tests.
envKey := fmt.Sprintf("TEST_PROC_ENV_%d", time.Now().UnixNano())
envVal := "custom_value_12345"
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: envVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHook", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_UPDATE_ENV_%d", time.Now().UnixNano())
envVal := "injected_by_hook"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, envVal)), nil
})
// The process should see the variable even though it
// was not passed in req.Env.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHookOverriddenByReqEnv", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_OVERRIDE_%d", time.Now().UnixNano())
hookVal := "from_hook"
reqVal := "from_request"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, hookVal)), nil
})
// req.Env should take precedence over the hook.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: reqVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// When duplicate env vars exist, shells use the last
// value. Since req.Env is appended after the hook,
// the request value wins.
require.Contains(t, strings.TrimSpace(resp.Output), reqVal)
})
}
func TestListProcesses(t *testing.T) {
t.Parallel()
t.Run("NoProcesses", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.NotNil(t, resp.Processes)
require.Empty(t, resp.Processes)
})
t.Run("MixedRunningAndExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process that exits quickly.
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, exitedID)
// Start a long-running process.
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// List should contain both.
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
procMap := make(map[string]workspacesdk.ProcessInfo)
for _, p := range resp.Processes {
procMap[p.ID] = p
}
exited, ok := procMap[exitedID]
require.True(t, ok, "exited process should be in list")
require.False(t, exited.Running)
require.NotNil(t, exited.ExitCode)
running, ok := procMap[runningID]
require.True(t, ok, "running process should be in list")
require.True(t, running.Running)
// Clean up the long-running process.
sw := postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
})
}
func TestProcessOutput(t *testing.T) {
t.Parallel()
t.Run("ExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-output",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-output")
})
t.Run("RunningProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getOutput(t, handler, "nonexistent-id-12345")
require.Equal(t, http.StatusNotFound, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
}
func TestSignalProcess(t *testing.T) {
t.Parallel()
t.Run("KillRunning", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("TerminateRunning", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("SIGTERM is not supported on Windows")
}
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "terminate",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postSignal(t, handler, "nonexistent-id-12345", workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("AlreadyExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
// Wait for exit first.
waitForExit(t, handler, id)
// Signaling an exited process should return 409
// Conflict via the errProcessNotRunning sentinel.
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
assert.Equal(t, http.StatusConflict, w.Code,
"expected 409 for signaling exited process, got %d", w.Code)
})
t.Run("EmptySignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Signal is required")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
t.Run("InvalidSignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "SIGFOO",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Unsupported signal")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
}
func TestProcessLifecycle(t *testing.T) {
t.Parallel()
t.Run("StartWaitCheckOutput", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo lifecycle-test && echo second-line",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "lifecycle-test")
require.Contains(t, resp.Output, "second-line")
})
t.Run("NonZeroExitCode", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "exit 42",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 42, *resp.ExitCode)
})
t.Run("StartSignalVerifyExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a long-running background process.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// Verify it's running.
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var running workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&running)
require.NoError(t, err)
require.True(t, running.Running)
// Signal it.
sw := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
// Verify it exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
})
t.Run("OutputExceedsBuffer", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Generate output that exceeds MaxHeadBytes +
// MaxTailBytes. Each line is ~100 chars, and we
// need more than 32KB total (16KB head + 16KB
// tail).
lineCount := (agentproc.MaxHeadBytes+agentproc.MaxTailBytes)/50 + 500
cmd := fmt.Sprintf(
"for i in $(seq 1 %d); do echo \"line-$i-padding-to-make-this-longer-than-fifty-characters-total\"; done",
lineCount,
)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: cmd,
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// The output should be truncated with head/tail
// strategy metadata.
require.NotNil(t, resp.Truncated, "large output should be truncated")
require.Equal(t, "head_tail", resp.Truncated.Strategy)
require.Greater(t, resp.Truncated.OmittedBytes, 0)
require.Greater(t, resp.Truncated.OriginalBytes, resp.Truncated.RetainedBytes)
// Verify the output contains the omission marker.
require.Contains(t, resp.Output, "... [omitted")
})
t.Run("StderrCaptured", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo stdout-msg && echo stderr-msg >&2",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// Both stdout and stderr should be captured.
require.Contains(t, resp.Output, "stdout-msg")
require.Contains(t, resp.Output, "stderr-msg")
})
}
-309
View File
@@ -1,309 +0,0 @@
package agentproc
import (
"fmt"
"strings"
"sync"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// MaxHeadBytes is the number of bytes retained from the
// beginning of the output for LLM consumption.
MaxHeadBytes = 16 << 10 // 16KB
// MaxTailBytes is the number of bytes retained from the
// end of the output for LLM consumption.
MaxTailBytes = 16 << 10 // 16KB
// MaxLineLength is the maximum length of a single line
// before it is truncated. This prevents minified files
// or other long single-line output from consuming the
// entire buffer.
MaxLineLength = 2048
// lineTruncationSuffix is appended to lines that exceed
// MaxLineLength.
lineTruncationSuffix = " ... [truncated]"
)
// HeadTailBuffer is a thread-safe buffer that captures process
// output and provides head+tail truncation for LLM consumption.
// It implements io.Writer so it can be used directly as
// cmd.Stdout or cmd.Stderr.
//
// The buffer stores up to MaxHeadBytes from the beginning of
// the output and up to MaxTailBytes from the end in a ring
// buffer, keeping total memory usage bounded regardless of
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
totalBytes int
maxHead int
maxTail int
}
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
}
// Write implements io.Writer. It is safe for concurrent use.
// All bytes are accepted; the return value always equals
// len(p) with a nil error.
func (b *HeadTailBuffer) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
b.mu.Lock()
defer b.mu.Unlock()
n := len(p)
b.totalBytes += n
// Fill head buffer if it is not yet full.
if !b.headFull {
remaining := b.maxHead - len(b.head)
if remaining > 0 {
take := remaining
if take > len(p) {
take = len(p)
}
b.head = append(b.head, p[:take]...)
p = p[take:]
if len(b.head) >= b.maxHead {
b.headFull = true
}
}
if len(p) == 0 {
return n, nil
}
}
// Write remaining bytes into the tail ring buffer.
b.writeTail(p)
return n, nil
}
// writeTail appends data to the tail ring buffer. The caller
// must hold b.mu.
func (b *HeadTailBuffer) writeTail(p []byte) {
if b.maxTail <= 0 {
return
}
// Lazily allocate the tail buffer on first use.
if b.tail == nil {
b.tail = make([]byte, b.maxTail)
}
for len(p) > 0 {
// Write as many bytes as fit starting at tailPos.
space := b.maxTail - b.tailPos
take := space
if take > len(p) {
take = len(p)
}
copy(b.tail[b.tailPos:b.tailPos+take], p[:take])
p = p[take:]
b.tailPos += take
if b.tailPos >= b.maxTail {
b.tailPos = 0
b.tailFull = true
}
}
}
// tailBytes returns the current tail contents in order. The
// caller must hold b.mu.
func (b *HeadTailBuffer) tailBytes() []byte {
if b.tail == nil {
return nil
}
if !b.tailFull {
// Haven't wrapped yet; data is [0, tailPos).
return b.tail[:b.tailPos]
}
// Wrapped: data is [tailPos, maxTail) + [0, tailPos).
out := make([]byte, b.maxTail)
n := copy(out, b.tail[b.tailPos:])
copy(out[n:], b.tail[:b.tailPos])
return out
}
// Bytes returns a copy of the raw buffer contents. If no
// truncation has occurred the full output is returned;
// otherwise the head and tail portions are concatenated.
func (b *HeadTailBuffer) Bytes() []byte {
b.mu.Lock()
defer b.mu.Unlock()
tail := b.tailBytes()
if len(tail) == 0 {
out := make([]byte, len(b.head))
copy(out, b.head)
return out
}
out := make([]byte, len(b.head)+len(tail))
copy(out, b.head)
copy(out[len(b.head):], tail)
return out
}
// Len returns the number of bytes currently stored in the
// buffer.
func (b *HeadTailBuffer) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
tailLen := 0
if b.tailFull {
tailLen = b.maxTail
} else if b.tail != nil {
tailLen = b.tailPos
}
return len(b.head) + tailLen
}
// TotalWritten returns the total number of bytes written to
// the buffer, which may exceed the stored capacity.
func (b *HeadTailBuffer) TotalWritten() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.totalBytes
}
// Output returns the truncated output suitable for LLM
// consumption, along with truncation metadata. If the total
// output fits within the head buffer alone, the full output is
// returned with nil truncation info. Otherwise the head and
// tail are joined with an omission marker and long lines are
// truncated.
func (b *HeadTailBuffer) Output() (string, *workspacesdk.ProcessTruncation) {
b.mu.Lock()
head := make([]byte, len(b.head))
copy(head, b.head)
tail := b.tailBytes()
total := b.totalBytes
headFull := b.headFull
b.mu.Unlock()
storedLen := len(head) + len(tail)
// If everything fits, no head/tail split is needed.
if !headFull || len(tail) == 0 {
out := truncateLines(string(head))
if total == 0 {
return "", nil
}
return out, nil
}
// We have both head and tail data, meaning the total
// output exceeded the head capacity. Build the
// combined output with an omission marker.
omitted := total - storedLen
headStr := truncateLines(string(head))
tailStr := truncateLines(string(tail))
var sb strings.Builder
_, _ = sb.WriteString(headStr)
if omitted > 0 {
_, _ = sb.WriteString(fmt.Sprintf(
"\n\n... [omitted %d bytes] ...\n\n",
omitted,
))
} else {
// Head and tail are contiguous but were stored
// separately because the head filled up.
_, _ = sb.WriteString("\n")
}
_, _ = sb.WriteString(tailStr)
result := sb.String()
return result, &workspacesdk.ProcessTruncation{
OriginalBytes: total,
RetainedBytes: len(result),
OmittedBytes: omitted,
Strategy: "head_tail",
}
}
// truncateLines scans the input line by line and truncates
// any line longer than MaxLineLength.
func truncateLines(s string) string {
if len(s) <= MaxLineLength {
// Fast path: if the entire string is shorter than
// the max line length, no line can exceed it.
return s
}
var b strings.Builder
b.Grow(len(s))
for len(s) > 0 {
idx := strings.IndexByte(s, '\n')
var line string
if idx == -1 {
line = s
s = ""
} else {
line = s[:idx]
s = s[idx+1:]
}
if len(line) > MaxLineLength {
// Truncate preserving the suffix length so the
// total does not exceed a reasonable size.
cut := MaxLineLength - len(lineTruncationSuffix)
if cut < 0 {
cut = 0
}
_, _ = b.WriteString(line[:cut])
_, _ = b.WriteString(lineTruncationSuffix)
} else {
_, _ = b.WriteString(line)
}
// Re-add the newline unless this was the final
// segment without a trailing newline.
if idx != -1 {
_ = b.WriteByte('\n')
}
}
return b.String()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
defer b.mu.Unlock()
b.head = nil
b.tail = nil
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.totalBytes = 0
}
-338
View File
@@ -1,338 +0,0 @@
package agentproc_test
import (
"fmt"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentproc"
)
func TestHeadTailBuffer_EmptyBuffer(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
require.Empty(t, buf.Bytes())
}
func TestHeadTailBuffer_SmallOutput(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
data := "hello world\n"
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, len(data), n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "small output should not be truncated")
require.Equal(t, len(data), buf.Len())
require.Equal(t, len(data), buf.TotalWritten())
}
func TestHeadTailBuffer_ExactlyHeadSize(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Build data that is exactly MaxHeadBytes using short
// lines so that line truncation does not apply.
line := strings.Repeat("x", 79) + "\n" // 80 bytes per line
count := agentproc.MaxHeadBytes / len(line)
pad := agentproc.MaxHeadBytes - (count * len(line))
data := strings.Repeat(line, count) + strings.Repeat("y", pad)
require.Equal(t, agentproc.MaxHeadBytes, len(data),
"test data must be exactly MaxHeadBytes")
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, agentproc.MaxHeadBytes, n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "output fitting in head should not be truncated")
require.Equal(t, agentproc.MaxHeadBytes, buf.Len())
}
func TestHeadTailBuffer_HeadPlusTailNoOmission(t *testing.T) {
t.Parallel()
// Use a small buffer so we can test the boundary where
// head fills and tail starts but nothing is omitted.
// With maxHead=10, maxTail=10, writing exactly 20 bytes
// means head gets 10, tail gets 10, omitted = 0.
buf := agentproc.NewHeadTailBufferSized(10, 10)
data := "0123456789abcdefghij" // 20 bytes
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 20, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 0, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// The output should contain both head and tail.
require.Contains(t, out, "0123456789")
require.Contains(t, out, "abcdefghij")
}
func TestHeadTailBuffer_LargeOutputTruncation(t *testing.T) {
t.Parallel()
// Use small head/tail so truncation is easy to verify.
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write 100 bytes: head=10, tail=10, omitted=80.
data := strings.Repeat("A", 50) + strings.Repeat("Z", 50)
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 100, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 100, info.OriginalBytes)
require.Equal(t, 80, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// Head should be first 10 bytes (all A's).
require.True(t, strings.HasPrefix(out, "AAAAAAAAAA"))
// Tail should be last 10 bytes (all Z's).
require.True(t, strings.HasSuffix(out, "ZZZZZZZZZZ"))
// Omission marker should be present.
require.Contains(t, out, "... [omitted 80 bytes] ...")
require.Equal(t, 20, buf.Len())
require.Equal(t, 100, buf.TotalWritten())
}
func TestHeadTailBuffer_MultiMBStaysBounded(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write 5MB of data in chunks.
chunk := []byte(strings.Repeat("x", 4096) + "\n")
totalWritten := 0
for totalWritten < 5*1024*1024 {
n, err := buf.Write(chunk)
require.NoError(t, err)
require.Equal(t, len(chunk), n)
totalWritten += n
}
// Memory should be bounded to head+tail.
require.LessOrEqual(t, buf.Len(),
agentproc.MaxHeadBytes+agentproc.MaxTailBytes)
require.Equal(t, totalWritten, buf.TotalWritten())
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, totalWritten, info.OriginalBytes)
require.Greater(t, info.OmittedBytes, 0)
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_LongLineTruncation(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write a line longer than MaxLineLength.
longLine := strings.Repeat("m", agentproc.MaxLineLength+500)
_, err := buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 1)
require.LessOrEqual(t, len(lines[0]), agentproc.MaxLineLength)
require.True(t, strings.HasSuffix(lines[0], "... [truncated]"))
}
func TestHeadTailBuffer_LongLineInTail(t *testing.T) {
t.Parallel()
// Use small buffers so we can force data into the tail.
buf := agentproc.NewHeadTailBufferSized(20, 5000)
// Fill head with short data.
_, err := buf.Write([]byte("head data goes here\n"))
require.NoError(t, err)
// Now write a very long line into the tail.
longLine := strings.Repeat("T", agentproc.MaxLineLength+100)
_, err = buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// The long line in the tail should be truncated.
require.Contains(t, out, "... [truncated]")
}
func TestHeadTailBuffer_ConcurrentWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
const goroutines = 10
const writes = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func() {
defer wg.Done()
line := fmt.Sprintf("goroutine-%d: data\n", g)
for range writes {
_, err := buf.Write([]byte(line))
assert.NoError(t, err)
}
}()
}
wg.Wait()
// Verify totals are consistent.
require.Greater(t, buf.TotalWritten(), 0)
require.Greater(t, buf.Len(), 0)
out, _ := buf.Output()
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_TruncationInfoFields(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write enough to cause omission.
data := strings.Repeat("D", 50)
_, err := buf.Write([]byte(data))
require.NoError(t, err)
_, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 50, info.OriginalBytes)
require.Equal(t, 30, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// RetainedBytes is the length of the formatted output
// string including the omission marker.
require.Greater(t, info.RetainedBytes, 0)
}
func TestHeadTailBuffer_MultipleSmallWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write one byte at a time.
expected := "hello world"
for i := range len(expected) {
n, err := buf.Write([]byte{expected[i]})
require.NoError(t, err)
require.Equal(t, 1, n)
}
out, info := buf.Output()
require.Equal(t, expected, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_WriteEmptySlice(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
n, err := buf.Write([]byte{})
require.NoError(t, err)
require.Equal(t, 0, n)
require.Equal(t, 0, buf.TotalWritten())
}
func TestHeadTailBuffer_Reset(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("some data"))
require.NoError(t, err)
require.Greater(t, buf.Len(), 0)
buf.Reset()
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_BytesReturnsCopy(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("original"))
require.NoError(t, err)
b := buf.Bytes()
require.Equal(t, []byte("original"), b)
// Mutating the returned slice should not affect the
// buffer.
b[0] = 'X'
require.Equal(t, []byte("original"), buf.Bytes())
}
func TestHeadTailBuffer_RingBufferWraparound(t *testing.T) {
t.Parallel()
// Use a tail of 10 bytes and write enough to wrap
// around multiple times.
buf := agentproc.NewHeadTailBufferSized(5, 10)
// Fill head (5 bytes).
_, err := buf.Write([]byte("HEADD"))
require.NoError(t, err)
// Write 25 bytes into tail, wrapping 2.5 times.
_, err = buf.Write([]byte("0123456789"))
require.NoError(t, err)
_, err = buf.Write([]byte("abcdefghij"))
require.NoError(t, err)
_, err = buf.Write([]byte("ABCDE"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// Tail should contain the last 10 bytes: "fghijABCDE".
require.True(t, strings.HasSuffix(out, "fghijABCDE"),
"expected tail to be last 10 bytes, got: %q", out)
}
func TestHeadTailBuffer_MultipleLinesTruncated(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
short := "short line\n"
long := strings.Repeat("L", agentproc.MaxLineLength+100) + "\n"
_, err := buf.Write([]byte(short + long + short))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 3)
require.Equal(t, "short line", lines[0])
require.True(t, strings.HasSuffix(lines[1], "... [truncated]"))
require.Equal(t, "short line", lines[2])
}
-294
View File
@@ -1,294 +0,0 @@
package agentproc
import (
"context"
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
var (
errProcessNotFound = xerrors.New("process not found")
errProcessNotRunning = xerrors.New("process is not running")
)
// process represents a running or completed process.
type process struct {
mu sync.Mutex
id string
command string
workDir string
background bool
cmd *exec.Cmd
cancel context.CancelFunc
buf *HeadTailBuffer
running bool
exitCode *int
startedAt int64
exitedAt *int64
done chan struct{} // closed when process exits
}
// info returns a snapshot of the process state.
func (p *process) info() workspacesdk.ProcessInfo {
p.mu.Lock()
defer p.mu.Unlock()
return workspacesdk.ProcessInfo{
ID: p.id,
Command: p.command,
WorkDir: p.workDir,
Background: p.background,
Running: p.running,
ExitCode: p.exitCode,
StartedAt: p.startedAt,
ExitedAt: p.exitedAt,
}
}
// output returns the truncated output from the process buffer
// along with optional truncation metadata.
func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
return p.buf.Output()
}
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
// start spawns a new process. Both foreground and background
// processes use a long-lived context so the process survives
// the HTTP request lifecycle. The background flag only affects
// client-side polling behavior.
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil, xerrors.New("manager is closed")
}
m.mu.Unlock()
id := uuid.New().String()
// Use a cancellable context so Close() can terminate
// all processes. context.Background() is the parent so
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
// still holding the stdout/stderr pipes open.
cmd.WaitDelay = 5 * time.Second
buf := NewHeadTailBuffer()
cmd.Stdout = buf
cmd.Stderr = buf
// Build the process environment. If the manager has an
// updateEnv hook (provided by the agent), use it to get the
// full agent environment including GIT_ASKPASS, CODER_* vars,
// etc. Otherwise fall back to the current process env.
baseEnv := os.Environ()
if m.updateEnv != nil {
updated, err := m.updateEnv(baseEnv)
if err != nil {
m.logger.Warn(
context.Background(),
"failed to update command environment, falling back to os env",
slog.Error(err),
)
} else {
baseEnv = updated
}
}
// Always set cmd.Env explicitly so that req.Env overrides
// are applied on top of the full agent environment.
cmd.Env = baseEnv
for k, v := range req.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
if err := cmd.Start(); err != nil {
cancel()
return nil, xerrors.Errorf("start process: %w", err)
}
now := m.clock.Now().Unix()
proc := &process{
id: id,
command: req.Command,
workDir: req.WorkDir,
background: req.Background,
cmd: cmd,
cancel: cancel,
buf: buf,
running: true,
startedAt: now,
done: make(chan struct{}),
}
m.mu.Lock()
if m.closed {
m.mu.Unlock()
// Manager closed between our check and now. Kill the
// process we just started.
cancel()
_ = cmd.Wait()
return nil, xerrors.New("manager is closed")
}
m.procs[id] = proc
m.mu.Unlock()
go func() {
err := cmd.Wait()
exitedAt := m.clock.Now().Unix()
proc.mu.Lock()
proc.running = false
proc.exitedAt = &exitedAt
code := 0
if err != nil {
// Extract the exit code from the error.
var exitErr *exec.ExitError
if xerrors.As(err, &exitErr) {
code = exitErr.ExitCode()
} else {
// Unknown error; use -1 as a sentinel.
code = -1
m.logger.Warn(
context.Background(),
"process wait returned non-exit error",
slog.F("id", id),
slog.Error(err),
)
}
}
proc.exitCode = &code
proc.mu.Unlock()
close(proc.done)
}()
return proc, nil
}
// get returns a process by ID.
func (m *manager) get(id string) (*process, bool) {
m.mu.Lock()
defer m.mu.Unlock()
proc, ok := m.procs[id]
return proc, ok
}
// list returns info about all tracked processes.
func (m *manager) list() []workspacesdk.ProcessInfo {
m.mu.Lock()
defer m.mu.Unlock()
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
for _, proc := range m.procs {
infos = append(infos, proc.info())
}
return infos
}
// signal sends a signal to a running process. It returns
// sentinel errors errProcessNotFound and errProcessNotRunning
// so callers can distinguish failure modes.
func (m *manager) signal(id string, sig string) error {
m.mu.Lock()
proc, ok := m.procs[id]
m.mu.Unlock()
if !ok {
return errProcessNotFound
}
proc.mu.Lock()
defer proc.mu.Unlock()
if !proc.running {
return errProcessNotRunning
}
switch sig {
case "kill":
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
return xerrors.Errorf("unsupported signal %q", sig)
}
return nil
}
// Close kills all running processes and prevents new ones from
// starting. It cancels each process's context, which causes
// CommandContext to kill the process and its pipe goroutines to
// drain.
func (m *manager) Close() error {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil
}
m.closed = true
procs := make([]*process, 0, len(m.procs))
for _, p := range m.procs {
procs = append(procs, p)
}
m.mu.Unlock()
for _, p := range procs {
p.cancel()
}
// Wait for all processes to exit.
for _, p := range procs {
<-p.done
}
return nil
}
+103 -2
View File
@@ -1,22 +1,37 @@
package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestServer(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -26,7 +41,7 @@ func TestServer(t *testing.T) {
t.Run("AlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server1, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -34,4 +49,90 @@ func TestServer(t *testing.T) {
_, err = agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "create socket")
})
t.Run("AutoSocketPath", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
require.NoError(t, server.Close())
})
}
func TestServerWindowsNotSupported(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
t.Run("NewServer", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
_, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
t.Run("NewClient", func(t *testing.T) {
t.Parallel()
_, err := agentsocket.NewClient(context.Background(), agentsocket.WithPath("test.sock"))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
}
func TestAgentInitializesOnWindowsWithoutSocketServer(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t).Named("agent")
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *agentproto.Stats, 50)
agentID := uuid.New()
manifest := agentsdk.Manifest{
AgentID: agentID,
AgentName: "test-agent",
WorkspaceName: "test-workspace",
OwnerName: "test-user",
WorkspaceID: uuid.New(),
DERPMap: derpMap,
}
client := agenttest.NewClient(t, logger.Named("agenttest"), agentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)
options := agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: testutil.WaitShort,
EnvironmentVariables: map[string]string{},
SocketPath: "",
}
agnt := agent.New(options)
t.Cleanup(func() {
_ = agnt.Close()
})
startup := testutil.TryReceive(ctx, t, client.GetStartup())
require.NotNil(t, startup, "agent should send startup message")
err := agnt.Close()
require.NoError(t, err, "agent should close cleanly")
}
+17 -11
View File
@@ -2,6 +2,8 @@ package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/require"
@@ -28,10 +30,14 @@ func newSocketClient(ctx context.Context, t *testing.T, socketPath string) *agen
func TestDRPCAgentSocketService(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("Ping", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -51,7 +57,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -73,7 +79,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -103,7 +109,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -142,7 +148,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -172,7 +178,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnits", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -197,7 +203,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -232,7 +238,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -274,7 +280,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnregisteredUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -293,7 +299,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -317,7 +323,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
+6 -47
View File
@@ -4,60 +4,19 @@ package agentsocket
import (
"context"
"fmt"
"net"
"os"
"os/user"
"strings"
"github.com/Microsoft/go-winio"
"golang.org/x/xerrors"
)
const defaultSocketPath = `\\.\pipe\com.coder.agentsocket`
func createSocket(path string) (net.Listener, error) {
if path == "" {
path = defaultSocketPath
}
if !strings.HasPrefix(path, `\\.\pipe\`) {
return nil, xerrors.Errorf("%q is not a valid local socket path", path)
}
user, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to look up current user: %w", err)
}
sid := user.Uid
// SecurityDescriptor is in SDDL format. c.f.
// https://learn.microsoft.com/en-us/windows/win32/secauthz/security-descriptor-string-format for full details.
// D: indicates this is a Discretionary Access Control List (DACL), which is Windows-speak for ACLs that allow or
// deny access (as opposed to SACL which controls audit logging).
// P indicates that this DACL is "protected" from being modified thru inheritance
// () delimit access control entries (ACEs), here we only have one, which, allows (A) generic all (GA) access to our
// specific user's security ID (SID).
//
// Note that although Microsoft docs at https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes warns that
// named pipes are accessible from remote machines in the general case, the `winio` package sets the flag
// windows.FILE_PIPE_REJECT_REMOTE_CLIENTS when creating pipes, so connections from remote machines are always
// denied. This is important because we sort of expect customers to run the Coder agent under a generic user
// account unless they are very sophisticated. We don't want this socket to cross the boundary of the local machine.
configuration := &winio.PipeConfig{
SecurityDescriptor: fmt.Sprintf("D:P(A;;GA;;;%s)", sid),
}
listener, err := winio.ListenPipe(path, configuration)
if err != nil {
return nil, xerrors.Errorf("failed to open named pipe: %w", err)
}
return listener, nil
func createSocket(_ string) (net.Listener, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
func cleanupSocket(path string) error {
return os.Remove(path)
func cleanupSocket(_ string) error {
return nil
}
func dialSocket(ctx context.Context, path string) (net.Conn, error) {
return winio.DialPipeContext(ctx, path)
func dialSocket(_ context.Context, _ string) (net.Conn, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
-1
View File
@@ -24,7 +24,6 @@ func New(t testing.TB, coderURL *url.URL, agentToken string, opts ...func(*agent
var o agent.Options
log := testutil.Logger(t).Named("agent")
o.Logger = log
o.SocketPath = testutil.AgentSocketPath(t)
for _, opt := range opts {
opt(&o)
-10
View File
@@ -124,12 +124,6 @@ func (c *Client) Close() {
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
}
func (c *Client) ConnectRPC28WithRole(ctx context.Context, _ string) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
return c.ConnectRPC28(ctx)
}
func (c *Client) ConnectRPC28(ctx context.Context) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
@@ -235,10 +229,6 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
-1
View File
@@ -28,7 +28,6 @@ func (a *agent) apiHandler() http.Handler {
})
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
+1 -2
View File
@@ -10,7 +10,6 @@ import (
"testing"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -70,7 +69,7 @@ func TestBoundaryLogs_EndToEnd(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
-286
View File
@@ -1,286 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.30.0
// protoc v4.23.4
// source: agent/boundarylogproxy/codec/boundary.proto
package codec
import (
proto "github.com/coder/coder/v2/agent/proto"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// BoundaryMessage is the envelope for all TagV2 messages sent over the
// boundary <-> agent unix socket. TagV1 carries a bare
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
// everything in this envelope so the protocol can be extended with new
// message types without adding more tags.
type BoundaryMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Msg:
//
// *BoundaryMessage_Logs
// *BoundaryMessage_Status
Msg isBoundaryMessage_Msg `protobuf_oneof:"msg"`
}
func (x *BoundaryMessage) Reset() {
*x = BoundaryMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *BoundaryMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*BoundaryMessage) ProtoMessage() {}
func (x *BoundaryMessage) ProtoReflect() protoreflect.Message {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use BoundaryMessage.ProtoReflect.Descriptor instead.
func (*BoundaryMessage) Descriptor() ([]byte, []int) {
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{0}
}
func (m *BoundaryMessage) GetMsg() isBoundaryMessage_Msg {
if m != nil {
return m.Msg
}
return nil
}
func (x *BoundaryMessage) GetLogs() *proto.ReportBoundaryLogsRequest {
if x, ok := x.GetMsg().(*BoundaryMessage_Logs); ok {
return x.Logs
}
return nil
}
func (x *BoundaryMessage) GetStatus() *BoundaryStatus {
if x, ok := x.GetMsg().(*BoundaryMessage_Status); ok {
return x.Status
}
return nil
}
type isBoundaryMessage_Msg interface {
isBoundaryMessage_Msg()
}
type BoundaryMessage_Logs struct {
Logs *proto.ReportBoundaryLogsRequest `protobuf:"bytes,1,opt,name=logs,proto3,oneof"`
}
type BoundaryMessage_Status struct {
Status *BoundaryStatus `protobuf:"bytes,2,opt,name=status,proto3,oneof"`
}
func (*BoundaryMessage_Logs) isBoundaryMessage_Msg() {}
func (*BoundaryMessage_Status) isBoundaryMessage_Msg() {}
// BoundaryStatus carries operational metadata from boundary to the agent.
// The agent records these values as Prometheus metrics. This message is
// never forwarded to coderd.
type BoundaryStatus struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Logs dropped because boundary's internal channel buffer was full.
DroppedChannelFull int64 `protobuf:"varint,1,opt,name=dropped_channel_full,json=droppedChannelFull,proto3" json:"dropped_channel_full,omitempty"`
// Logs dropped because boundary's batch buffer was full after a
// failed flush attempt.
DroppedBatchFull int64 `protobuf:"varint,2,opt,name=dropped_batch_full,json=droppedBatchFull,proto3" json:"dropped_batch_full,omitempty"`
}
func (x *BoundaryStatus) Reset() {
*x = BoundaryStatus{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *BoundaryStatus) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*BoundaryStatus) ProtoMessage() {}
func (x *BoundaryStatus) ProtoReflect() protoreflect.Message {
mi := &file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use BoundaryStatus.ProtoReflect.Descriptor instead.
func (*BoundaryStatus) Descriptor() ([]byte, []int) {
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP(), []int{1}
}
func (x *BoundaryStatus) GetDroppedChannelFull() int64 {
if x != nil {
return x.DroppedChannelFull
}
return 0
}
func (x *BoundaryStatus) GetDroppedBatchFull() int64 {
if x != nil {
return x.DroppedBatchFull
}
return 0
}
var File_agent_boundarylogproxy_codec_boundary_proto protoreflect.FileDescriptor
var file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = []byte{
0x0a, 0x2b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79,
0x6c, 0x6f, 0x67, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2f, 0x62,
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x1f, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x1a, 0x17,
0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x67, 0x65, 0x6e,
0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, 0x01, 0x0a, 0x0f, 0x42, 0x6f, 0x75, 0x6e,
0x64, 0x61, 0x72, 0x79, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3f, 0x0a, 0x04, 0x6c,
0x6f, 0x67, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x63, 0x6f, 0x64, 0x65,
0x72, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x76, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72,
0x74, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x4c, 0x6f, 0x67, 0x73, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x04, 0x6c, 0x6f, 0x67, 0x73, 0x12, 0x49, 0x0a, 0x06,
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2f, 0x2e, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2e, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67,
0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2e, 0x76, 0x31, 0x2e, 0x42,
0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x48, 0x00, 0x52,
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, 0x70,
0x0a, 0x0e, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73,
0x12, 0x30, 0x0a, 0x14, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x6e,
0x6e, 0x65, 0x6c, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x12,
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x43, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x46, 0x75,
0x6c, 0x6c, 0x12, 0x2c, 0x0a, 0x12, 0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x5f, 0x62, 0x61,
0x74, 0x63, 0x68, 0x5f, 0x66, 0x75, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x10,
0x64, 0x72, 0x6f, 0x70, 0x70, 0x65, 0x64, 0x42, 0x61, 0x74, 0x63, 0x68, 0x46, 0x75, 0x6c, 0x6c,
0x42, 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x6f, 0x64, 0x65, 0x72, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x72, 0x2f, 0x76, 0x32, 0x2f, 0x61, 0x67,
0x65, 0x6e, 0x74, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x6c, 0x6f, 0x67, 0x70,
0x72, 0x6f, 0x78, 0x79, 0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce sync.Once
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = file_agent_boundarylogproxy_codec_boundary_proto_rawDesc
)
func file_agent_boundarylogproxy_codec_boundary_proto_rawDescGZIP() []byte {
file_agent_boundarylogproxy_codec_boundary_proto_rawDescOnce.Do(func() {
file_agent_boundarylogproxy_codec_boundary_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_boundarylogproxy_codec_boundary_proto_rawDescData)
})
return file_agent_boundarylogproxy_codec_boundary_proto_rawDescData
}
var file_agent_boundarylogproxy_codec_boundary_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_agent_boundarylogproxy_codec_boundary_proto_goTypes = []interface{}{
(*BoundaryMessage)(nil), // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage
(*BoundaryStatus)(nil), // 1: coder.boundarylogproxy.codec.v1.BoundaryStatus
(*proto.ReportBoundaryLogsRequest)(nil), // 2: coder.agent.v2.ReportBoundaryLogsRequest
}
var file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = []int32{
2, // 0: coder.boundarylogproxy.codec.v1.BoundaryMessage.logs:type_name -> coder.agent.v2.ReportBoundaryLogsRequest
1, // 1: coder.boundarylogproxy.codec.v1.BoundaryMessage.status:type_name -> coder.boundarylogproxy.codec.v1.BoundaryStatus
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_agent_boundarylogproxy_codec_boundary_proto_init() }
func file_agent_boundarylogproxy_codec_boundary_proto_init() {
if File_agent_boundarylogproxy_codec_boundary_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*BoundaryMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*BoundaryStatus); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_agent_boundarylogproxy_codec_boundary_proto_msgTypes[0].OneofWrappers = []interface{}{
(*BoundaryMessage_Logs)(nil),
(*BoundaryMessage_Status)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_agent_boundarylogproxy_codec_boundary_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_agent_boundarylogproxy_codec_boundary_proto_goTypes,
DependencyIndexes: file_agent_boundarylogproxy_codec_boundary_proto_depIdxs,
MessageInfos: file_agent_boundarylogproxy_codec_boundary_proto_msgTypes,
}.Build()
File_agent_boundarylogproxy_codec_boundary_proto = out.File
file_agent_boundarylogproxy_codec_boundary_proto_rawDesc = nil
file_agent_boundarylogproxy_codec_boundary_proto_goTypes = nil
file_agent_boundarylogproxy_codec_boundary_proto_depIdxs = nil
}
@@ -1,29 +0,0 @@
syntax = "proto3";
option go_package = "github.com/coder/coder/v2/agent/boundarylogproxy/codec";
package coder.boundarylogproxy.codec.v1;
import "agent/proto/agent.proto";
// BoundaryMessage is the envelope for all TagV2 messages sent over the
// boundary <-> agent unix socket. TagV1 carries a bare
// ReportBoundaryLogsRequest for backwards compatibility; TagV2 wraps
// everything in this envelope so the protocol can be extended with new
// message types without adding more tags.
message BoundaryMessage {
oneof msg {
coder.agent.v2.ReportBoundaryLogsRequest logs = 1;
BoundaryStatus status = 2;
}
}
// BoundaryStatus carries operational metadata from boundary to the agent.
// The agent records these values as Prometheus metrics. This message is
// never forwarded to coderd.
message BoundaryStatus {
// Logs dropped because boundary's internal channel buffer was full.
int64 dropped_channel_full = 1;
// Logs dropped because boundary's batch buffer was full after a
// failed flush attempt.
int64 dropped_batch_full = 2;
}
+14 -73
View File
@@ -14,23 +14,14 @@ import (
"io"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
agentproto "github.com/coder/coder/v2/agent/proto"
)
type Tag uint8
const (
// TagV1 identifies the first revision of the protocol. The payload is a
// bare ReportBoundaryLogsRequest. This version has a maximum data length
// of MaxMessageSizeV1.
// TagV1 identifies the first revision of the protocol. This version has a maximum
// data length of MaxMessageSizeV1.
TagV1 Tag = 1
// TagV2 identifies the second revision of the protocol. The payload is
// a BoundaryMessage envelope. This version has a maximum data length of
// MaxMessageSizeV2.
TagV2 Tag = 2
)
const (
@@ -44,9 +35,6 @@ const (
// over the wire for the TagV1 tag. While the wire format allows 24 bits for
// length, TagV1 only uses 15 bits.
MaxMessageSizeV1 uint32 = 1 << 15
// MaxMessageSizeV2 is the maximum data length for TagV2.
MaxMessageSizeV2 = MaxMessageSizeV1
)
var (
@@ -60,9 +48,12 @@ var (
// WriteFrame writes a framed message with the given tag and data. The data
// must not exceed 2^DataLength in length.
func WriteFrame(w io.Writer, tag Tag, data []byte) error {
maxSize, err := maxSizeForTag(tag)
if err != nil {
return err
var maxSize uint32
switch tag {
case TagV1:
maxSize = MaxMessageSizeV1
default:
return xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
if len(data) > int(maxSize) {
@@ -110,9 +101,12 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
}
tag := Tag(shifted)
maxSize, err := maxSizeForTag(tag)
if err != nil {
return 0, nil, err
var maxSize uint32
switch tag {
case TagV1:
maxSize = MaxMessageSizeV1
default:
return 0, nil, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
if length > maxSize {
@@ -131,56 +125,3 @@ func ReadFrame(r io.Reader, buf []byte) (Tag, []byte, error) {
return tag, buf[:length], nil
}
// maxSizeForTag returns the maximum payload size for the given tag.
func maxSizeForTag(tag Tag) (uint32, error) {
switch tag {
case TagV1:
return MaxMessageSizeV1, nil
case TagV2:
return MaxMessageSizeV2, nil
default:
return 0, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
}
// ReadMessage reads a framed message and unmarshals it based on tag. The
// returned buf should be passed back on the next call for buffer reuse.
func ReadMessage(r io.Reader, buf []byte) (proto.Message, []byte, error) {
tag, data, err := ReadFrame(r, buf)
if err != nil {
return nil, data, err
}
var msg proto.Message
switch tag {
case TagV1:
var req agentproto.ReportBoundaryLogsRequest
if err := proto.Unmarshal(data, &req); err != nil {
return nil, data, xerrors.Errorf("unmarshal TagV1: %w", err)
}
msg = &req
case TagV2:
var envelope BoundaryMessage
if err := proto.Unmarshal(data, &envelope); err != nil {
return nil, data, xerrors.Errorf("unmarshal TagV2: %w", err)
}
msg = &envelope
default:
// maxSizeForTag already rejects unknown tags during ReadFrame,
// but handle it here for safety.
return nil, data, xerrors.Errorf("%w: %d", ErrUnsupportedTag, tag)
}
return msg, data, nil
}
// WriteMessage marshals a proto message and writes it as a framed message
// with the given tag.
func WriteMessage(w io.Writer, tag Tag, msg proto.Message) error {
data, err := proto.Marshal(msg)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
return WriteFrame(w, tag, data)
}
+2 -2
View File
@@ -89,7 +89,7 @@ func TestReadFrameInvalidTag(t *testing.T) {
// reading the invalid tag.
const (
dataLength uint32 = 10
bogusTag uint32 = 222
bogusTag uint32 = 2
)
header := bogusTag<<codec.DataLength | dataLength
data := make([]byte, 4)
@@ -139,7 +139,7 @@ func TestWriteFrameInvalidTag(t *testing.T) {
var buf bytes.Buffer
data := make([]byte, 1)
const bogusTag = 222
const bogusTag = 2
err := codec.WriteFrame(&buf, codec.Tag(bogusTag), data)
require.ErrorIs(t, err, codec.ErrUnsupportedTag)
}
-77
View File
@@ -1,77 +0,0 @@
package boundarylogproxy
import "github.com/prometheus/client_golang/prometheus"
// Metrics tracks observability for the boundary -> agent -> coderd audit log
// pipeline.
//
// Audit logs from boundary workspaces pass through several async buffers
// before reaching coderd, and any stage can silently drop data. These
// metrics make that loss visible so operators/devs can:
//
// - Bubble up data loss: a non-zero drop rate means audit logs are being
// lost, which may have auditing implications.
// - Identify the bottleneck: the reason label pinpoints where drops
// occur: boundary's internal buffers, the agent's channel, or the
// RPC to coderd.
// - Tune buffer sizes: sustained "buffer_full" drops indicate the
// agent's channel (or boundary's batch buffer) is too small for the
// workload. Combined with batches_forwarded_total you can compute a
// drop rate: drops / (drops + forwards).
// - Detect batch forwarding issues: "forward_failed" drops increase when
// the agent cannot reach coderd.
//
// Drops are captured at two stages:
// - Agent-side: the agent's channel buffer overflows (reason
// "buffer_full") or the RPC forward to coderd fails (reason
// "forward_failed").
// - Boundary-reported: boundary self-reports drops via BoundaryStatus
// messages (reasons "boundary_channel_full", "boundary_batch_full").
// These arrive on the next successful flush from boundary.
//
// There are circumstances where metrics could be lost e.g., agent restarts,
// boundary crashes, or the agent shuts down when the DRPC connection is down.
type Metrics struct {
batchesDropped *prometheus.CounterVec
logsDropped *prometheus.CounterVec
batchesForwarded prometheus.Counter
}
func newMetrics(registerer prometheus.Registerer) *Metrics {
batchesDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "batches_dropped_total",
Help: "Total number of boundary log batches dropped before reaching coderd. " +
"Reason: buffer_full = the agent's internal buffer is full, meaning boundary is producing logs faster than the agent can forward them to coderd; " +
"forward_failed = the agent failed to send the batch to coderd, potentially because coderd is unreachable or the connection was interrupted.",
}, []string{"reason"})
registerer.MustRegister(batchesDropped)
logsDropped := prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "logs_dropped_total",
Help: "Total number of individual boundary log entries dropped before reaching coderd. " +
"Reason: buffer_full = the agent's internal buffer is full; " +
"forward_failed = the agent failed to send the batch to coderd; " +
"boundary_channel_full = boundary's internal send channel overflowed, meaning boundary is generating logs faster than it can batch and send them; " +
"boundary_batch_full = boundary's outgoing batch buffer overflowed after a failed flush, meaning boundary could not write to the agent's socket.",
}, []string{"reason"})
registerer.MustRegister(logsDropped)
batchesForwarded := prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "agent",
Subsystem: "boundary_log_proxy",
Name: "batches_forwarded_total",
Help: "Total number of boundary log batches successfully forwarded to coderd. " +
"Compare with batches_dropped_total to compute a drop rate.",
})
registerer.MustRegister(batchesForwarded)
return &Metrics{
batchesDropped: batchesDropped,
logsDropped: logsDropped,
batchesForwarded: batchesForwarded,
}
}
+23 -60
View File
@@ -11,7 +11,6 @@ import (
"path/filepath"
"sync"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
@@ -27,13 +26,6 @@ const (
logBufferSize = 100
)
const (
droppedReasonBoundaryChannelFull = "boundary_channel_full"
droppedReasonBoundaryBatchFull = "boundary_batch_full"
droppedReasonBufferFull = "buffer_full"
droppedReasonForwardFailed = "forward_failed"
)
// DefaultSocketPath returns the default path for the boundary audit log socket.
func DefaultSocketPath() string {
return filepath.Join(os.TempDir(), "boundary-audit.sock")
@@ -51,7 +43,6 @@ type Reporter interface {
type Server struct {
logger slog.Logger
socketPath string
metrics *Metrics
listener net.Listener
cancel context.CancelFunc
@@ -62,11 +53,10 @@ type Server struct {
}
// NewServer creates a new boundary log proxy server.
func NewServer(logger slog.Logger, socketPath string, registerer prometheus.Registerer) *Server {
func NewServer(logger slog.Logger, socketPath string) *Server {
return &Server{
logger: logger.Named("boundary-log-proxy"),
socketPath: socketPath,
metrics: newMetrics(registerer),
logs: make(chan *agentproto.ReportBoundaryLogsRequest, logBufferSize),
}
}
@@ -110,13 +100,9 @@ func (s *Server) RunForwarder(ctx context.Context, sender Reporter) error {
s.logger.Warn(ctx, "failed to forward boundary logs",
slog.Error(err),
slog.F("log_count", len(req.Logs)))
s.metrics.batchesDropped.WithLabelValues(droppedReasonForwardFailed).Inc()
s.metrics.logsDropped.WithLabelValues(droppedReasonForwardFailed).Add(float64(len(req.Logs)))
// Continue forwarding other logs. The current batch is lost,
// but the socket stays alive.
continue
}
s.metrics.batchesForwarded.Inc()
}
}
}
@@ -153,8 +139,8 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
_ = conn.Close()
}()
// This is intended to be a sane starting point for the read buffer size.
// It may be grown by codec.ReadMessage if necessary.
// This is intended to be a sane starting point for the read buffer size. It may be
// grown by codec.ReadFrame if necessary.
const initBufSize = 1 << 10
buf := make([]byte, initBufSize)
@@ -165,59 +151,36 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
default:
}
var err error
var msg proto.Message
msg, buf, err = codec.ReadMessage(conn, buf)
var (
tag codec.Tag
err error
)
tag, buf, err = codec.ReadFrame(conn, buf)
switch {
case errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed):
return
case errors.Is(err, codec.ErrUnsupportedTag) || errors.Is(err, codec.ErrMessageTooLarge):
case err != nil:
s.logger.Warn(ctx, "read frame error", slog.Error(err))
return
case err != nil:
s.logger.Warn(ctx, "read message error", slog.Error(err))
}
if tag != codec.TagV1 {
s.logger.Warn(ctx, "invalid tag value", slog.F("tag", tag))
return
}
var req agentproto.ReportBoundaryLogsRequest
if err := proto.Unmarshal(buf, &req); err != nil {
s.logger.Warn(ctx, "proto unmarshal error", slog.Error(err))
continue
}
s.handleMessage(ctx, msg)
}
}
func (s *Server) handleMessage(ctx context.Context, msg proto.Message) {
switch m := msg.(type) {
case *agentproto.ReportBoundaryLogsRequest:
s.bufferLogs(ctx, m)
case *codec.BoundaryMessage:
switch inner := m.Msg.(type) {
case *codec.BoundaryMessage_Logs:
s.bufferLogs(ctx, inner.Logs)
case *codec.BoundaryMessage_Status:
s.recordBoundaryStatus(inner.Status)
select {
case s.logs <- &req:
default:
s.logger.Warn(ctx, "unknown BoundaryMessage variant")
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
slog.F("log_count", len(req.Logs)))
}
default:
s.logger.Warn(ctx, "unexpected message type")
}
}
func (s *Server) recordBoundaryStatus(status *codec.BoundaryStatus) {
if n := status.DroppedChannelFull; n > 0 {
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryChannelFull).Add(float64(n))
}
if n := status.DroppedBatchFull; n > 0 {
s.metrics.logsDropped.WithLabelValues(droppedReasonBoundaryBatchFull).Add(float64(n))
}
}
func (s *Server) bufferLogs(ctx context.Context, req *agentproto.ReportBoundaryLogsRequest) {
select {
case s.logs <- req:
default:
s.logger.Warn(ctx, "dropping boundary logs, buffer full",
slog.F("log_count", len(req.Logs)))
s.metrics.batchesDropped.WithLabelValues(droppedReasonBufferFull).Inc()
s.metrics.logsDropped.WithLabelValues(droppedReasonBufferFull).Add(float64(len(req.Logs)))
}
}
+26 -303
View File
@@ -11,8 +11,8 @@ import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/agent/boundarylogproxy"
@@ -21,42 +21,20 @@ import (
"github.com/coder/coder/v2/testutil"
)
// sendLogsV1 writes a bare ReportBoundaryLogsRequest using TagV1, the
// legacy framing that existing boundary deployments use.
func sendLogsV1(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
// sendMessage writes a framed protobuf message to the connection.
func sendMessage(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
t.Helper()
err := codec.WriteMessage(conn, codec.TagV1, req)
data, err := proto.Marshal(req)
if err != nil {
t.Errorf("write v1 logs: %s", err)
//nolint:gocritic // In tests we're not worried about conn being nil.
t.Errorf("%s marshal req: %s", conn.LocalAddr().String(), err)
}
}
// sendLogs writes a BoundaryMessage envelope containing logs to the
// connection using TagV2.
func sendLogs(t *testing.T, conn net.Conn, req *agentproto.ReportBoundaryLogsRequest) {
t.Helper()
msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Logs{Logs: req},
}
err := codec.WriteMessage(conn, codec.TagV2, msg)
err = codec.WriteFrame(conn, codec.TagV1, data)
if err != nil {
t.Errorf("write logs: %s", err)
}
}
// sendStatus writes a BoundaryMessage envelope containing a BoundaryStatus
// to the connection using TagV2.
func sendStatus(t *testing.T, conn net.Conn, status *codec.BoundaryStatus) {
t.Helper()
msg := &codec.BoundaryMessage{
Msg: &codec.BoundaryMessage_Status{Status: status},
}
err := codec.WriteMessage(conn, codec.TagV2, msg)
if err != nil {
t.Errorf("write status: %s", err)
//nolint:gocritic // In tests we're not worried about conn being nil.
t.Errorf("%s write frame: %s", conn.LocalAddr().String(), err)
}
}
@@ -102,7 +80,7 @@ func TestServer_StartAndClose(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -121,7 +99,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -158,7 +136,7 @@ func TestServer_ReceiveAndForwardLogs(t *testing.T) {
},
}
sendLogs(t, conn, req)
sendMessage(t, conn, req)
// Wait for the reporter to receive the log.
require.Eventually(t, func() bool {
@@ -181,7 +159,7 @@ func TestServer_MultipleMessages(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -217,7 +195,7 @@ func TestServer_MultipleMessages(t *testing.T) {
},
},
}
sendLogs(t, conn, req)
sendMessage(t, conn, req)
}
require.Eventually(t, func() bool {
@@ -233,7 +211,7 @@ func TestServer_MultipleConnections(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -276,7 +254,7 @@ func TestServer_MultipleConnections(t *testing.T) {
},
},
}
sendLogs(t, conn, req)
sendMessage(t, conn, req)
}(i)
}
wg.Wait()
@@ -294,7 +272,7 @@ func TestServer_MessageTooLarge(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -322,7 +300,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -364,7 +342,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
},
},
}
sendLogs(t, conn, req1)
sendMessage(t, conn, req1)
select {
case <-reportNotify:
@@ -387,7 +365,7 @@ func TestServer_ForwarderContinuesAfterError(t *testing.T) {
},
},
}
sendLogs(t, conn, req2)
sendMessage(t, conn, req2)
// Only the second message should be recorded.
require.Eventually(t, func() bool {
@@ -407,7 +385,7 @@ func TestServer_CloseStopsForwarder(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -436,7 +414,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -480,7 +458,7 @@ func TestServer_InvalidProtobuf(t *testing.T) {
},
},
}
sendLogs(t, conn, req)
sendMessage(t, conn, req)
require.Eventually(t, func() bool {
logs := reporter.getLogs()
@@ -495,7 +473,7 @@ func TestServer_InvalidHeader(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -545,7 +523,7 @@ func TestServer_AllowRequest(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath)
err := srv.Start()
require.NoError(t, err)
@@ -581,7 +559,7 @@ func TestServer_AllowRequest(t *testing.T) {
},
},
}
sendLogs(t, conn, req)
sendMessage(t, conn, req)
require.Eventually(t, func() bool {
logs := reporter.getLogs()
@@ -598,258 +576,3 @@ func TestServer_AllowRequest(t *testing.T) {
cancel()
<-forwarderDone
}
func TestServer_TagV1BackwardsCompatibility(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, prometheus.NewRegistry())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
reporter := &fakeReporter{}
forwarderDone := make(chan error, 1)
go func() {
forwarderDone <- srv.RunForwarder(ctx, reporter)
}()
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Send a TagV1 message (bare ReportBoundaryLogsRequest) to verify
// the server still handles the legacy framing used by existing
// boundary deployments.
v1Req := &agentproto.ReportBoundaryLogsRequest{
Logs: []*agentproto.BoundaryLog{
{
Allowed: true,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "GET",
Url: "https://example.com/v1",
},
},
},
},
}
sendLogsV1(t, conn, v1Req)
require.Eventually(t, func() bool {
return len(reporter.getLogs()) == 1
}, testutil.WaitShort, testutil.IntervalFast)
// Now send a TagV2 message on the same connection to verify both
// tag versions work interleaved.
v2Req := &agentproto.ReportBoundaryLogsRequest{
Logs: []*agentproto.BoundaryLog{
{
Allowed: false,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "POST",
Url: "https://example.com/v2",
},
},
},
},
}
sendLogs(t, conn, v2Req)
require.Eventually(t, func() bool {
return len(reporter.getLogs()) == 2
}, testutil.WaitShort, testutil.IntervalFast)
logs := reporter.getLogs()
require.Equal(t, "https://example.com/v1", logs[0].Logs[0].GetHttpRequest().Url)
require.Equal(t, "https://example.com/v2", logs[1].Logs[0].GetHttpRequest().Url)
cancel()
<-forwarderDone
}
func TestServer_Metrics(t *testing.T) {
t.Parallel()
makeReq := func(n int) *agentproto.ReportBoundaryLogsRequest {
logs := make([]*agentproto.BoundaryLog, n)
for i := range n {
logs[i] = &agentproto.BoundaryLog{
Allowed: true,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{
HttpRequest: &agentproto.BoundaryLog_HttpRequest{
Method: "GET",
Url: "https://example.com",
},
},
}
}
return &agentproto.ReportBoundaryLogsRequest{Logs: logs}
}
// BufferFull needs its own setup because it intentionally does not run
// a forwarder so the channel fills up.
t.Run("BufferFull", func(t *testing.T) {
t.Parallel()
reg := prometheus.NewRegistry()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Fill the buffer (size 100) without running a forwarder so nothing
// drains. Then send one more to trigger the drop path.
for range 101 {
sendLogs(t, conn, makeReq(1))
}
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "buffer_full") >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.GreaterOrEqual(t,
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "buffer_full"),
float64(1))
})
// The remaining metrics share one server, forwarder, and connection. The
// phases run sequentially so metrics accumulate.
t.Run("Forwarding", func(t *testing.T) {
t.Parallel()
reg := prometheus.NewRegistry()
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "boundary.sock")
srv := boundarylogproxy.NewServer(testutil.Logger(t), socketPath, reg)
err := srv.Start()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, srv.Close()) })
reportNotify := make(chan struct{}, 4)
reporter := &fakeReporter{
err: context.DeadlineExceeded,
errOnce: true,
reportCb: func() {
select {
case reportNotify <- struct{}{}:
default:
}
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
forwarderDone := make(chan error, 1)
go func() {
forwarderDone <- srv.RunForwarder(ctx, reporter)
}()
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer conn.Close()
// Phase 1: the first forward errors
sendLogs(t, conn, makeReq(2))
select {
case <-reportNotify:
case <-time.After(testutil.WaitShort):
t.Fatal("timed out waiting for forward attempt")
}
// The metric is incremented after ReportBoundaryLogs returns, so we
// need to poll briefly.
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_batches_dropped_total", "forward_failed") >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(2),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "forward_failed"))
// Phase 2: forward succeeds.
sendLogs(t, conn, makeReq(1))
require.Eventually(t, func() bool {
return len(reporter.getLogs()) >= 1
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(1),
getCounterValue(t, reg, "agent_boundary_log_proxy_batches_forwarded_total"))
// Phase 3: boundary-reported drop counts arrive as a separate BoundaryStatus
// message, not piggybacked on log batches.
sendStatus(t, conn, &codec.BoundaryStatus{
DroppedChannelFull: 5,
DroppedBatchFull: 3,
})
// Status is handled immediately by the reader goroutine, not by the
// forwarder, so poll metrics directly.
require.Eventually(t, func() bool {
return getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full") >= 5
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, float64(5),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_channel_full"))
require.Equal(t, float64(3),
getCounterVecValue(t, reg, "agent_boundary_log_proxy_logs_dropped_total", "boundary_batch_full"))
cancel()
<-forwarderDone
})
}
// getCounterVecValue returns the current value of a CounterVec metric filtered
// by the given reason label.
func getCounterVecValue(t *testing.T, reg *prometheus.Registry, name, reason string) float64 {
t.Helper()
metrics, err := reg.Gather()
require.NoError(t, err)
for _, mf := range metrics {
if mf.GetName() != name {
continue
}
for _, m := range mf.GetMetric() {
for _, lp := range m.GetLabel() {
if lp.GetName() == "reason" && lp.GetValue() == reason {
return m.GetCounter().GetValue()
}
}
}
}
return 0
}
// getCounterValue returns the current value of a Counter metric.
func getCounterValue(t *testing.T, reg *prometheus.Registry, name string) float64 {
t.Helper()
metrics, err := reg.Gather()
require.NoError(t, err)
for _, mf := range metrics {
if mf.GetName() != name {
continue
}
for _, m := range mf.GetMetric() {
return m.GetCounter().GetValue()
}
}
return 0
}
-316
View File
@@ -1,316 +0,0 @@
package filefinder_test
import (
"context"
"fmt"
"math/rand"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
)
var (
dirNames = []string{
"cmd", "internal", "pkg", "api", "auth", "database", "server", "client", "middleware",
"handler", "config", "utils", "models", "service", "worker", "scheduler", "notification",
"provisioner", "template", "workspace", "agent", "proxy", "crypto", "telemetry", "billing",
}
fileExts = []string{
".go", ".ts", ".tsx", ".js", ".py", ".sql", ".yaml", ".json", ".md", ".proto", ".sh",
}
fileStems = []string{
"main", "handler", "middleware", "service", "model", "query", "config", "utils", "helpers",
"types", "interface", "test", "mock", "factory", "builder", "adapter", "observer", "provider",
"resolver", "schema", "migration", "fixture", "snapshot", "checkpoint",
}
)
// generateFileTree creates n files under root in a realistic nested directory structure.
func generateFileTree(t testing.TB, root string, n int, seed int64) {
t.Helper()
rng := rand.New(rand.NewSource(seed)) //nolint:gosec // deterministic benchmarks
numDirs := n / 5
if numDirs < 10 {
numDirs = 10
}
dirs := make([]string, 0, numDirs)
for i := 0; i < numDirs; i++ {
depth := rng.Intn(6) + 1
parts := make([]string, depth)
for d := 0; d < depth; d++ {
parts[d] = dirNames[rng.Intn(len(dirNames))]
}
dirs = append(dirs, filepath.Join(parts...))
}
created := make(map[string]struct{})
for _, d := range dirs {
full := filepath.Join(root, d)
if _, ok := created[full]; ok {
continue
}
require.NoError(t, os.MkdirAll(full, 0o755))
created[full] = struct{}{}
}
for i := 0; i < n; i++ {
dir := dirs[rng.Intn(len(dirs))]
stem := fileStems[rng.Intn(len(fileStems))]
ext := fileExts[rng.Intn(len(fileExts))]
name := fmt.Sprintf("%s_%d%s", stem, i, ext)
full := filepath.Join(root, dir, name)
f, err := os.Create(full)
require.NoError(t, err)
_ = f.Close()
}
}
// buildIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func buildIndex(t testing.TB, root string) *filefinder.Index {
t.Helper()
absRoot, err := filepath.Abs(root)
require.NoError(t, err)
idx, err := filefinder.BuildTestIndex(absRoot)
require.NoError(t, err)
return idx
}
func BenchmarkBuildIndex(b *testing.B) {
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
if idx.Len() == 0 {
b.Fatal("expected non-empty index")
}
}
b.StopTimer()
idx := buildIndex(b, dir)
b.ReportMetric(float64(idx.Len())/b.Elapsed().Seconds(), "files/sec")
})
}
}
func BenchmarkSearch_ByScale(b *testing.B) {
queries := []struct {
name string
query string
}{
{"exact_basename", "handler.go"},
{"short_query", "ha"},
{"fuzzy_basename", "hndlr"},
{"path_structured", "internal/handler"},
{"multi_token", "api handler"},
}
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
opts := filefinder.DefaultSearchOptions()
for _, q := range queries {
b.Run(q.name, func(b *testing.B) {
p := filefinder.NewQueryPlanForTest(q.query)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(p, snap, opts.MaxCandidates)
}
})
}
})
}
}
func BenchmarkSearch_ConcurrentReads(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
logger := slogtest.Make(b, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelError)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(b, eng.AddRoot(ctx, dir))
b.Cleanup(func() { _ = eng.Close() })
opts := filefinder.DefaultSearchOptions()
goroutines := []int{1, 4, 16, 64}
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.SetParallelism(g)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
results, err := eng.Search(ctx, "handler", opts)
if err != nil {
b.Fatal(err)
}
_ = results
}
})
})
}
}
func BenchmarkDeltaUpdate(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
addCounts := []int{1, 10, 100}
for _, count := range addCounts {
b.Run(fmt.Sprintf("add_%d_files", count), func(b *testing.B) {
paths := make([]string, count)
for i := range paths {
paths[i] = fmt.Sprintf("injected/dir_%d/newfile_%d.go", i%10, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
idx := buildIndex(b, dir)
b.StartTimer()
for _, p := range paths {
idx.Add(p, 0)
}
}
b.ReportMetric(float64(count), "files_added/op")
})
}
b.Run("search_after_100_additions", func(b *testing.B) {
idx := buildIndex(b, dir)
for i := 0; i < 100; i++ {
idx.Add(fmt.Sprintf("injected/extra/file_%d.go", i), 0)
}
snap := idx.Snapshot()
plan := filefinder.NewQueryPlanForTest("handler")
opts := filefinder.DefaultSearchOptions()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, opts.MaxCandidates)
}
})
}
func BenchmarkMemoryProfile(b *testing.B) {
scales := []struct {
name string
n int
}{
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale memory profile")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
_ = idx.Snapshot()
}
b.StopTimer()
// Report memory stats on the last iteration.
runtime.GC()
var before runtime.MemStats
runtime.ReadMemStats(&before)
idx := buildIndex(b, dir)
var after runtime.MemStats
runtime.ReadMemStats(&after)
allocDelta := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(allocDelta)/float64(idx.Len()), "bytes/file")
runtime.GC()
runtime.ReadMemStats(&before)
snap := idx.Snapshot()
_ = snap
runtime.GC()
runtime.ReadMemStats(&after)
snapAlloc := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(snapAlloc)/float64(idx.Len()), "snap-bytes/file")
})
}
}
func BenchmarkSearch_ConcurrentReads_Throughput(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
goroutines := []int{1, 4, 16, 64}
plan := filefinder.NewQueryPlanForTest("handler.go")
maxCands := filefinder.DefaultSearchOptions().MaxCandidates
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.ResetTimer()
var wg sync.WaitGroup
perGoroutine := b.N / g
if perGoroutine < 1 {
perGoroutine = 1
}
for gi := 0; gi < g; gi++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < perGoroutine; j++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, maxCands)
}
}()
}
wg.Wait()
totalOps := float64(g * perGoroutine)
b.ReportMetric(totalOps/b.Elapsed().Seconds(), "searches/sec")
})
}
}
-125
View File
@@ -1,125 +0,0 @@
package filefinder
import "strings"
// FileFlag represents the type of filesystem entry.
type FileFlag uint16
const (
FlagFile FileFlag = 0
FlagDir FileFlag = 1
FlagSymlink FileFlag = 2
)
type doc struct {
path string
baseOff int
baseLen int
depth int
flags uint16
}
// Index is an append-only in-memory file index with snapshot support.
type Index struct {
docs []doc
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
byPath map[string]uint32
deleted map[uint32]bool
}
// Snapshot is a frozen, read-only view of the index at a point in time.
type Snapshot struct {
docs []doc
deleted map[uint32]bool
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
}
// NewIndex creates an empty Index.
func NewIndex() *Index {
return &Index{
byGram: make(map[uint32][]uint32),
byPrefix2: make(map[uint16][]uint32),
byPath: make(map[string]uint32),
deleted: make(map[uint32]bool),
}
}
// Add inserts a path into the index, tombstoning any previous entry.
func (idx *Index) Add(path string, flags uint16) uint32 {
norm := string(normalizePathBytes([]byte(path)))
if oldID, ok := idx.byPath[norm]; ok {
idx.deleted[oldID] = true
}
id := uint32(len(idx.docs)) //nolint:gosec // Index will never exceed 2^32 docs.
baseOff, baseLen := extractBasename([]byte(norm))
idx.docs = append(idx.docs, doc{
path: norm, baseOff: baseOff, baseLen: baseLen,
depth: strings.Count(norm, "/"), flags: flags,
})
idx.byPath[norm] = id
for _, g := range extractTrigrams([]byte(norm)) {
idx.byGram[g] = append(idx.byGram[g], id)
}
if baseLen > 0 {
basename := []byte(norm[baseOff : baseOff+baseLen])
p1 := prefix1(basename)
idx.byPrefix1[p1] = append(idx.byPrefix1[p1], id)
p2 := prefix2(basename)
idx.byPrefix2[p2] = append(idx.byPrefix2[p2], id)
}
return id
}
// Remove marks the entry for path as deleted.
func (idx *Index) Remove(path string) bool {
norm := string(normalizePathBytes([]byte(path)))
id, ok := idx.byPath[norm]
if !ok {
return false
}
idx.deleted[id] = true
delete(idx.byPath, norm)
return true
}
// Has reports whether path exists (not deleted) in the index.
func (idx *Index) Has(path string) bool {
_, ok := idx.byPath[string(normalizePathBytes([]byte(path)))]
return ok
}
// Len returns the number of live (non-deleted) documents.
func (idx *Index) Len() int { return len(idx.byPath) }
func copyPostings[K comparable](m map[K][]uint32) map[K][]uint32 {
cp := make(map[K][]uint32, len(m))
for k, v := range m {
cp[k] = v[:len(v):len(v)]
}
return cp
}
// Snapshot returns a frozen read-only view of the index.
func (idx *Index) Snapshot() *Snapshot {
del := make(map[uint32]bool, len(idx.deleted))
for id := range idx.deleted {
del[id] = true
}
var p1Copy [256][]uint32
for i, ids := range idx.byPrefix1 {
if len(ids) > 0 {
p1Copy[i] = ids[:len(ids):len(ids)]
}
}
return &Snapshot{
docs: idx.docs[:len(idx.docs):len(idx.docs)],
deleted: del,
byGram: copyPostings(idx.byGram),
byPrefix1: p1Copy,
byPrefix2: copyPostings(idx.byPrefix2),
}
}
-120
View File
@@ -1,120 +0,0 @@
package filefinder_test
import (
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestIndex_AddAndLen(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
if idx.Len() != 2 {
t.Fatalf("expected 2, got %d", idx.Len())
}
}
func TestIndex_Has(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Has("foo/bar.go") {
t.Fatal("expected Has to return true")
}
if idx.Has("foo/missing.go") {
t.Fatal("expected Has to return false for missing path")
}
}
func TestIndex_Remove(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Remove("foo/bar.go") {
t.Fatal("expected Remove to return true")
}
if idx.Has("foo/bar.go") {
t.Fatal("expected Has to return false after Remove")
}
if idx.Len() != 0 {
t.Fatalf("expected Len 0 after Remove, got %d", idx.Len())
}
}
func TestIndex_AddOverwrite(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", uint16(filefinder.FlagFile))
idx.Add("foo/bar.go", uint16(filefinder.FlagDir)) // overwrite
if idx.Len() != 1 {
t.Fatalf("expected 1 after overwrite, got %d", idx.Len())
}
// The old entry should be tombstoned.
if !filefinder.IndexIsDeleted(idx, 0) {
t.Fatal("expected old entry to be deleted")
}
if filefinder.IndexIsDeleted(idx, 1) {
t.Fatal("expected new entry to be live")
}
}
func TestIndex_Snapshot(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
snap := idx.Snapshot()
if filefinder.SnapshotCount(snap) != 2 {
t.Fatalf("expected snapshot count 2, got %d", filefinder.SnapshotCount(snap))
}
// Adding more docs after snapshot doesn't affect it.
idx.Add("foo/qux.go", 0)
if filefinder.SnapshotCount(snap) != 2 {
t.Fatal("snapshot count should not change after new adds")
}
}
func TestIndex_TrigramIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// "handler.go" should produce trigrams for "handler.go".
// Check that at least one trigram exists.
if filefinder.IndexByGramLen(idx) == 0 {
t.Fatal("expected non-empty trigram index")
}
}
func TestIndex_PrefixIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// basename is "handler.go", first byte is 'h'
if filefinder.IndexByPrefix1Len(idx, 'h') == 0 {
t.Fatal("expected prefix1['h'] to be non-empty")
}
}
func TestIndex_RemoveNonexistent(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
if idx.Remove("nonexistent.go") {
t.Fatal("expected Remove to return false for missing path")
}
}
func TestIndex_PathNormalization(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("Foo/Bar.go", 0)
// Should be findable with lowercase.
if !idx.Has("foo/bar.go") {
t.Fatal("expected case-insensitive Has")
}
}
-364
View File
@@ -1,364 +0,0 @@
// Package filefinder provides an in-memory file index with trigram
// matching, fuzzy search, and filesystem watching. It is designed
// to power file-finding features on workspace agents.
package filefinder
import (
"context"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// SearchOptions controls search behavior.
type SearchOptions struct {
Limit int
MaxCandidates int
}
// DefaultSearchOptions returns sensible default search options.
func DefaultSearchOptions() SearchOptions {
return SearchOptions{Limit: 100, MaxCandidates: 10000}
}
type rootSnapshot struct {
root string
snap *Snapshot
}
// Engine is the main file finder. Safe for concurrent use.
type Engine struct {
snap atomic.Pointer[[]*rootSnapshot]
logger slog.Logger
mu sync.Mutex
roots map[string]*rootState
eventCh chan rootEvent
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
type rootState struct {
root string
index *Index
watcher *fsWatcher
cancel context.CancelFunc
}
type rootEvent struct {
root string
events []FSEvent
}
// walkRoot performs a full filesystem walk of absRoot and returns
// a populated Index containing all discovered files and directories.
func walkRoot(absRoot string) (*Index, error) {
idx := NewIndex()
err := filepath.Walk(absRoot, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return nil //nolint:nilerr
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if path == absRoot {
return nil
}
relPath, relErr := filepath.Rel(absRoot, path)
if relErr != nil {
return nil //nolint:nilerr
}
relPath = filepath.ToSlash(relPath)
var flags uint16
if info.IsDir() {
flags = uint16(FlagDir)
} else if info.Mode()&os.ModeSymlink != 0 {
flags = uint16(FlagSymlink)
}
idx.Add(relPath, flags)
return nil
})
return idx, err
}
// NewEngine creates a new Engine.
func NewEngine(logger slog.Logger) *Engine {
e := &Engine{
logger: logger,
roots: make(map[string]*rootState),
eventCh: make(chan rootEvent, 256),
closeCh: make(chan struct{}),
}
empty := make([]*rootSnapshot, 0)
e.snap.Store(&empty)
e.wg.Add(1)
go e.start()
return e
}
// ErrClosed is returned when operations are attempted on a
// closed engine.
var ErrClosed = xerrors.New("engine is closed")
// AddRoot adds a directory root to the engine.
func (e *Engine) AddRoot(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
if e.closed.Load() {
e.mu.Unlock()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
return nil
}
e.mu.Unlock()
// Walk and create the watcher outside the lock to avoid
// blocking the event pipeline on filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("walk root: %w", walkErr)
}
wCtx, wCancel := context.WithCancel(context.Background())
w, wErr := newFSWatcher(absRoot, e.logger)
if wErr != nil {
wCancel()
return xerrors.Errorf("create watcher: %w", wErr)
}
e.mu.Lock()
// Re-check after re-acquiring the lock: another goroutine
// may have added this root or closed the engine while we
// were walking.
if e.closed.Load() {
e.mu.Unlock()
wCancel()
_ = w.Close()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
wCancel()
_ = w.Close()
return nil
}
rs := &rootState{root: absRoot, index: idx, watcher: w, cancel: wCancel}
e.roots[absRoot] = rs
w.Start(wCtx)
e.wg.Add(1)
go e.forwardEvents(wCtx, absRoot, w)
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "added root to engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
// RemoveRoot stops watching a root and removes it.
func (e *Engine) RemoveRoot(root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[absRoot]
if !exists {
return xerrors.Errorf("root %q not found", absRoot)
}
rs.cancel()
_ = rs.watcher.Close()
delete(e.roots, absRoot)
e.publishSnapshot()
return nil
}
// Search performs a fuzzy file search across all roots.
func (e *Engine) Search(_ context.Context, query string, opts SearchOptions) ([]Result, error) {
if e.closed.Load() {
return nil, ErrClosed
}
snapPtr := e.snap.Load()
if snapPtr == nil || len(*snapPtr) == 0 {
return nil, nil
}
roots := *snapPtr
plan := newQueryPlan(query)
if len(plan.Normalized) == 0 {
return nil, nil
}
if opts.Limit <= 0 {
opts.Limit = 100
}
if opts.MaxCandidates <= 0 {
opts.MaxCandidates = 10000
}
params := defaultScoreParams()
var allCands []candidate
for _, rs := range roots {
allCands = append(allCands, searchSnapshot(plan, rs.snap, opts.MaxCandidates)...)
}
results := mergeAndScore(allCands, plan, params, opts.Limit)
return results, nil
}
// Close shuts down the engine.
func (e *Engine) Close() error {
if e.closed.Swap(true) {
return nil
}
close(e.closeCh)
e.mu.Lock()
for _, rs := range e.roots {
rs.cancel()
_ = rs.watcher.Close()
}
e.roots = make(map[string]*rootState)
e.mu.Unlock()
e.wg.Wait()
return nil
}
// Rebuild forces a complete re-walk and re-index of a root.
func (e *Engine) Rebuild(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
// Walk outside the lock to avoid blocking the event
// pipeline on potentially slow filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("rebuild walk: %w", walkErr)
}
e.mu.Lock()
rs, exists := e.roots[absRoot]
if !exists {
e.mu.Unlock()
return xerrors.Errorf("root %q not found", absRoot)
}
rs.index = idx
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "rebuilt root in engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
func (e *Engine) start() {
defer e.wg.Done()
for {
select {
case <-e.closeCh:
return
case re, ok := <-e.eventCh:
if !ok {
return
}
e.applyEvents(re)
}
}
}
func (e *Engine) forwardEvents(ctx context.Context, root string, w *fsWatcher) {
defer e.wg.Done()
for {
select {
case <-ctx.Done():
return
case <-e.closeCh:
return
case evts, ok := <-w.Events():
if !ok {
return
}
select {
case e.eventCh <- rootEvent{root: root, events: evts}:
case <-ctx.Done():
return
case <-e.closeCh:
return
}
}
}
}
func (e *Engine) applyEvents(re rootEvent) {
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[re.root]
if !exists {
return
}
changed := false
for _, ev := range re.events {
relPath, err := filepath.Rel(rs.root, ev.Path)
if err != nil {
continue
}
relPath = filepath.ToSlash(relPath)
switch ev.Op {
case OpCreate:
if rs.index.Has(relPath) {
continue
}
var flags uint16
if ev.IsDir {
flags = uint16(FlagDir)
}
rs.index.Add(relPath, flags)
changed = true
case OpRemove, OpRename:
if rs.index.Remove(relPath) {
changed = true
}
if ev.IsDir || ev.Op == OpRename {
prefix := strings.ToLower(filepath.ToSlash(relPath)) + "/"
for path := range rs.index.byPath {
if strings.HasPrefix(path, prefix) {
rs.index.Remove(path)
changed = true
}
}
}
case OpModify:
}
}
if changed {
e.publishSnapshot()
}
}
// publishSnapshot builds and atomically publishes a new snapshot.
// Must be called with e.mu held.
func (e *Engine) publishSnapshot() {
roots := make([]*rootSnapshot, 0, len(e.roots))
for _, rs := range e.roots {
roots = append(roots, &rootSnapshot{
root: rs.root,
snap: rs.index.Snapshot(),
})
}
slices.SortFunc(roots, func(a, b *rootSnapshot) int {
return strings.Compare(a.root, b.root)
})
e.snap.Store(&roots)
}
-233
View File
@@ -1,233 +0,0 @@
package filefinder_test
import (
"context"
"os"
"path/filepath"
"sort"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
"github.com/coder/coder/v2/testutil"
)
func newTestEngine(t *testing.T) (*filefinder.Engine, context.Context) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
eng := filefinder.NewEngine(logger)
t.Cleanup(func() { _ = eng.Close() })
return eng, context.Background()
}
func requireResultHasPath(t *testing.T, results []filefinder.Result, path string) {
t.Helper()
for _, r := range results {
if r.Path == path {
return
}
}
t.Errorf("expected %q in results, got %v", path, resultPaths(results))
}
func TestEngine_SearchFindsKnownFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/main.go", "package main")
createFile(t, dir, "src/handler.go", "package main")
createFile(t, dir, "README.md", "# hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "main.go", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find main.go")
requireResultHasPath(t, results, "src/main.go")
}
func TestEngine_SearchFuzzyMatch(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/controllers/user_handler.go", "package controllers")
createFile(t, dir, "src/models/user.go", "package models")
createFile(t, dir, "docs/api.md", "# API")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
// "handler" should match "user_handler.go".
results, err := eng.Search(ctx, "handler", filefinder.DefaultSearchOptions())
require.NoError(t, err)
// The query is a subsequence of "user_handler.go" so it
// should appear somewhere in the results.
requireResultHasPath(t, results, "src/controllers/user_handler.go")
}
func TestEngine_IndexPicksUpNewFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "existing.txt", "hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "newfile_unique.txt", "world")
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "newfile_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "newfile_unique.txt" {
return true
}
}
return false
}, testutil.WaitShort, testutil.IntervalFast, "expected newfile_unique.txt to appear via watcher")
}
func TestEngine_IndexRemovesDeletedFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "deleteme_unique.txt", "goodbye")
createFile(t, dir, "keeper.txt", "stay")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find deleteme_unique.txt initially")
require.NoError(t, os.Remove(filepath.Join(dir, "deleteme_unique.txt")))
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "deleteme_unique.txt" {
return false // still found
}
}
return true
}, testutil.WaitShort, testutil.IntervalFast, "expected deleteme_unique.txt to disappear after removal")
}
func TestEngine_MultipleRoots(t *testing.T) {
t.Parallel()
dir1 := t.TempDir()
dir2 := t.TempDir()
createFile(t, dir1, "alpha_unique.go", "package alpha")
createFile(t, dir2, "beta_unique.go", "package beta")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir1))
require.NoError(t, eng.AddRoot(ctx, dir2))
results, err := eng.Search(ctx, "alpha_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "alpha_unique.go")
results, err = eng.Search(ctx, "beta_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "beta_unique.go")
}
func TestEngine_EmptyQueryReturnsEmpty(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "something.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results, "empty query should return no results")
}
func TestEngine_CloseIsClean(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.Close())
_, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.Error(t, err)
}
func TestEngine_AddRootIdempotent(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.AddRoot(ctx, dir))
snapLen := filefinder.EngineSnapLen(eng)
require.Equal(t, 1, snapLen, "expected exactly one root after duplicate add")
}
func TestEngine_RemoveRoot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results)
require.NoError(t, eng.RemoveRoot(dir))
results, err = eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results)
}
func TestEngine_Rebuild(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "original.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "sneaky_rebuild.txt", "hidden")
require.NoError(t, eng.Rebuild(ctx, dir))
results, err := eng.Search(ctx, "sneaky_rebuild", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "sneaky_rebuild.txt")
}
// createFile creates a file (and parent dirs) at relPath under dir.
func createFile(t *testing.T, dir, relPath, content string) {
t.Helper()
full := filepath.Join(dir, relPath)
require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755))
require.NoError(t, os.WriteFile(full, []byte(content), 0o600))
}
func resultPaths(results []filefinder.Result) []string {
paths := make([]string, len(results))
for i, r := range results {
paths[i] = r.Path
}
sort.Strings(paths)
return paths
}
-85
View File
@@ -1,85 +0,0 @@
package filefinder
// Test helpers that need internal access.
// MakeTestSnapshot builds a Snapshot from a list of paths. Useful for
// query-level tests that don't need a real filesystem.
func MakeTestSnapshot(paths []string) *Snapshot {
idx := NewIndex()
for _, p := range paths {
idx.Add(p, 0)
}
return idx.Snapshot()
}
// BuildTestIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func BuildTestIndex(root string) (*Index, error) {
return walkRoot(root)
}
// IndexIsDeleted reports whether the document at id is tombstoned.
func IndexIsDeleted(idx *Index, id uint32) bool {
return idx.deleted[id]
}
// IndexByGramLen returns the number of entries in the trigram index.
func IndexByGramLen(idx *Index) int {
return len(idx.byGram)
}
// IndexByPrefix1Len returns the number of posting-list entries for
// the given single-byte prefix.
func IndexByPrefix1Len(idx *Index, b byte) int {
return len(idx.byPrefix1[b])
}
// SnapshotCount returns the number of documents in a Snapshot.
func SnapshotCount(snap *Snapshot) int {
return len(snap.docs)
}
// EngineSnapLen returns the number of root snapshots currently held
// by the engine, or -1 if the pointer is nil.
func EngineSnapLen(eng *Engine) int {
p := eng.snap.Load()
if p == nil {
return -1
}
return len(*p)
}
// DefaultScoreParamsForTest exposes defaultScoreParams for tests.
var DefaultScoreParamsForTest = defaultScoreParams
// ScoreParamsForTest is a type alias for scoreParams.
type ScoreParamsForTest = scoreParams
// Exported aliases for internal functions used in tests.
var (
NewQueryPlanForTest = newQueryPlan
SearchSnapshotForTest = searchSnapshot
IntersectSortedForTest = intersectSorted
IntersectAllForTest = intersectAll
MergeAndScoreForTest = mergeAndScore
NormalizeQueryForTest = normalizeQuery
NormalizePathBytesForTest = normalizePathBytes
ExtractTrigramsForTest = extractTrigrams
ExtractBasenameForTest = extractBasename
ExtractSegmentsForTest = extractSegments
Prefix1ForTest = prefix1
Prefix2ForTest = prefix2
IsSubsequenceForTest = isSubsequence
LongestContiguousMatchForTest = longestContiguousMatch
IsBoundaryForTest = isBoundary
CountBoundaryHitsForTest = countBoundaryHits
EqualFoldASCIIForTest = equalFoldASCII
ScorePathForTest = scorePath
PackTrigramForTest = packTrigram
)
// Type aliases for internal types used in tests.
type (
CandidateForTest = candidate
QueryPlanForTest = queryPlan
)
-299
View File
@@ -1,299 +0,0 @@
package filefinder
import (
"container/heap"
"slices"
"strings"
)
type candidate struct {
DocID uint32
Path string
BaseOff int
BaseLen int
Depth int
Flags uint16
}
// Result is a scored search result returned to callers.
type Result struct {
Path string
Score float32
IsDir bool
}
type queryPlan struct {
Original string
Normalized string
Tokens [][]byte
Trigrams []uint32
IsShort bool
HasSlash bool
BasenameQ []byte
DirTokens [][]byte
}
func newQueryPlan(q string) *queryPlan {
norm := normalizeQuery(q)
p := &queryPlan{Original: q, Normalized: norm}
if len(norm) == 0 {
p.IsShort = true
return p
}
raw := strings.ReplaceAll(norm, "/", " ")
parts := strings.Fields(raw)
p.HasSlash = strings.ContainsRune(norm, '/')
for _, part := range parts {
p.Tokens = append(p.Tokens, []byte(part))
}
if len(p.Tokens) > 0 {
p.BasenameQ = p.Tokens[len(p.Tokens)-1]
if len(p.Tokens) > 1 {
p.DirTokens = p.Tokens[:len(p.Tokens)-1]
}
}
p.IsShort = true
for _, tok := range p.Tokens {
if len(tok) >= 3 {
p.IsShort = false
break
}
}
if !p.IsShort {
p.Trigrams = extractQueryTrigrams(p.Tokens)
}
return p
}
func extractQueryTrigrams(tokens [][]byte) []uint32 {
seen := make(map[uint32]struct{})
for _, tok := range tokens {
if len(tok) < 3 {
continue
}
for i := 0; i <= len(tok)-3; i++ {
seen[packTrigram(tok[i], tok[i+1], tok[i+2])] = struct{}{}
}
}
if len(seen) == 0 {
return nil
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
return result
}
func packTrigram(a, b, c byte) uint32 {
return uint32(toLowerASCII(a))<<16 | uint32(toLowerASCII(b))<<8 | uint32(toLowerASCII(c))
}
// searchSnapshot runs the full search pipeline against a single
// root snapshot: it selects a strategy (prefix, trigram, or
// fuzzy fallback) based on query length, retrieves candidate
// doc IDs, and converts them into candidate structs.
func searchSnapshot(plan *queryPlan, snap *Snapshot, limit int) []candidate {
if snap == nil || len(snap.docs) == 0 || len(plan.Normalized) == 0 {
return nil
}
var ids []uint32
if plan.IsShort {
ids = searchShort(plan, snap)
} else {
ids = searchTrigrams(plan, snap)
if len(ids) == 0 && len(plan.BasenameQ) > 0 {
ids = searchFuzzyFallback(plan, snap)
}
}
if len(ids) == 0 {
return nil
}
cands := make([]candidate, 0, min(len(ids), limit))
for _, id := range ids {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
d := snap.docs[id]
cands = append(cands, candidate{
DocID: id, Path: d.path, BaseOff: d.baseOff,
BaseLen: d.baseLen, Depth: d.depth, Flags: d.flags,
})
if len(cands) >= limit {
break
}
}
return cands
}
func searchShort(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
if len(plan.BasenameQ) >= 2 {
if ids := snap.byPrefix2[prefix2(plan.BasenameQ)]; len(ids) > 0 {
return ids
}
}
return snap.byPrefix1[prefix1(plan.BasenameQ)]
}
func searchTrigrams(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.Trigrams) == 0 {
return nil
}
lists := make([][]uint32, 0, len(plan.Trigrams))
for _, g := range plan.Trigrams {
ids, ok := snap.byGram[g]
if !ok || len(ids) == 0 {
return nil
}
lists = append(lists, ids)
}
return intersectAll(lists)
}
func searchFuzzyFallback(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
bucket := snap.byPrefix1[prefix1(plan.BasenameQ)]
if len(bucket) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
var ids []uint32
for _, id := range bucket {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, id)
}
}
if len(ids) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
return ids
}
func searchSubsequenceScan(plan *queryPlan, snap *Snapshot, maxCheck int) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
var ids []uint32
checked := 0
for id := 0; id < len(snap.docs) && checked < maxCheck; id++ {
uid := uint32(id) //nolint:gosec // Snapshot count is bounded well below 2^32.
if snap.deleted[uid] {
continue
}
checked++
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, uid)
}
}
return ids
}
func intersectSorted(a, b []uint32) []uint32 {
if len(a) == 0 || len(b) == 0 {
return nil
}
var result []uint32
ai, bi := 0, 0
for ai < len(a) && bi < len(b) {
switch {
case a[ai] < b[bi]:
ai++
case a[ai] > b[bi]:
bi++
default:
result = append(result, a[ai])
ai++
bi++
}
}
return result
}
func intersectAll(lists [][]uint32) []uint32 {
if len(lists) == 0 {
return nil
}
if len(lists) == 1 {
return lists[0]
}
slices.SortFunc(lists, func(a, b []uint32) int { return len(a) - len(b) })
result := lists[0]
for i := 1; i < len(lists) && len(result) > 0; i++ {
result = intersectSorted(result, lists[i])
}
return result
}
func mergeAndScore(cands []candidate, plan *queryPlan, params scoreParams, topK int) []Result {
if topK <= 0 || len(cands) == 0 {
return nil
}
query := []byte(plan.Normalized)
h := &resultHeap{}
heap.Init(h)
for i := range cands {
c := &cands[i]
s := scorePath([]byte(c.Path), c.BaseOff, c.BaseLen, c.Depth, query, plan.Tokens, params)
if s <= 0 {
continue
}
// DirTokenHit is applied here rather than in scorePath because
// it depends on the query plan's directory tokens, which are
// split from the full query during planning. scorePath operates
// on raw query bytes without knowledge of token boundaries.
if len(plan.DirTokens) > 0 {
segments := extractSegments([]byte(c.Path))
for _, dt := range plan.DirTokens {
for _, seg := range segments {
if equalFoldASCII(seg, dt) {
s += params.DirTokenHit
break
}
}
}
}
r := Result{Path: c.Path, Score: s, IsDir: c.Flags == uint16(FlagDir)}
if h.Len() < topK {
heap.Push(h, r)
} else if s > (*h)[0].Score {
(*h)[0] = r
heap.Fix(h, 0)
}
}
n := h.Len()
results := make([]Result, n)
for i := n - 1; i >= 0; i-- {
v := heap.Pop(h)
if r, ok := v.(Result); ok {
results[i] = r
}
}
return results
}
type resultHeap []Result
func (h resultHeap) Len() int { return len(h) }
func (h resultHeap) Less(i, j int) bool { return h[i].Score < h[j].Score }
func (h resultHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *resultHeap) Push(x interface{}) {
r, ok := x.(Result)
if ok {
*h = append(*h, r)
}
}
func (h *resultHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}
-343
View File
@@ -1,343 +0,0 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNewQueryPlan(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
wantNorm string
wantShort bool
wantSlash bool
wantBase string
wantTokens []string
wantDirTok []string
wantTriCnt int // -1 to skip check
}{
{"Simple", "foo", "foo", false, false, "foo", []string{"foo"}, nil, 1},
{"MultiToken", "foo bar", "foo bar", false, false, "bar", []string{"foo", "bar"}, []string{"foo"}, -1},
{"Slash", "internal/foo", "internal/foo", false, true, "foo", []string{"internal", "foo"}, []string{"internal"}, -1},
{"SingleChar", "a", "a", true, false, "a", []string{"a"}, nil, 0},
{"TwoChars", "ab", "ab", true, false, "ab", []string{"ab"}, nil, -1},
{"ThreeChars", "abc", "abc", false, false, "abc", []string{"abc"}, nil, 1},
{"DotPrefix", ".go", ".go", false, false, ".go", []string{".go"}, nil, -1},
{"UpperCase", "FOO", "foo", false, false, "foo", []string{"foo"}, nil, -1},
{"Empty", "", "", true, false, "", nil, nil, -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest(tt.query)
if plan.Normalized != tt.wantNorm {
t.Errorf("normalized = %q, want %q", plan.Normalized, tt.wantNorm)
}
if plan.IsShort != tt.wantShort {
t.Errorf("isShort = %v, want %v", plan.IsShort, tt.wantShort)
}
if plan.HasSlash != tt.wantSlash {
t.Errorf("hasSlash = %v, want %v", plan.HasSlash, tt.wantSlash)
}
if string(plan.BasenameQ) != tt.wantBase {
t.Errorf("basenameQ = %q, want %q", plan.BasenameQ, tt.wantBase)
}
if tt.wantTokens == nil {
if len(plan.Tokens) != 0 {
t.Errorf("expected 0 tokens, got %d", len(plan.Tokens))
}
} else {
if len(plan.Tokens) != len(tt.wantTokens) {
t.Fatalf("tokens len = %d, want %d", len(plan.Tokens), len(tt.wantTokens))
}
for i, tok := range plan.Tokens {
if string(tok) != tt.wantTokens[i] {
t.Errorf("tokens[%d] = %q, want %q", i, tok, tt.wantTokens[i])
}
}
}
if tt.wantDirTok != nil {
if len(plan.DirTokens) != len(tt.wantDirTok) {
t.Fatalf("dirTokens len = %d, want %d", len(plan.DirTokens), len(tt.wantDirTok))
}
for i, tok := range plan.DirTokens {
if string(tok) != tt.wantDirTok[i] {
t.Errorf("dirTokens[%d] = %q, want %q", i, tok, tt.wantDirTok[i])
}
}
}
if tt.wantTriCnt >= 0 && len(plan.Trigrams) != tt.wantTriCnt {
t.Errorf("trigram count = %d, want %d", len(plan.Trigrams), tt.wantTriCnt)
}
})
}
// ThreeChars: verify the actual trigram value.
plan := filefinder.NewQueryPlanForTest("abc")
if want := filefinder.PackTrigramForTest('a', 'b', 'c'); plan.Trigrams[0] != want {
t.Errorf("trigram = %x, want %x", plan.Trigrams[0], want)
}
// ShortMultiToken: both tokens < 3 chars so isShort should be true.
plan = filefinder.NewQueryPlanForTest("ab cd")
if !plan.IsShort {
t.Error("expected isShort=true when all tokens < 3 chars")
}
// One token >= 3 chars, so isShort should be false.
plan = filefinder.NewQueryPlanForTest("ab cde")
if plan.IsShort {
t.Error("expected isShort=false when any token >= 3 chars")
}
}
func requireCandHasPath(t *testing.T, cands []filefinder.CandidateForTest, path string) {
t.Helper()
for _, c := range cands {
if c.Path == path {
return
}
}
t.Errorf("expected to find %q in candidates", path)
}
func TestSearchSnapshot_TrigramMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'handler'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_ShortQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go", "bar.go", "fab.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("fo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'fo'")
}
requireCandHasPath(t, cands, "foo.go")
}
func TestSearchSnapshot_FuzzyFallback(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("hndlr"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected fuzzy fallback to find 'handler.go' for query 'hndlr'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_FuzzyFallbackNoFirstCharMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/xylophone.go", "lib/extra.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("xylo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'xylo'")
}
requireCandHasPath(t, cands, "src/xylophone.go")
}
func TestSearchSnapshot_NilSnapshot(t *testing.T) {
t.Parallel()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("foo"), nil, 100)
if cands != nil {
t.Errorf("expected nil for nil snapshot, got %v", cands)
}
}
func TestSearchSnapshot_EmptyQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest(""), snap, 100)
if cands != nil {
t.Errorf("expected nil for empty query, got %v", cands)
}
}
func TestSearchSnapshot_DeletedDocsExcluded(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
idx.Remove("handler.go")
snap := idx.Snapshot()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
for _, c := range cands {
if c.Path == "handler.go" {
t.Error("deleted doc should not appear in results")
}
}
}
func TestSearchSnapshot_Limit(t *testing.T) {
t.Parallel()
paths := make([]string, 50)
for i := range paths {
paths[i] = "handler" + string(rune('a'+i%26)) + ".go"
}
snap := filefinder.MakeTestSnapshot(paths)
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 3)
if len(cands) > 3 {
t.Errorf("expected at most 3 candidates, got %d", len(cands))
}
}
func TestIntersectSorted(t *testing.T) {
t.Parallel()
tests := []struct {
name string
a, b []uint32
want []uint32
}{
{"both empty", nil, nil, nil},
{"a empty", nil, []uint32{1, 2}, nil},
{"b empty", []uint32{1, 2}, nil, nil},
{"no overlap", []uint32{1, 3, 5}, []uint32{2, 4, 6}, nil},
{"full overlap", []uint32{1, 2, 3}, []uint32{1, 2, 3}, []uint32{1, 2, 3}},
{"partial overlap", []uint32{1, 2, 3, 5}, []uint32{2, 4, 5}, []uint32{2, 5}},
{"single match", []uint32{1, 2, 3}, []uint32{2}, []uint32{2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectSortedForTest(tt.a, tt.b)
if len(tt.want) == 0 {
if len(got) != 0 {
t.Errorf("got %v, want empty/nil", got)
}
return
}
if !slices.Equal(got, tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestIntersectAll(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest(nil); got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("single", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3}}); len(got) != 3 {
t.Fatalf("len = %d, want 3", len(got))
}
})
t.Run("multiple", func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3, 4, 5}, {2, 3, 5}, {3, 5, 7}})
if !slices.Equal(got, []uint32{3, 5}) {
t.Errorf("got %v, want [3 5]", got)
}
})
t.Run("no overlap", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2}, {3, 4}}); got != nil {
t.Errorf("got %v, want nil", got)
}
})
}
func TestMergeAndScore_SortedDescending(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
params := filefinder.DefaultScoreParamsForTest()
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "a/b/c/d/e/foo", BaseOff: 10, BaseLen: 3, Depth: 5},
{DocID: 1, Path: "src/foo", BaseOff: 4, BaseLen: 3, Depth: 1},
{DocID: 2, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0},
}
results := filefinder.MergeAndScoreForTest(cands, plan, params, 10)
if len(results) == 0 {
t.Fatal("expected non-empty results")
}
for i := 1; i < len(results); i++ {
if results[i].Score > results[i-1].Score {
t.Errorf("results not sorted: [%d].Score=%f > [%d].Score=%f",
i, results[i].Score, i-1, results[i-1].Score)
}
}
}
func TestMergeAndScore_TopKLimit(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("f")
params := filefinder.DefaultScoreParamsForTest()
var cands []filefinder.CandidateForTest
for i := range 20 {
p := "f" + string(rune('a'+i))
cands = append(cands, filefinder.CandidateForTest{DocID: uint32(i), Path: p, BaseOff: 0, BaseLen: len(p), Depth: 0}) //nolint:gosec // test index is tiny
}
if results := filefinder.MergeAndScoreForTest(cands, plan, params, 5); len(results) != 5 {
t.Errorf("expected 5 results, got %d", len(results))
}
}
func TestMergeAndScore_ZeroTopK(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0}}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 0); len(results) != 0 {
t.Errorf("expected 0 results for topK=0, got %d", len(results))
}
}
func TestMergeAndScore_NoMatchCandidatesDropped(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("xyz")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "abc", BaseOff: 0, BaseLen: 3, Depth: 0},
{DocID: 1, Path: "def", BaseOff: 0, BaseLen: 3, Depth: 0},
}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for non-matching candidates, got %d", len(results))
}
}
func TestMergeAndScore_IsDirFlag(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0, Flags: uint16(filefinder.FlagDir)},
}
results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if !results[0].IsDir {
t.Error("expected IsDir=true for FlagDir candidate")
}
}
func TestMergeAndScore_EmptyCandidates(t *testing.T) {
t.Parallel()
if results := filefinder.MergeAndScoreForTest(nil, filefinder.NewQueryPlanForTest("foo"), filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for nil candidates, got %d", len(results))
}
}
func TestSearchSnapshot_FuzzyFallbackEndToEnd(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/middleware.go", "pkg/config.go"})
plan := filefinder.NewQueryPlanForTest("hndlr")
results := filefinder.MergeAndScoreForTest(filefinder.SearchSnapshotForTest(plan, snap, 100), plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) == 0 {
t.Fatal("expected fuzzy fallback to produce scored results for 'hndlr'")
}
if results[0].Path != "src/handler.go" {
t.Errorf("expected top result 'src/handler.go', got %q", results[0].Path)
}
}
-288
View File
@@ -1,288 +0,0 @@
package filefinder
import "slices"
func toLowerASCII(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
func normalizeQuery(q string) string {
b := make([]byte, 0, len(q))
prevSpace := true
for i := 0; i < len(q); i++ {
c := q[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == ' ' {
if prevSpace {
continue
}
prevSpace = true
} else {
prevSpace = false
}
b = append(b, c)
}
if len(b) > 0 && b[len(b)-1] == ' ' {
b = b[:len(b)-1]
}
return string(b)
}
func normalizePathBytes(p []byte) []byte {
j := 0
prevSlash := false
for i := 0; i < len(p); i++ {
c := p[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == '/' {
if prevSlash {
continue
}
prevSlash = true
} else {
prevSlash = false
}
p[j] = c
j++
}
return p[:j]
}
// extractTrigrams returns deduplicated, sorted trigrams (three-byte
// subsequences) from s. Trigrams are the primary index key: a
// document matches a query only if every query trigram appears in
// the document, giving O(1) candidate filtering per trigram.
func extractTrigrams(s []byte) []uint32 {
if len(s) < 3 {
return nil
}
seen := make(map[uint32]struct{}, len(s))
for i := 0; i <= len(s)-3; i++ {
b0 := toLowerASCII(s[i])
b1 := toLowerASCII(s[i+1])
b2 := toLowerASCII(s[i+2])
gram := uint32(b0)<<16 | uint32(b1)<<8 | uint32(b2)
seen[gram] = struct{}{}
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
slices.Sort(result)
return result
}
func extractBasename(path []byte) (offset int, length int) {
end := len(path)
if end > 0 && path[end-1] == '/' {
end--
}
if end == 0 {
return 0, 0
}
i := end - 1
for i >= 0 && path[i] != '/' {
i--
}
start := i + 1
return start, end - start
}
func extractSegments(path []byte) [][]byte {
var segments [][]byte
start := 0
for i := 0; i <= len(path); i++ {
if i == len(path) || path[i] == '/' {
if i > start {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
return segments
}
func prefix1(name []byte) byte {
if len(name) == 0 {
return 0
}
return toLowerASCII(name[0])
}
func prefix2(name []byte) uint16 {
if len(name) == 0 {
return 0
}
hi := uint16(toLowerASCII(name[0])) << 8
if len(name) < 2 {
return hi
}
return hi | uint16(toLowerASCII(name[1]))
}
// scoreParams controls the weights for each scoring signal.
type scoreParams struct {
BasenameMatch float32
BasenamePrefix float32
ExactSegment float32
BoundaryHit float32
ContiguousRun float32
DirTokenHit float32
DepthPenalty float32
LengthPenalty float32
}
func defaultScoreParams() scoreParams {
return scoreParams{
BasenameMatch: 6.0,
BasenamePrefix: 3.5,
ExactSegment: 2.5,
BoundaryHit: 1.8,
ContiguousRun: 1.2,
DirTokenHit: 0.4,
DepthPenalty: 0.08,
LengthPenalty: 0.01,
}
}
func isSubsequence(haystack, needle []byte) bool {
if len(needle) == 0 {
return true
}
ni := 0
for _, hb := range haystack {
if toLowerASCII(hb) == toLowerASCII(needle[ni]) {
ni++
if ni == len(needle) {
return true
}
}
}
return false
}
func longestContiguousMatch(haystack, needle []byte) int {
if len(needle) == 0 || len(haystack) == 0 {
return 0
}
best := 0
ni := 0
run := 0
for _, hb := range haystack {
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run++
ni++
if run > best {
best = run
}
} else {
run = 0
ni = 0
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run = 1
ni = 1
if run > best {
best = run
}
}
}
}
return best
}
func isBoundary(b byte) bool {
return b == '/' || b == '.' || b == '_' || b == '-'
}
func countBoundaryHits(path []byte, query []byte) int {
if len(query) == 0 || len(path) == 0 {
return 0
}
hits := 0
qi := 0
for pi := 0; pi < len(path) && qi < len(query); pi++ {
atBoundary := pi == 0 || isBoundary(path[pi-1])
if atBoundary && toLowerASCII(path[pi]) == toLowerASCII(query[qi]) {
hits++
qi++
}
}
return hits
}
func equalFoldASCII(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if toLowerASCII(a[i]) != toLowerASCII(b[i]) {
return false
}
}
return true
}
func hasPrefixFoldASCII(haystack, prefix []byte) bool {
if len(prefix) > len(haystack) {
return false
}
for i := range prefix {
if toLowerASCII(haystack[i]) != toLowerASCII(prefix[i]) {
return false
}
}
return true
}
// scorePath computes a relevance score for a candidate path
// against a query. The score combines several signals:
// basename match, basename prefix, exact segment match,
// word-boundary hits, longest contiguous run, and penalties
// for depth and length. A return value of 0 means no match
// (the query is not a subsequence of the path).
func scorePath(
path []byte,
baseOff int,
baseLen int,
depth int,
query []byte,
queryTokens [][]byte,
params scoreParams,
) float32 {
if !isSubsequence(path, query) {
return 0
}
var score float32
basename := path[baseOff : baseOff+baseLen]
if isSubsequence(basename, query) {
score += params.BasenameMatch
}
if hasPrefixFoldASCII(basename, query) {
score += params.BasenamePrefix
}
segments := extractSegments(path)
for _, token := range queryTokens {
for _, seg := range segments {
if equalFoldASCII(seg, token) {
score += params.ExactSegment
break
}
}
}
bh := countBoundaryHits(path, query)
score += float32(bh) * params.BoundaryHit
lcm := longestContiguousMatch(path, query)
score += float32(lcm) * params.ContiguousRun
score -= float32(depth) * params.DepthPenalty
score -= float32(len(path)) * params.LengthPenalty
return score
}
-388
View File
@@ -1,388 +0,0 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNormalizeQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"empty", "", ""},
{"leading and trailing spaces", " hello ", "hello"},
{"multiple internal spaces", "foo bar baz", "foo bar baz"},
{"uppercase to lower", "FooBar", "foobar"},
{"backslash to slash", `foo\bar\baz`, "foo/bar/baz"},
{"mixed case and spaces", " Hello World ", "hello world"},
{"unicode passthrough", "héllo wörld", "héllo wörld"},
{"only spaces", " ", ""},
{"single char", "A", "a"},
{"slashes preserved", "/foo/bar/", "/foo/bar/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.NormalizeQueryForTest(tt.input)
if got != tt.want {
t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestExtractTrigrams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want []uint32
}{
{"too short", "ab", nil},
{"exactly three bytes", "abc", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"case insensitive", "ABC", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"deduplication", "aaaa", []uint32{uint32('a')<<16 | uint32('a')<<8 | uint32('a')}},
{"four bytes produces two trigrams", "abcd", []uint32{
uint32('a')<<16 | uint32('b')<<8 | uint32('c'),
uint32('b')<<16 | uint32('c')<<8 | uint32('d'),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractTrigramsForTest([]byte(tt.input))
if !slices.Equal(got, tt.want) {
t.Errorf("extractTrigrams(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestExtractBasename(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
wantOff int
wantName string
}{
{"full path", "/foo/bar/baz.go", 9, "baz.go"},
{"bare filename", "baz.go", 0, "baz.go"},
{"trailing slash", "/a/b/", 3, "b"},
{"root slash", "/", 0, ""},
{"empty", "", 0, ""},
{"single dir with slash", "/foo", 1, "foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
off, length := filefinder.ExtractBasenameForTest([]byte(tt.path))
if off != tt.wantOff {
t.Errorf("extractBasename(%q) offset = %d, want %d", tt.path, off, tt.wantOff)
}
gotName := string([]byte(tt.path)[off : off+length])
if gotName != tt.wantName {
t.Errorf("extractBasename(%q) name = %q, want %q", tt.path, gotName, tt.wantName)
}
})
}
}
func TestExtractSegments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want []string
}{
{"absolute path", "/foo/bar/baz", []string{"foo", "bar", "baz"}},
{"relative path", "foo/bar", []string{"foo", "bar"}},
{"trailing slash", "/a/b/", []string{"a", "b"}},
{"multiple slashes", "//a///b//", []string{"a", "b"}},
{"empty", "", nil},
{"single segment", "foo", []string{"foo"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractSegmentsForTest([]byte(tt.path))
if len(got) != len(tt.want) {
t.Fatalf("extractSegments(%q) got %d segments, want %d", tt.path, len(got), len(tt.want))
}
for i := range got {
if string(got[i]) != tt.want[i] {
t.Errorf("extractSegments(%q)[%d] = %q, want %q", tt.path, i, got[i], tt.want[i])
}
}
})
}
}
func TestPrefix1(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want byte
}{
{"lowercase", "foo", 'f'},
{"uppercase", "Foo", 'f'},
{"empty", "", 0},
{"digit", "1abc", '1'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix1ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix1(%q) = %d (%c), want %d (%c)", tt.in, got, got, tt.want, tt.want)
}
})
}
}
func TestPrefix2(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want uint16
}{
{"two chars", "ab", uint16('a')<<8 | uint16('b')},
{"uppercase", "AB", uint16('a')<<8 | uint16('b')},
{"single char", "A", uint16('a') << 8},
{"empty", "", 0},
{"longer string", "Hello", uint16('h')<<8 | uint16('e')},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix2ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix2(%q) = %d, want %d", tt.in, got, tt.want)
}
})
}
}
func TestNormalizePathBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"backslash to slash", `C:\Users\test`, "c:/users/test"},
{"collapse slashes", "//foo///bar//", "/foo/bar/"},
{"lowercase", "FooBar", "foobar"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
buf := []byte(tt.input)
got := string(filefinder.NormalizePathBytesForTest(buf))
if got != tt.want {
t.Errorf("normalizePathBytes(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestIsSubsequence(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want bool
}{
{"empty needle", "anything", "", true},
{"empty both", "", "", true},
{"empty haystack", "", "a", false},
{"exact match", "abc", "abc", true},
{"scattered", "axbycz", "abc", true},
{"prefix", "abcdef", "abc", true},
{"suffix", "xyzabc", "abc", true},
{"case insensitive", "AbCdEf", "ace", true},
{"case insensitive reverse", "abcdef", "ACE", true},
{"no match", "abcdef", "xyz", false},
{"partial match", "abcdef", "abz", false},
{"longer needle", "ab", "abc", false},
{"single char match", "hello", "l", true},
{"single char no match", "hello", "z", false},
{"path like", "src/internal/foo.go", "sif", true},
{"path like no match", "src/internal/foo.go", "zzz", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IsSubsequenceForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("isSubsequence(%q, %q) = %v, want %v", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestLongestContiguousMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want int
}{
{"empty needle", "abc", "", 0},
{"empty haystack", "", "abc", 0},
{"full match", "abc", "abc", 3},
{"prefix match", "abcdef", "abc", 3},
{"middle match", "xxabcyy", "abc", 3},
{"suffix match", "xxabc", "abc", 3},
{"partial", "axbc", "abc", 1},
{"scattered no contiguous", "axbxcx", "abc", 1},
{"case insensitive", "ABCdef", "abc", 3},
{"no match", "xyz", "abc", 0},
{"single char", "abc", "b", 1},
{"repeated", "aababc", "abc", 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.LongestContiguousMatchForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("longestContiguousMatch(%q, %q) = %d, want %d", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestIsBoundary(t *testing.T) {
t.Parallel()
for _, b := range []byte{'/', '.', '_', '-'} {
if !filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = false, want true", b)
}
}
for _, b := range []byte{'a', 'Z', '0', ' ', '('} {
if filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = true, want false", b)
}
}
}
func TestCountBoundaryHits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
query string
want int
}{
{"start of string", "foo/bar", "f", 1},
{"after slash", "foo/bar", "fb", 2},
{"after dot", "foo.bar", "fb", 2},
{"after underscore", "foo_bar", "fb", 2},
{"no hits", "xxxx", "y", 0},
{"empty query", "foo", "", 0},
{"empty path", "", "f", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.CountBoundaryHitsForTest([]byte(tt.path), []byte(tt.query))
if got != tt.want {
t.Errorf("countBoundaryHits(%q, %q) = %d, want %d", tt.path, tt.query, got, tt.want)
}
})
}
}
func TestScorePath_NoSubsequenceReturnsZero(t *testing.T) {
t.Parallel()
path := []byte("src/internal/handler.go")
query := []byte("zzz")
tokens := [][]byte{[]byte("zzz")}
params := filefinder.DefaultScoreParamsForTest()
s := filefinder.ScorePathForTest(path, 13, 10, 2, query, tokens, params)
if s != 0 {
t.Errorf("expected 0 for no subsequence match, got %f", s)
}
}
func TestScorePath_ExactBasenameOverPartial(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("main")
tokens := [][]byte{query}
pathExact := []byte("src/main")
scoreExact := filefinder.ScorePathForTest(pathExact, 4, 4, 1, query, tokens, params)
pathPartial := []byte("module/amazing")
scorePartial := filefinder.ScorePathForTest(pathPartial, 7, 7, 1, query, tokens, params)
if scoreExact <= scorePartial {
t.Errorf("exact basename (%f) should score higher than partial (%f)", scoreExact, scorePartial)
}
}
func TestScorePath_BasenamePrefixOverScattered(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("han")
tokens := [][]byte{query}
pathPrefix := []byte("src/handler.go")
scorePrefix := filefinder.ScorePathForTest(pathPrefix, 4, 10, 1, query, tokens, params)
pathScattered := []byte("has/another/thing")
scoreScattered := filefinder.ScorePathForTest(pathScattered, 12, 5, 2, query, tokens, params)
if scorePrefix <= scoreScattered {
t.Errorf("basename prefix (%f) should score higher than scattered (%f)", scorePrefix, scoreScattered)
}
}
func TestScorePath_ShallowOverDeep(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShallow := []byte("src/foo.go")
scoreShallow := filefinder.ScorePathForTest(pathShallow, 4, 6, 1, query, tokens, params)
pathDeep := []byte("a/b/c/d/e/foo.go")
scoreDeep := filefinder.ScorePathForTest(pathDeep, 10, 6, 5, query, tokens, params)
if scoreShallow <= scoreDeep {
t.Errorf("shallow path (%f) should score higher than deep (%f)", scoreShallow, scoreDeep)
}
}
func TestScorePath_ShorterOverLongerSameMatch(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShort := []byte("x/foo")
scoreShort := filefinder.ScorePathForTest(pathShort, 2, 3, 1, query, tokens, params)
pathLong := []byte("x/foo_extremely_long_suffix_name")
scoreLong := filefinder.ScorePathForTest(pathLong, 2, 29, 1, query, tokens, params)
if scoreShort <= scoreLong {
t.Errorf("shorter path (%f) should score higher than longer (%f)", scoreShort, scoreLong)
}
}
func BenchmarkScorePath(b *testing.B) {
path := []byte("src/internal/coderd/database/queries/workspaces.sql")
query := []byte("workspace")
tokens := [][]byte{query}
params := filefinder.DefaultScoreParamsForTest()
baseOff, baseLen := filefinder.ExtractBasenameForTest(path)
s := filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
if s == 0 {
b.Fatal("expected non-zero score for benchmark path")
}
b.ResetTimer()
for b.Loop() {
filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
}
}
-210
View File
@@ -1,210 +0,0 @@
package filefinder
import (
"context"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"cdr.dev/slog/v3"
)
// FSEvent represents a filesystem change event.
type FSEvent struct {
Op FSEventOp
Path string
IsDir bool
}
// FSEventOp represents the type of filesystem operation.
type FSEventOp uint8
// Filesystem operations reported by the watcher.
const (
OpCreate FSEventOp = iota
OpRemove
OpRename
OpModify
)
var skipDirs = map[string]struct{}{
".git": {}, "node_modules": {}, ".hg": {}, ".svn": {},
"__pycache__": {}, ".cache": {}, ".venv": {}, "vendor": {}, ".terraform": {},
}
type fsWatcher struct {
w *fsnotify.Watcher
root string
events chan []FSEvent
logger slog.Logger
mu sync.Mutex
closed bool
done chan struct{}
}
func newFSWatcher(root string, logger slog.Logger) (*fsWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return &fsWatcher{
w: w,
root: root,
events: make(chan []FSEvent, 64),
logger: logger,
done: make(chan struct{}),
}, nil
}
func (fw *fsWatcher) Start(ctx context.Context) {
initEvents := fw.addRecursive(fw.root)
if len(initEvents) > 0 {
select {
case fw.events <- initEvents:
case <-ctx.Done():
return
}
}
fw.logger.Debug(ctx, "fs watcher started", slog.F("root", fw.root))
go fw.loop(ctx)
}
func (fw *fsWatcher) Events() <-chan []FSEvent { return fw.events }
func (fw *fsWatcher) Close() error {
fw.mu.Lock()
if fw.closed {
fw.mu.Unlock()
return nil
}
fw.closed = true
fw.mu.Unlock()
err := fw.w.Close()
<-fw.done
return err
}
func (fw *fsWatcher) loop(ctx context.Context) {
defer close(fw.done)
const batchWindow = 50 * time.Millisecond
var (
batch []FSEvent
seen = make(map[string]struct{})
timer *time.Timer
timerC <-chan time.Time
)
flush := func() {
if len(batch) == 0 {
return
}
select {
case fw.events <- batch:
default:
fw.logger.Warn(ctx, "fs watcher dropping batch", slog.F("count", len(batch)))
}
batch = nil
seen = make(map[string]struct{})
if timer != nil {
timer.Stop()
}
timer = nil
timerC = nil
}
addToBatch := func(ev FSEvent) {
if _, dup := seen[ev.Path]; dup {
return
}
seen[ev.Path] = struct{}{}
batch = append(batch, ev)
if timer == nil {
timer = time.NewTimer(batchWindow)
timerC = timer.C
}
}
for {
select {
case <-ctx.Done():
flush()
return
case ev, ok := <-fw.w.Events:
if !ok {
flush()
return
}
fsev := translateEvent(ev)
if fsev == nil {
continue
}
if fsev.IsDir && fsev.Op == OpCreate {
for _, s := range fw.addRecursive(fsev.Path) {
addToBatch(s)
}
}
addToBatch(*fsev)
case err, ok := <-fw.w.Errors:
if !ok {
flush()
return
}
fw.logger.Warn(ctx, "fsnotify watcher error", slog.Error(err))
case <-timerC:
flush()
}
}
}
func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
var events []FSEvent
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil //nolint:nilerr // best-effort
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if info.IsDir() {
if addErr := fw.w.Add(path); addErr != nil {
fw.logger.Debug(context.Background(), "failed to add watch",
slog.F("path", path), slog.Error(addErr))
}
if path != dir {
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: true})
}
return nil
}
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: false})
return nil
})
return events
}
func translateEvent(ev fsnotify.Event) *FSEvent {
var op FSEventOp
switch {
case ev.Op&fsnotify.Create != 0:
op = OpCreate
case ev.Op&fsnotify.Remove != 0:
op = OpRemove
case ev.Op&fsnotify.Rename != 0:
op = OpRename
case ev.Op&fsnotify.Write != 0:
op = OpModify
default:
return nil
}
isDir := false
if op == OpCreate || op == OpModify {
fi, err := os.Lstat(ev.Name)
if err == nil {
isDir = fi.IsDir()
}
}
if isDir {
if _, skip := skipDirs[filepath.Base(ev.Name)]; skip {
return nil
}
}
return &FSEvent{Op: op, Path: ev.Name, IsDir: isDir}
}
+330 -544
View File
File diff suppressed because it is too large Load Diff
+1 -20
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -530,5 +512,4 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+1 -41
View File
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+3 -7
View File
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+21 -25
View File
@@ -3,11 +3,11 @@
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true,
"defaultBranch": "main",
"defaultBranch": "main"
},
"files": {
"includes": ["**", "!**/pnpm-lock.yaml"],
"ignoreUnknown": true,
"ignoreUnknown": true
},
"linter": {
"rules": {
@@ -15,18 +15,18 @@
"noSvgWithoutTitle": "off",
"useButtonType": "off",
"useSemanticElements": "off",
"noStaticElementInteractions": "off",
"noStaticElementInteractions": "off"
},
"correctness": {
"noUnusedImports": "warn",
"correctness": {
"noUnusedImports": "warn",
"useUniqueElementIds": "off", // TODO: This is new but we want to fix it
"noNestedComponentDefinitions": "off", // TODO: Investigate, since it is used by shadcn components
"noUnusedVariables": {
"level": "warn",
"noUnusedVariables": {
"level": "warn",
"options": {
"ignoreRestSiblings": true,
},
},
"ignoreRestSiblings": true
}
}
},
"style": {
"noNonNullAssertion": "off",
@@ -45,10 +45,6 @@
"level": "error",
"options": {
"paths": {
"react": {
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
"importNames": ["forwardRef"],
},
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
// "@mui/material/Autocomplete": "Use shadcn/ui Combobox instead.",
@@ -115,10 +111,10 @@
"@emotion/styled": "Use Tailwind CSS instead.",
// "@emotion/cache": "Use Tailwind CSS instead.",
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
"lodash": "Use lodash/<name> instead.",
},
},
},
"lodash": "Use lodash/<name> instead."
}
}
}
},
"suspicious": {
"noArrayIndexKey": "off",
@@ -129,14 +125,14 @@
"noConsole": {
"level": "error",
"options": {
"allow": ["error", "info", "warn"],
},
},
"allow": ["error", "info", "warn"]
}
}
},
"complexity": {
"noImportantStyles": "off", // TODO: check and fix !important styles
},
},
"noImportantStyles": "off" // TODO: check and fix !important styles
}
}
},
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
}
+1 -1
View File
@@ -489,7 +489,7 @@ func workspaceAgent() *serpent.Command {
},
{
Flag: "socket-server-enabled",
Default: "true",
Default: "false",
Env: "CODER_AGENT_SOCKET_SERVER_ENABLED",
Description: "Enable the agent socket server.",
Value: serpent.BoolOf(&socketServerEnabled),
-4
View File
@@ -44,7 +44,6 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-token", r.AgentToken,
"--agent-url", client.URL.String(),
"--log-dir", logDir,
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, inv)
@@ -77,7 +76,6 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-token", r.AgentToken,
"--agent-url", client.URL.String(),
"--log-dir", logDir,
"--socket-path", testutil.AgentSocketPath(t),
)
// Set the subsystems for the agent.
inv.Environ.Set(agent.EnvAgentSubsystem, fmt.Sprintf("%s,%s", codersdk.AgentSubsystemExectrace, codersdk.AgentSubsystemEnvbox))
@@ -160,7 +158,6 @@ func TestWorkspaceAgent(t *testing.T) {
"--agent-header", "X-Testing=agent",
"--agent-header", "Cool-Header=Ethan was Here!",
"--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow",
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, agentInv)
coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).
@@ -202,7 +199,6 @@ func TestWorkspaceAgent(t *testing.T) {
"--pprof-address", "",
"--prometheus-address", "",
"--debug-address", "",
"--socket-path", testutil.AgentSocketPath(t),
)
clitest.Start(t, inv)
+3 -9
View File
@@ -30,15 +30,9 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
_, _ = fmt.Fprint(inv.Stdout, "\033[1A")
var defaults []string
defaultSource := defaultValue
if defaultSource == "" {
defaultSource = templateVersionParameter.DefaultValue
}
if defaultSource != "" {
err = json.Unmarshal([]byte(defaultSource), &defaults)
if err != nil {
return "", err
}
err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &defaults)
if err != nil {
return "", err
}
values, err := RichMultiSelect(inv, RichMultiSelectOptions{
+45 -50
View File
@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -24,7 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
}
}
}()
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
}
return s.queue.Push(taskReport{
link: args.Link,
+1 -185
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We override idle from the agent to working, but trust final states.
// We ignore the state from the agent and assume "working".
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
-12
View File
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
templateVersionJobTimeout time.Duration
prebuildWorkspaceTimeout time.Duration
noCleanup bool
provisionerTags []string
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
for i := range numTemplates {
id := strconv.Itoa(int(i))
cfg := prebuilds.Config{
OrganizationID: me.OrganizationIDs[0],
ProvisionerTags: tags,
NumPresets: int(numPresets),
NumPresetPrebuilds: int(numPresetPrebuilds),
TemplateVersionJobTimeout: templateVersionJobTimeout,
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
Description: "Skip cleanup (deletion test) and leave resources intact.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
tracingFlags.attach(&cmd.Options)
+1 -46
View File
@@ -4,9 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"golang.org/x/xerrors"
@@ -19,29 +16,6 @@ import (
"github.com/coder/serpent"
)
// detectGitRef attempts to resolve the current git branch and remote
// origin URL from the given working directory. These are sent to the
// control plane so it can look up PR/diff status via the GitHub API
// without SSHing into the workspace. Failures are silently ignored
// since this is best-effort.
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
run := func(args ...string) string {
//nolint:gosec
cmd := exec.Command(args[0], args[1:]...)
if workingDirectory != "" {
cmd.Dir = workingDirectory
}
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
return branch, remoteOrigin
}
// gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname.
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
@@ -64,21 +38,8 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("create agent client: %w", err)
}
workingDirectory, err := os.Getwd()
if err != nil {
workingDirectory = ""
}
// Detect the current git branch and remote origin so
// the control plane can resolve diffs without needing
// to SSH back into the workspace.
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
ChatID: inv.Environ.Get("CODER_CHAT_ID"),
Match: host,
})
if err != nil {
var apiError *codersdk.Error
@@ -97,12 +58,6 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("get git token: %w", err)
}
if token.URL != "" {
// This is to help the agent authenticate with Git.
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
return cliui.ErrCanceled
}
if err := openURL(inv, token.URL); err == nil {
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
} else {
+5 -1
View File
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+5 -29
View File
@@ -1,7 +1,6 @@
package cli
import (
"encoding/json"
"fmt"
"strings"
@@ -232,7 +231,7 @@ next:
continue // immutables should not be passed to consecutive builds
}
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, *tvp) {
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, tvp.Options) {
continue // do not propagate invalid options
}
@@ -298,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
if !tvp.Mutable && action != WorkspaceCreate {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
@@ -366,7 +365,7 @@ func (pr *ParameterResolver) isLastBuildParameterInvalidOption(templateVersionPa
for _, buildParameter := range pr.lastBuildParameters {
if buildParameter.Name == templateVersionParameter.Name {
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter)
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter.Options)
}
}
return false
@@ -390,31 +389,8 @@ func findWorkspaceBuildParameter(parameterName string, params []codersdk.Workspa
return nil
}
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, templateVersionParameter codersdk.TemplateVersionParameter) bool {
// Multi-select parameters store values as a JSON array (e.g.
// '["vim","emacs"]'), so we need to parse the array and validate
// each element individually against the allowed options.
if templateVersionParameter.Type == "list(string)" {
var values []string
if err := json.Unmarshal([]byte(buildParameter.Value), &values); err != nil {
return false
}
for _, v := range values {
found := false
for _, opt := range templateVersionParameter.Options {
if opt.Value == v {
found = true
break
}
}
if !found {
return false
}
}
return true
}
for _, opt := range templateVersionParameter.Options {
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, options []codersdk.TemplateVersionParameterOption) bool {
for _, opt := range options {
if opt.Value == buildParameter.Value {
return true
}
-85
View File
@@ -1,85 +0,0 @@
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/codersdk"
)
func TestIsValidTemplateParameterOption(t *testing.T) {
t.Parallel()
options := []codersdk.TemplateVersionParameterOption{
{Name: "Vim", Value: "vim"},
{Name: "Emacs", Value: "emacs"},
{Name: "VS Code", Value: "vscode"},
}
t.Run("SingleSelectValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "vim"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("SingleSelectInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "notepad"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectAllValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","emacs"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectOneInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","notepad"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectEmptyArray", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `[]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectInvalidJSON", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `not-json`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
}
+6 -26
View File
@@ -2,7 +2,6 @@ package cli_test
import (
"bytes"
"cmp"
"context"
"database/sql"
"encoding/json"
@@ -21,6 +20,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
)
@@ -35,10 +35,7 @@ func TestProvisioners_Golden(t *testing.T) {
provisioners, err := coderdAPI.Database.GetProvisionerDaemons(systemCtx)
require.NoError(t, err)
slices.SortFunc(provisioners, func(a, b database.ProvisionerDaemon) int {
return cmp.Or(
a.CreatedAt.Compare(b.CreatedAt),
bytes.Compare(a.ID[:], b.ID[:]),
)
return a.CreatedAt.Compare(b.CreatedAt)
})
pIdx := 0
for _, p := range provisioners {
@@ -50,10 +47,7 @@ func TestProvisioners_Golden(t *testing.T) {
jobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
require.NoError(t, err)
slices.SortFunc(jobs, func(a, b database.ProvisionerJob) int {
return cmp.Or(
a.CreatedAt.Compare(b.CreatedAt),
bytes.Compare(a.ID[:], b.ID[:]),
)
return a.CreatedAt.Compare(b.CreatedAt)
})
jIdx := 0
for _, j := range jobs {
@@ -82,15 +76,11 @@ func TestProvisioners_Golden(t *testing.T) {
firstProvisioner := coderdtest.NewTaggedProvisionerDaemon(t, coderdAPI, "default-provisioner", map[string]string{"owner": "", "scope": "organization"})
t.Cleanup(func() { _ = firstProvisioner.Close() })
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, completeWithAgent())
version = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status,
"template version import should succeed, got error: %s", version.Job.Error)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
wb := coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
require.Equal(t, codersdk.ProvisionerJobSucceeded, wb.Job.Status,
"workspace build job should succeed, got error: %s", wb.Job.Error)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Stop the provisioner so it doesn't grab any more jobs.
firstProvisioner.Close()
@@ -99,17 +89,7 @@ func TestProvisioners_Golden(t *testing.T) {
replace[version.ID.String()] = "00000000-0000-0000-cccc-000000000000"
replace[workspace.LatestBuild.ID.String()] = "00000000-0000-0000-dddd-000000000000"
// Base synthetic times off the latest real job's CreatedAt, not the
// wall clock. Using dbtime.Now() here is racy because NTP clock
// steps can make it return a time before the real jobs' CreatedAt.
systemCtx := dbauthz.AsSystemRestricted(context.Background())
existingJobs, err := coderdAPI.Database.GetProvisionerJobsCreatedAfter(systemCtx, time.Time{})
require.NoError(t, err)
require.NotEmpty(t, existingJobs, "expected at least one provisioner job")
latestJob := slices.MaxFunc(existingJobs, func(a, b database.ProvisionerJob) int {
return a.CreatedAt.Compare(b.CreatedAt)
})
now := latestJob.CreatedAt.Add(time.Second)
now := dbtime.Now()
// Create a provisioner that's working on a job.
pd1 := dbgen.ProvisionerDaemon(t, coderdAPI.Database, database.ProvisionerDaemon{
+4 -15
View File
@@ -884,27 +884,16 @@ func (o *OrganizationContext) Selected(inv *serpent.Invocation, client *codersdk
index := slices.IndexFunc(orgs, func(org codersdk.Organization) bool {
return org.Name == o.FlagSelect || org.ID.String() == o.FlagSelect
})
if index >= 0 {
return orgs[index], nil
}
// Not in membership list - try direct fetch.
// This allows site-wide admins (e.g., Owners) to use orgs they aren't
// members of.
org, err := client.OrganizationByName(inv.Context(), o.FlagSelect)
if err != nil {
if index < 0 {
var names []string
for _, org := range orgs {
names = append(names, org.Name)
}
var sdkErr *codersdk.Error
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound {
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
}
return codersdk.Organization{}, xerrors.Errorf("get organization %q: %w", o.FlagSelect, err)
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
}
return org, nil
return orgs[index], nil
}
if len(orgs) == 1 {
+45 -129
View File
@@ -95,7 +95,6 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/cryptorand"
@@ -137,15 +136,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
if err != nil {
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
}
if vals.OIDC.RedirectURL.String() != "" {
redirectURL, err = vals.OIDC.RedirectURL.Value().Parse("/api/v2/users/oidc/callback")
if err != nil {
return nil, xerrors.Errorf("parse oidc redirect url %q", err)
}
logger.Warn(ctx, "custom OIDC redirect URL used instead of 'access_url', ensure this matches the value configured in your OIDC provider")
}
// If the scopes contain 'groups', we enable group support.
// Do not override any custom value set by the user.
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
@@ -617,8 +607,28 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
@@ -649,7 +659,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Pubsub: nil,
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: nil,
ExternalAuthConfigs: externalAuthConfigs,
RealIPConfig: realIPConfig,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
@@ -809,43 +819,9 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
ctx,
options.Logger,
options.Database,
vals,
mergedExternalAuthProviders,
)
if err != nil {
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
}
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
oauthInstrument,
mergedExternalAuthProviders,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range options.ExternalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
// Manage push notifications.
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
if experiments.Enabled(codersdk.ExperimentWebPush) || buildinfo.IsDev() {
if experiments.Enabled(codersdk.ExperimentWebPush) {
if !strings.HasPrefix(options.AccessURL.String(), "https://") {
options.Logger.Warn(ctx, "access URL is not HTTPS, so web push notifications may not work on some browsers", slog.F("access_url", options.AccessURL.String()))
}
@@ -959,12 +935,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
options.StatsBatcher = batcher
defer closeBatcher()
wsBuilderMetrics, err := wsbuilder.NewMetrics(options.PrometheusRegistry)
if err != nil {
return xerrors.Errorf("failed to register workspace builder metrics: %w", err)
}
options.WorkspaceBuilderMetrics = wsBuilderMetrics
// Manage notifications.
var (
notificationsCfg = options.DeploymentValues.Notifications
@@ -1148,7 +1118,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
defer autobuildTicker.Stop()
autobuildExecutor := autobuild.NewExecutor(
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments, coderAPI.WorkspaceBuilderMetrics)
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments)
autobuildExecutor.Run()
jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value())
@@ -1940,79 +1910,6 @@ type githubOAuth2ConfigParams struct {
enterpriseBaseURL string
}
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
// We want to enable the default provider only for new deployments, and avoid
// enabling it if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return false, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return false, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
return defaultEligible, nil
}
func maybeAppendDefaultGithubExternalAuthProvider(
ctx context.Context,
logger slog.Logger,
db database.Store,
vals *codersdk.DeploymentValues,
mergedExplicitProviders []codersdk.ExternalAuthConfig,
) ([]codersdk.ExternalAuthConfig, error) {
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "disabled by configuration"),
slog.F("flag", "external-auth-github-default-provider-enable"),
)
return mergedExplicitProviders, nil
}
if len(mergedExplicitProviders) > 0 {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "explicit external auth providers configured"),
slog.F("provider_count", len(mergedExplicitProviders)),
)
return mergedExplicitProviders, nil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "deployment is not eligible"),
)
return mergedExplicitProviders, nil
}
logger.Info(ctx, "injecting default github external auth provider",
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
)
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
ClientID: GithubOAuth2DefaultProviderClientID,
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
}), nil
}
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
params := githubOAuth2ConfigParams{
accessURL: vals.AccessURL.Value(),
@@ -2037,9 +1934,28 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return nil, nil //nolint:nilnil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
// We want to enable it only for new deployments, and avoid enabling it
// if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
if !defaultEligible {
-163
View File
@@ -53,7 +53,6 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
@@ -303,7 +302,6 @@ func TestServer(t *testing.T) {
"open install.sh: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
}
countLines := func(fullOutput string) int {
@@ -1742,18 +1740,6 @@ func TestServer(t *testing.T) {
// Next, we instruct the same server to display the YAML config
// and then save it.
// Because this is literally the same invocation, DefaultFn sets the
// value of 'Default'. Which triggers a mutually exclusive error
// on the next parse.
// Usually we only parse flags once, so this is not an issue
for _, c := range inv.Command.Children {
if c.Name() == "server" {
for i := range c.Options {
c.Options[i].DefaultFn = nil
}
break
}
}
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
//nolint:gocritic
inv.Args = append(args, "--write-config")
@@ -1807,155 +1793,6 @@ func TestServer(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
type testCase struct {
name string
args []string
env map[string]string
createUserPreStart bool
expectedProviders []string
}
run := func(t *testing.T, tc testCase) {
ctx := testutil.Context(t, testutil.WaitLong)
unsetPrefixedEnv := func(prefix string) {
t.Helper()
for _, envVar := range os.Environ() {
envKey, _, found := strings.Cut(envVar, "=")
if !found || !strings.HasPrefix(envKey, prefix) {
continue
}
value, had := os.LookupEnv(envKey)
require.True(t, had)
require.NoError(t, os.Unsetenv(envKey))
keyCopy := envKey
valueCopy := value
t.Cleanup(func() {
// This is for setting/unsetting a number of prefixed env vars.
// t.Setenv doesn't cover this use case.
// nolint:usetesting
_ = os.Setenv(keyCopy, valueCopy)
})
}
}
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
unsetPrefixedEnv("CODER_GITAUTH_")
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
const (
existingUserEmail = "existing-user@coder.com"
existingUserUsername = "existing-user"
existingUserPassword = "SomeSecurePassword!"
)
if tc.createUserPreStart {
hashedPassword, err := userpassword.Hash(existingUserPassword)
require.NoError(t, err)
_ = dbgen.User(t, db, database.User{
Email: existingUserEmail,
Username: existingUserUsername,
HashedPassword: []byte(hashedPassword),
})
}
args := []string{
"server",
"--postgres-url", dbURL,
"--http-address", ":0",
"--access-url", "https://example.com",
}
args = append(args, tc.args...)
inv, cfg := clitest.New(t, args...)
for envKey, value := range tc.env {
t.Setenv(envKey, value)
}
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
if tc.createUserPreStart {
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: existingUserEmail,
Password: existingUserPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
} else {
_ = coderdtest.CreateFirstUser(t, client)
}
externalAuthResp, err := client.ListExternalAuths(ctx)
require.NoError(t, err)
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
for _, provider := range externalAuthResp.Providers {
gotProviders[provider.ID] = provider
}
require.Len(t, gotProviders, len(tc.expectedProviders))
for _, providerID := range tc.expectedProviders {
provider, ok := gotProviders[providerID]
require.Truef(t, ok, "expected provider %q to be configured", providerID)
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
require.True(t, provider.Device)
}
}
}
for _, tc := range []testCase{
{
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
},
{
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
createUserPreStart: true,
expectedProviders: nil,
},
{
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
args: []string{
"--external-auth-github-default-provider-enable=false",
},
expectedProviders: nil,
},
{
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
args: []string{
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
} {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+31 -7
View File
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
orgOwner = coderdtest.CreateFirstUser(t, client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+7 -9
View File
@@ -1,3 +1,5 @@
//go:build !windows
package cli_test
import (
@@ -5,7 +7,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"time"
@@ -24,15 +25,12 @@ func setupSocketServer(t *testing.T) (path string, cleanup func()) {
t.Helper()
// Use a temporary socket path for each test
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
// Create parent directory if needed. Not necessary on Windows because named pipes live in an abstract namespace
// not tied to any real files.
if runtime.GOOS != "windows" {
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
}
// Create parent directory if needed
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
-2
View File
@@ -17,8 +17,6 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
r.taskDelete(),
r.taskList(),
r.taskLogs(),
r.taskPause(),
r.taskResume(),
r.taskSend(),
r.taskStatus(),
},
+10 -5
View File
@@ -41,7 +41,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client // user already has access to their own workspace
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
output := clitest.Capture(inv)
@@ -64,7 +65,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
output := clitest.Capture(inv)
@@ -87,7 +89,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String())
output := clitest.Capture(inv)
@@ -141,7 +144,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String())
clitest.SetupConfig(t, userClient, root)
@@ -197,7 +201,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Run("SnapshotWithoutLogs_NoSnapshotCaptured", func(t *testing.T) {
t.Parallel()
userClient, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
client, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
userClient := client
inv, root := clitest.New(t, "task", "logs", task.Name)
output := clitest.Capture(inv)
-90
View File
@@ -1,90 +0,0 @@
package cli
import (
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func (r *RootCmd) taskPause() *serpent.Command {
cmd := &serpent.Command{
Use: "pause <task>",
Short: "Pause a task",
Long: FormatExamples(
Example{
Description: "Pause a task by name",
Command: "coder task pause my-task",
},
Example{
Description: "Pause another user's task",
Command: "coder task pause alice/my-task",
},
Example{
Description: "Pause a task without confirmation",
Command: "coder task pause my-task --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Options: serpent.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
if err != nil {
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
if task.Status == codersdk.TaskStatusPaused {
return xerrors.Errorf("task %q is already paused", display)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Pause task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
resp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("pause task %q: %w", display, err)
}
if resp.WorkspaceBuild == nil {
return xerrors.Errorf("pause task %q: no workspace build returned", display)
}
err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID)
if err != nil {
return xerrors.Errorf("watch pause build for task %q: %w", display, err)
}
_, _ = fmt.Fprintf(
inv.Stdout,
"\nThe %s task has been paused at %s!\n",
cliui.Keyword(task.Name),
cliui.Timestamp(time.Now()),
)
return nil
},
}
return cmd
}
-144
View File
@@ -1,144 +0,0 @@
package cli_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskPause(t *testing.T) {
t.Parallel()
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: Expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
// OtherUserTask verifies that an admin can pause a task owned by
// another user using the "owner/name" identifier format.
t.Run("OtherUserTask", func(t *testing.T) {
t.Parallel()
// Given: A different user's running task
setupCtx := testutil.Context(t, testutil.WaitLong)
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause their task
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("PromptConfirm", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to pause the task
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Pause task")
pty.WriteLine("yes")
// Then: We expect the task to be paused
pty.ExpectMatchContext(ctx, "has been paused")
require.NoError(t, w.Wait())
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("PromptDecline", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: We say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Pause task")
pty.WriteLine("no")
require.Error(t, w.Wait())
// Then: We expect the task to not be paused
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("TaskAlreadyPaused", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// And: We paused the running task
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, resp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
// When: We attempt to pause the task again
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is already paused
err = inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "is already paused")
})
}
-95
View File
@@ -1,95 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func (r *RootCmd) taskResume() *serpent.Command {
var noWait bool
cmd := &serpent.Command{
Use: "resume <task>",
Short: "Resume a task",
Long: FormatExamples(
Example{
Description: "Resume a task by name",
Command: "coder task resume my-task",
},
Example{
Description: "Resume another user's task",
Command: "coder task resume alice/my-task",
},
Example{
Description: "Resume a task without confirmation",
Command: "coder task resume my-task --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Options: serpent.OptionSet{
{
Flag: "no-wait",
Description: "Return immediately after resuming the task.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
if err != nil {
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
if task.Status == codersdk.TaskStatusError || task.Status == codersdk.TaskStatusUnknown {
return xerrors.Errorf("task %q is in %s state and cannot be resumed; check the workspace build logs and agent status for details", display, task.Status)
} else if task.Status != codersdk.TaskStatusPaused {
return xerrors.Errorf("task %q cannot be resumed (current status: %s)", display, task.Status)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Resume task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("resume task %q: %w", display, err)
} else if resp.WorkspaceBuild == nil {
return xerrors.Errorf("resume task %q: no workspace build returned", display)
}
if noWait {
_, _ = fmt.Fprintf(inv.Stdout, "Resuming task %q in the background.\n", cliui.Keyword(display))
return nil
}
if err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID); err != nil {
return xerrors.Errorf("watch resume build for task %q: %w", display, err)
}
_, _ = fmt.Fprintf(inv.Stdout, "\nThe %s task has been resumed.\n", cliui.Keyword(display))
return nil
},
}
return cmd
}
-183
View File
@@ -1,183 +0,0 @@
package cli_test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskResume(t *testing.T) {
t.Parallel()
// pauseTask is a helper that pauses a task and waits for the stop
// build to complete.
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, pauseResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
}
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
// OtherUserTask verifies that an admin can resume a task owned by
// another user using the "owner/name" identifier format.
t.Run("OtherUserTask", func(t *testing.T) {
t.Parallel()
// Given: A different user's paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume their task
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task (and specify no wait)
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed in the background
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "in the background")
// And: The task to eventually be resumed
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptConfirm", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to resume the task
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("yes")
// Then: We expect the task to be resumed
pty.ExpectMatchContext(ctx, "has been resumed")
require.NoError(t, w.Wait())
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptDecline", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: Say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("no")
require.Error(t, w.Wait())
// Then: We expect the task to still be paused
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("TaskNotPaused", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to resume the task that is not paused
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is not paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "cannot be resumed")
})
}
+7 -4
View File
@@ -25,7 +25,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
@@ -41,7 +42,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
@@ -57,7 +59,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
@@ -110,7 +113,7 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
+10 -44
View File
@@ -120,40 +120,6 @@ func Test_Tasks(t *testing.T) {
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
},
},
{
name: "pause task",
cmdArgs: []string{"task", "pause", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, "has been paused", "pause output should confirm task was paused")
},
},
{
name: "get task status after pause",
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, taskName, task.Name, "task name should match")
require.Equal(t, codersdk.TaskStatusPaused, task.Status, "task should be paused")
},
},
{
name: "resume task",
cmdArgs: []string{"task", "resume", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, "has been resumed", "resume output should confirm task was resumed")
},
},
{
name: "get task status after resume",
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, taskName, task.Name, "task name should match")
require.Equal(t, codersdk.TaskStatusInitializing, task.Status, "task should be initializing after resume")
},
},
{
name: "delete task",
cmdArgs: []string{"task", "delete", taskName, "--yes"},
@@ -272,17 +238,17 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
// setupCLITaskTest creates a test workspace with an AI task template and agent,
// with a fake agent API configured with the provided set of handlers.
// Returns the user client and workspace.
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
t.Helper()
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, ownerClient)
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
fakeAPI := startFakeAgentAPI(t, agentAPIHandlers)
authToken := uuid.NewString()
template := createAITaskTemplate(t, ownerClient, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
wantPrompt := "test prompt"
task, err := userClient.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
@@ -296,17 +262,17 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
WaitFor(coderdtest.AgentsReady)
return ownerClient, userClient, task
return userClient, task
}
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
-4
View File
@@ -139,10 +139,8 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
type templateVersionRow struct {
// For json format:
TemplateVersion codersdk.TemplateVersion `table:"-"`
ActiveJSON bool `json:"active" table:"-"`
// For table format:
ID string `json:"-" table:"id"`
Name string `json:"-" table:"name,default_sort"`
CreatedAt time.Time `json:"-" table:"created at"`
CreatedBy string `json:"-" table:"created by"`
@@ -168,8 +166,6 @@ func templateVersionsToRows(activeVersionID uuid.UUID, templateVersions ...coder
rows[i] = templateVersionRow{
TemplateVersion: templateVersion,
ActiveJSON: templateVersion.ID == activeVersionID,
ID: templateVersion.ID.String(),
Name: templateVersion.Name,
CreatedAt: templateVersion.CreatedAt,
CreatedBy: templateVersion.CreatedBy.Username,
-29
View File
@@ -1,9 +1,7 @@
package cli_test
import (
"bytes"
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,33 +40,6 @@ func TestTemplateVersions(t *testing.T) {
pty.ExpectMatch(version.CreatedBy.Username)
pty.ExpectMatch("Active")
})
t.Run("ListVersionsJSON", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "templates", "versions", "list", template.Name, "--output", "json")
clitest.SetupConfig(t, member, root)
var stdout bytes.Buffer
inv.Stdout = &stdout
require.NoError(t, inv.Run())
var rows []struct {
TemplateVersion codersdk.TemplateVersion `json:"TemplateVersion"`
Active bool `json:"active"`
}
require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows))
require.Len(t, rows, 1)
assert.Equal(t, version.ID, rows[0].TemplateVersion.ID)
assert.True(t, rows[0].Active)
})
}
func TestTemplateVersionsPromote(t *testing.T) {
+1 -1
View File
@@ -74,7 +74,7 @@ OPTIONS:
--socket-path string, $CODER_AGENT_SOCKET_PATH
Specify the path for the agent socket.
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: true)
--socket-server-enabled bool, $CODER_AGENT_SOCKET_SERVER_ENABLED (default: false)
Enable the agent socket server.
--ssh-max-timeout duration, $CODER_AGENT_SSH_MAX_TIMEOUT (default: 72h)
+11 -22
View File
@@ -49,9 +49,10 @@ OPTIONS:
security purposes if a --wildcard-access-url is configured.
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
Disable workspace sharing. Workspace ACL checking is disabled and only
owners can have ssh, apps and terminal access to workspaces. Access
based on the 'owner' role is also allowed unless disabled via
Disable workspace sharing (requires the "workspace-sharing" experiment
to be enabled). Workspace ACL checking is disabled and only owners can
have ssh, apps and terminal access to workspaces. Access based on the
'owner' role is also allowed unless disabled via
--disable-owner-workspace-access.
--swagger-enable bool, $CODER_SWAGGER_ENABLE
@@ -62,9 +63,6 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
@@ -175,22 +173,19 @@ AI BRIDGE OPTIONS:
exporting these records to external SIEM or observability systems.
AI BRIDGE PROXY OPTIONS:
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
Path to the CA certificate file for AI Bridge Proxy.
--aibridge-proxy-enabled bool, $CODER_AIBRIDGE_PROXY_ENABLED (default: false)
Enable the AI Bridge MITM Proxy for intercepting and decrypting AI
provider requests.
--aibridge-proxy-key-file string, $CODER_AIBRIDGE_PROXY_KEY_FILE
Path to the CA private key file for AI Bridge Proxy.
--aibridge-proxy-listen-addr string, $CODER_AIBRIDGE_PROXY_LISTEN_ADDR (default: :8888)
The address the AI Bridge Proxy will listen on.
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
Path to the CA certificate file used to intercept (MITM) HTTPS traffic
from AI clients. This CA must be trusted by AI clients for the proxy
to decrypt their requests.
--aibridge-proxy-key-file string, $CODER_AIBRIDGE_PROXY_KEY_FILE
Path to the CA private key file used to intercept (MITM) HTTPS traffic
from AI clients.
--aibridge-proxy-upstream string, $CODER_AIBRIDGE_PROXY_UPSTREAM
URL of an upstream HTTP proxy to chain tunneled (non-allowlisted)
requests through. Format: http://[user:pass@]host:port or
@@ -388,19 +383,13 @@ NETWORKING OPTIONS:
--samesite-auth-cookie lax|none, $CODER_SAMESITE_AUTH_COOKIE (default: lax)
Controls the 'SameSite' property is set on browser session cookies.
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE (default: false)
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE
Controls if the 'Secure' property is set on browser session cookies.
--wildcard-access-url string, $CODER_WILDCARD_ACCESS_URL
Specifies the wildcard hostname to use for workspace applications in
the form "*.example.com".
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
Recommended to be enabled. Enables `__Host-` prefix for cookies to
guarantee they are only set by the right domain. This change is
disruptive to any workspaces built before release 2.31, requiring a
workspace restart.
NETWORKING / DERP OPTIONS:
Most Coder deployments never have to think about DERP because all connections
between workspaces and users are peer-to-peer. However, when Coder cannot
-2
View File
@@ -12,8 +12,6 @@ SUBCOMMANDS:
delete Delete tasks
list List tasks
logs Show a task's logs
pause Pause a task
resume Resume a task
send Send input to a task
status Show the status of a task.
-25
View File
@@ -1,25 +0,0 @@
coder v0.0.0-devel
USAGE:
coder task pause [flags] <task>
Pause a task
- Pause a task by name:
$ coder task pause my-task
- Pause another user's task:
$ coder task pause alice/my-task
- Pause a task without confirmation:
$ coder task pause my-task --yes
OPTIONS:
-y, --yes bool
Bypass confirmation prompts.
———
Run `coder --help` for a list of global options.
-28
View File
@@ -1,28 +0,0 @@
coder v0.0.0-devel
USAGE:
coder task resume [flags] <task>
Resume a task
- Resume a task by name:
$ coder task resume my-task
- Resume another user's task:
$ coder task resume alice/my-task
- Resume a task without confirmation:
$ coder task resume my-task --yes
OPTIONS:
--no-wait bool
Return immediately after resuming the task.
-y, --yes bool
Bypass confirmation prompts.
———
Run `coder --help` for a list of global options.
+1 -1
View File
@@ -9,7 +9,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
-c, --column [name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
Columns to display in table output.
--include-archived bool
+1 -1
View File
@@ -27,7 +27,7 @@ USAGE:
SUBCOMMANDS:
create Create a token
list List tokens
remove Expire or delete a token
remove Delete a token
view Display detailed information about a token
———
-4
View File
@@ -15,10 +15,6 @@ OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
Columns to display in table output.
--include-expired bool
Include expired tokens in the output. By default, expired tokens are
hidden.
-o, --output table|json (default: table)
Output format.
+2 -10
View File
@@ -1,19 +1,11 @@
coder v0.0.0-devel
USAGE:
coder tokens remove [flags] <name|id|token>
coder tokens remove <name|id|token>
Expire or delete a token
Delete a token
Aliases: delete, rm
Remove a token by expiring it. Use --delete to permanently hard-delete the
token instead.
OPTIONS:
--delete bool
Permanently delete the token instead of expiring it. This removes the
audit trail.
———
Run `coder --help` for a list of global options.
+7 -23
View File
@@ -176,16 +176,11 @@ networking:
# (default: <unset>, type: string-array)
proxyTrustedOrigins: []
# Controls if the 'Secure' property is set on browser session cookies.
# (default: false, type: bool)
# (default: <unset>, type: bool)
secureAuthCookie: false
# Controls the 'SameSite' property is set on browser session cookies.
# (default: lax, type: enum[lax\|none])
sameSiteAuthCookie: lax
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
# they are only set by the right domain. This change is disruptive to any
# workspaces built before release 2.31, requiring a workspace restart.
# (default: false, type: bool)
hostPrefixCookie: false
# Whether Coder only allows connections to workspaces via the browser.
# (default: <unset>, type: bool)
browserOnly: false
@@ -422,11 +417,6 @@ oidc:
# an insecure OIDC configuration. It is not recommended to use this flag.
# (default: <unset>, type: bool)
dangerousSkipIssuerChecks: false
# Optional override of the default redirect url which uses the deployment's access
# url. Useful in situations where a deployment has more than 1 domain. Using this
# setting can also break OIDC, so use with caution.
# (default: <unset>, type: url)
oidc-redirect-url:
# Telemetry is critical to our ability to improve Coder. We strip all personal
# information before sending data to our servers. Please only disable telemetry
# when required by your organization's security policy.
@@ -524,10 +514,10 @@ disablePathApps: false
# workspaces.
# (default: <unset>, type: bool)
disableOwnerWorkspaceAccess: false
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
# can have ssh, apps and terminal access to workspaces. Access based on the
# 'owner' role is also allowed unless disabled via
# --disable-owner-workspace-access.
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
# and terminal access to workspaces. Access based on the 'owner' role is also
# allowed unless disabled via --disable-owner-workspace-access.
# (default: <unset>, type: bool)
disableWorkspaceSharing: false
# These options change the behavior of how clients interact with the Coder.
@@ -564,9 +554,6 @@ supportLinks: []
# External Authentication providers.
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
externalAuthProviders: []
# Enable the default GitHub external auth provider managed by Coder.
# (default: true, type: bool)
externalAuthGithubDefaultProviderEnable: true
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
# "tunnel.example.com".
@@ -830,13 +817,10 @@ aibridgeproxy:
# The address the AI Bridge Proxy will listen on.
# (default: :8888, type: string)
listen_addr: :8888
# Path to the CA certificate file used to intercept (MITM) HTTPS traffic from AI
# clients. This CA must be trusted by AI clients for the proxy to decrypt their
# requests.
# Path to the CA certificate file for AI Bridge Proxy.
# (default: <unset>, type: string)
cert_file: ""
# Path to the CA private key file used to intercept (MITM) HTTPS traffic from AI
# clients.
# Path to the CA private key file for AI Bridge Proxy.
# (default: <unset>, type: string)
key_file: ""
# Comma-separated list of AI provider domains for which HTTPS traffic will be
+14 -37
View File
@@ -218,10 +218,9 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
var (
all bool
includeExpired bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
all bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, defaultCols),
cliui.JSONFormat(),
)
@@ -241,8 +240,7 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeExpired: includeExpired,
IncludeAll: all,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
@@ -276,12 +274,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
Description: "Specifies whether all users' tokens will be listed or not (must have Owner role to see all tokens).",
Value: serpent.BoolOf(&all),
},
{
Name: "include-expired",
Flag: "include-expired",
Description: "Include expired tokens in the output. By default, expired tokens are hidden.",
Value: serpent.BoolOf(&includeExpired),
},
}
formatter.AttachOptions(&cmd.Options)
@@ -331,13 +323,10 @@ func (r *RootCmd) viewToken() *serpent.Command {
}
func (r *RootCmd) removeToken() *serpent.Command {
var deleteToken bool
cmd := &serpent.Command{
Use: "remove <name|id|token>",
Aliases: []string{"delete"},
Short: "Expire or delete a token",
Long: "Remove a token by expiring it. Use --delete to permanently hard-" +
"delete the token instead.",
Short: "Delete a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
@@ -349,7 +338,7 @@ func (r *RootCmd) removeToken() *serpent.Command {
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, inv.Args[0])
if err != nil {
// If it's a token, we need to extract the ID.
// If it's a token, we need to extract the ID
maybeID := strings.Split(inv.Args[0], "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
@@ -357,29 +346,17 @@ func (r *RootCmd) removeToken() *serpent.Command {
}
}
if deleteToken {
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been deleted.")
return nil
}
err = client.ExpireAPIKey(inv.Context(), codersdk.Me, token.ID)
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("expire api key: %w", err)
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been expired.")
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "delete",
Description: "Permanently delete the token instead of expiring it. This removes the audit trail.",
Value: serpent.BoolOf(&deleteToken),
cliui.Infof(
inv.Stdout,
"Token has been deleted.",
)
return nil
},
}

Some files were not shown because too many files have changed in this diff Show More