Compare commits

..

50 Commits

Author SHA1 Message Date
Charlie Voiselle eee13c42a4 docs(cli): reference --oidc-group-mapping flag name instead of 'legacy'
Issue: Used internal nomenclature instead of user-facing flag name

The previous fix referenced 'legacy group name mapping' but users don't
know what that means - it's an internal implementation detail. Users
configure this via the --oidc-group-mapping flag.

Changed to: 'This filter is applied after the oidc-group-mapping.'

This directly references the flag name users would actually use, making
the relationship clear and actionable. Users can now understand:
- The regex filter applies to group names
- Group names may have been transformed by --oidc-group-mapping first
- They need to write regex patterns that match the mapped names

Example: If --oidc-group-mapping transforms 'developers' to 'dev-team',
the regex in --oidc-group-regex-filter will match against 'dev-team'.
2026-02-09 16:14:00 -05:00
Charlie Voiselle 65b48c0f84 docs(cli): fix help text for --oidc-group-regex-filter (clarify mapping order)
Issue: Removed ordering information when it was actually helpful

The previous correction removed the sentence about filter order to avoid
confusion, but this actually made the description LESS clear. Users need
to understand that the regex filter operates on group names AFTER any
legacy name mapping has been applied.

Example: If IdP sends 'developers' and LegacyNameMapping renames it to
'dev-team', the regex filter will match against 'dev-team', not 'developers'.

Changed to: 'This filter is applied after legacy group name mapping.'

This clarifies:
1. It's the LEGACY mapping (name→name) not the new Mapping (name→IDs)
2. The regex operates on potentially-renamed group names
3. The filter happens before the final ID mapping

Code reference: coderd/idpsync/group.go lines 379-398
- Line 380: LegacyNameMapping (name → name)
- Line 386: RegexFilter (on the potentially renamed name)
- Line 392: Mapping (name → []uuid.UUID)
2026-02-09 16:13:59 -05:00
Charlie Voiselle 30cdf29e52 docs(cli): fix help text for --oidc-group-regex-filter (final correction)
Issue: Previous description incorrectly stated filter order

The correction commit stated 'This filter is applied after the group mapping'
but the actual code order in coderd/idpsync/group.go lines 379-398 shows:
1. Legacy group mappings
2. Regex filter
3. (New) group mapping

Since the filter order is complex and the description was causing confusion,
removed the last sentence entirely. The first two sentences clearly explain
what the flag does without introducing incorrect ordering claims.

This follows the verification report's recommendation to remove the
confusing last sentence.
2026-02-09 16:13:59 -05:00
Charlie Voiselle b1d2bb6d71 docs(cli): fix help text for --external-auth-providers
Issue: Clarity - vague description

Changed 'External Authentication providers.' to 'Configure external authentication providers for Git and other services.' to explain what these providers are actually used for.
2026-02-09 16:13:59 -05:00
Charlie Voiselle 94bad2a956 docs(cli): fix help text for --workspace-prebuilds-reconciliation-backoff-lookback-period
Issue: Clarity - unclear purpose

Changed 'Interval to look back to determine number of failed prebuilds, which influences backoff' to 'Time period to look back when counting failed prebuilds to calculate the backoff delay' to clarify this determines the time window for counting failures.
2026-02-09 16:13:59 -05:00
Charlie Voiselle 111714c7ed docs(cli): fix help text for --workspace-prebuilds-reconciliation-backoff-interval
Issue: Clarity - confusing wording about backoff behavior

Changed 'Interval to increase reconciliation backoff by when prebuilds fail, after which a retry attempt is made' to 'Amount of time to add to the reconciliation backoff delay after each prebuild failure, before the next retry attempt is made' to clarify this is an incremental addition to the backoff delay.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 1f9c516c5c docs(cli): fix help text for --workspace-prebuilds-failure-hard-limit
Issue: Clarity - unclear what 'hits the hard limit' means

Changed 'before a preset hits the hard limit' to 'before a preset is considered hard-limited and stops automatic prebuild creation' to explain what actually happens when the limit is reached.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 3645c65bb2 docs(cli): fix help text for --workspace-hostname-suffix
Issue: Clarity - incomplete example hostname

Changed 'in SSH config and Coder Connect on Coder Desktop' to 'for SSH connections and Coder Connect' for conciseness. Updated the example from 'myworkspace.coder' to the full format 'agent.workspace.owner.coder' to show the complete hostname structure.
2026-02-09 16:13:58 -05:00
Charlie Voiselle d3d2d2fb1e docs(cli): fix help text for --workspace-agent-logs-retention
Issue: Clarity - ambiguous scope

Changed 'Logs from the latest build are always retained' to 'Logs from the latest build for each workspace are always retained' to clarify that this applies per-workspace, not just one latest build globally.
2026-02-09 16:13:58 -05:00
Charlie Voiselle 086fb1f5d5 docs(cli): fix help text for --block-direct-connections
Issue: Clarity - imprecise wording about STUN behavior

Clarified that 'Workspace agents' (not 'Workspaces') reach out to STUN servers, changed 'get their address' to 'discover their address', and simplified 'until they are restarted after this change has been made' to just 'until they are restarted'.
2026-02-09 16:13:58 -05:00
Charlie Voiselle a73a535a5b docs(cli): fix help text for --proxy-health-interval
Issue: Clarity - awkward phrasing

Changed 'in which coderd should be checking' to 'at which coderd checks' for more concise, natural phrasing.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 96e01c3018 docs(cli): fix help text for --email-tls-cert-key-file
Issue: Clarity - vague description

Changed 'Certificate key file to use' to 'Private key file for the client certificate' to clarify this is the private key that pairs with --email-tls-cert-file.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 6b10a0359b docs(cli): fix help text for --email-tls-cert-file
Issue: Clarity - vague description

Changed 'Certificate file to use' to 'Client certificate file for mutual TLS authentication' to clarify what this certificate is for and when it's needed.
2026-02-09 16:13:57 -05:00
Charlie Voiselle b62583ad4b docs(cli): fix help text for --oidc-user-role-default
Issue: Clarity - ambiguous relationship between defaults and synced roles

Added 'in addition to synced roles' to clarify that these defaults don't replace synced roles. Also clarified that 'member' is always assigned 'regardless of this setting' to avoid confusion about whether this setting affects the member role.
2026-02-09 16:13:57 -05:00
Charlie Voiselle 3d6727a2cb docs(cli): fix help text for --oidc-group-field
Issue: Clarity - unclear structure

Reordered to put the primary purpose first: 'OIDC claim field to use as the user's groups' before the conditional requirement. This makes the description more scannable and understandable.
2026-02-09 16:13:56 -05:00
Charlie Voiselle b163962a14 docs(cli): fix help text for --aibridge-circuit-breaker-interval
Issue: Clarity - confusing technical jargon

Changed 'Cyclic period of the closed state for clearing internal failure counts' to 'Time window for counting failures before resetting the failure count in the closed state' to explain what the interval actually does in clearer terms.
2026-02-09 16:13:56 -05:00
Charlie Voiselle 9aca4ea27c docs(cli): fix help text for --aibridge-circuit-breaker-enabled
Issue: Clarity - ambiguous error code description

Changed '(429, 503, 529 overloaded)' to '(HTTP 429, 503, 529)' and added 'and overload errors' to clarify that these are HTTP status codes and what they represent.
2026-02-09 16:13:56 -05:00
Charlie Voiselle b0c10131ea docs(cli): fix help text for --aibridge-retention
Issue: Clarity - wordy phrasing

Simplified 'Length of time to retain data such as interceptions and all related records (token, prompt, tool use)' to 'How long to retain AI Bridge data including interceptions, tokens, prompts, and tool usage records' for more natural, clearer phrasing.
2026-02-09 16:13:56 -05:00
Charlie Voiselle c8c7e13e96 docs(cli): fix help text for --aibridge-inject-coder-mcp-tools
Issue: Clarity - awkward phrasing and formatting

Changed 'Whether to inject' to 'Enable injection of' for consistency with other boolean flags. Simplified the requirements clause and changed double quotes to single quotes for consistency.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 249b7ea38e docs(cli): fix help text for --aibridge-enabled
Issue: Clarity - unclear technical jargon

Changed 'Whether to start an in-memory aibridged instance' to 'Enable the embedded AI Bridge service to intercept and record AI provider requests' to explain what the feature actually does in user-friendly terms.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 1333096e25 docs(cli): fix help text for --oidc-group-regex-filter (correction)
Issue: Previous fix introduced confusing circular wording

The previous commit incorrectly changed the ending to 'after the group mapping and regex filter' which is nonsensical since this flag configures THE regex filter itself. Reverted to the correct wording: 'after the group mapping'.

The only valid changes from the original are:
- Added comma after 'If provided'
- Simplified 'allows for filtering' to 'allows filtering'
2026-02-09 16:13:55 -05:00
Charlie Voiselle 54bc9324dd docs(cli): fix help text for --samesite-auth-cookie
Issue: Grammar - missing word

Added missing 'if' to read 'Controls if the SameSite property is set' instead of 'Controls the SameSite property is set'.
2026-02-09 16:13:55 -05:00
Charlie Voiselle 109e5f2b19 docs(cli): fix help text for --enable-authz-recordings
Issue: Grammar - acronym capitalization

Capitalized 'API' (Application Programming Interface) - should always be uppercase.
2026-02-09 16:13:55 -05:00
Charlie Voiselle ee176b4207 docs(cli): fix help text for --ssh-config-options
Issue: Grammar - missing space after period

Added missing space after period between sentences: 'commas.' + 'Using' → 'commas. ' + 'Using'.
2026-02-09 16:13:54 -05:00
Charlie Voiselle 7e1e16be33 docs(cli): fix help text for --prometheus-address
Issue: Grammar - proper noun capitalization

Capitalized 'Prometheus' as it's a proper noun.
2026-02-09 16:13:54 -05:00
Charlie Voiselle 5cfe8082ce docs(cli): fix help text for --prometheus-enable
Issue: Grammar - proper noun capitalization

Capitalized 'Prometheus' as it's a proper noun (the name of the monitoring system).
2026-02-09 16:13:54 -05:00
Charlie Voiselle 6b7f672834 docs(cli): fix help text for --allow-custom-quiet-hours
Issue: Grammar - awkward phrasing

Changed 'for workspaces to stop in' to 'for when workspaces are stopped' for more natural phrasing.
2026-02-09 16:13:53 -05:00
Charlie Voiselle c55f6252a1 docs(cli): fix help text for --tls-client-ca-file
Issue: Grammar - missing article

Added missing article 'the' before 'client' to read 'authenticity of the client'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 842553b677 docs(cli): fix help text for --tls-ciphers
Issue: Grammar - missing verb

Fixed missing 'are' in 'that allowed to be used' → 'that are allowed to be used'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 05a771ba77 docs(cli): fix help text for --derp-server-stun-addresses
Issue: Grammar - incorrect possessive

Fixed "it's" (contraction of "it is") → "its" (possessive). Should be 'Each STUN server will get its own DERP region'.
2026-02-09 16:13:53 -05:00
Charlie Voiselle 70a0d42e65 docs(cli): fix help text for --derp-server-region-name
Issue: Grammar - malformed sentence

Fixed malformed sentence 'Region name that for' → 'Region name to use for'. The original was missing a verb.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 6b1d73b466 docs(cli): fix help text for --notifications-store-sync-buffer-size
Issue: Grammar - typo

Fixed typo: 'change' → 'chance'. Same typo as in --notifications-store-sync-interval.
2026-02-09 16:13:52 -05:00
Charlie Voiselle d7b9596145 docs(cli): fix help text for --notifications-store-sync-interval
Issue: Grammar - typo

Fixed typo: 'change' → 'chance'. The sentence should read 'the lower the chance of state inconsistency'.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 7a0aa1a40a docs(cli): fix help text for --oidc-signups-disabled-text
Issue: Grammar - awkward phrasing

Changed 'The custom text to show on the error page informing about disabled OIDC signups' to 'Custom text to show on the error page when OIDC signups are disabled' for clearer, more direct phrasing. Removed unnecessary 'The' article.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 4d8ea43e11 docs(cli): fix help text for --oidc-icon-url
Issue: Grammar - redundant phrasing

Changed 'URL pointing to the icon' to 'URL of the icon'. The phrase 'pointing to' is redundant since a URL inherently points to a resource.
2026-02-09 16:13:52 -05:00
Charlie Voiselle 6fddae98f6 docs(cli): fix help text for --oidc-group-regex-filter
Issue: Grammar - missing comma + simplification + filter order clarification

Added missing comma after 'If provided'. Simplified 'allows for filtering' to 'allows filtering'. Clarified filter order to match the actual implementation.
2026-02-09 16:13:51 -05:00
Charlie Voiselle e33fbb6087 docs(cli): fix help text for --oidc-group-mapping
Issue: Grammar - subject-verb agreement + awkward phrasing

Changed 'the group in Coder it should map to' to 'the groups in Coder they should map to' for proper plural agreement. Also simplified 'for when' to 'when'.
2026-02-09 16:13:51 -05:00
Charlie Voiselle 2337393e13 docs(cli): fix help text for --oidc-client-cert-file
Issue: Grammar - incorrect acronym capitalization

Changed 'Pem' to 'PEM', 'oauth2' to 'OAuth2', and 'x509' to 'X.509'. These are standard capitalizations for these acronyms and standards.
2026-02-09 16:13:51 -05:00
Charlie Voiselle d7357a1b0a docs(cli): fix help text for --oidc-client-key-file
Issue: Grammar - incorrect acronym capitalization

Changed 'Pem' to 'PEM' (Privacy Enhanced Mail), 'oauth2' to 'OAuth2', and 'IDP' to 'IdP' (Identity Provider). These are standard capitalizations for these acronyms.
2026-02-09 16:13:51 -05:00
Charlie Voiselle afbf1af29c docs(cli): fix help text for --oauth2-github-allow-everyone
Issue: Grammar - unclear and run-on sentence

Changed 'Allow all logins, setting this option means...' to 'Allow all GitHub users to authenticate. When enabled, allowed orgs and teams must be empty.' This separates the run-on sentence and clarifies what 'all logins' means (all GitHub users).
2026-02-09 16:13:50 -05:00
Charlie Voiselle 1d834c747c docs(cli): fix help text for --aibridge-circuit-breaker-failure-threshold
Issue: Grammar - subject-verb agreement

Changed 'triggers' to 'trigger' for correct subject-verb agreement. 'Number' is the subject, which takes the singular form, but 'failures' is the head of the relative clause 'that trigger...', making 'trigger' (plural) correct.
2026-02-09 16:13:50 -05:00
Charlie Voiselle a80edec752 docs(cli): fix help text for --aibridge-bedrock-access-key-secret
Issue: Grammar - wordy and redundant phrasing

Simplified from 'The access key secret to use with the access key to authenticate against' to 'AWS secret access key for authenticating with'. Uses standard AWS terminology and eliminates redundancy.
2026-02-09 16:13:50 -05:00
Charlie Voiselle 2a6473e8c6 docs(cli): fix help text for --aibridge-bedrock-access-key
Issue: Grammar - awkward phrasing

Changed 'The access key to authenticate against' to 'AWS access key for authenticating with' for consistency and clarity. Uses standard AWS terminology.
2026-02-09 16:13:50 -05:00
Charlie Voiselle 1f9c0b9b7f docs(cli): fix help text for --aibridge-anthropic-key
Issue: Grammar - awkward phrasing

Changed 'The key to authenticate against' to 'API key for authenticating with' for consistency with --aibridge-openai-key and more natural phrasing.
2026-02-09 16:13:49 -05:00
Charlie Voiselle 5494afabd8 docs(cli): fix help text for --aibridge-openai-key
Issue: Grammar - awkward phrasing

Changed 'The key to authenticate against' to 'API key for authenticating with' for more natural, concise phrasing. This matches standard API documentation conventions.
2026-02-09 16:13:49 -05:00
Charlie Voiselle 07c6e86a50 docs(cli): fix help text for --notifications-email-hello
Issue: Factually incorrect description of SMTP HELO/EHLO (deprecated alias)

Same issue as --email-hello. This is a deprecated alias but still needs the correct description. The HELO/EHLO command identifies the client to the server, not the server itself.

Fix: Clarified this identifies 'this client to the SMTP server'.
2026-02-09 16:13:49 -05:00
Charlie Voiselle b543821a1c docs(cli): fix help text for --email-hello
Issue: Factually incorrect description of SMTP HELO/EHLO

The description incorrectly stated this identifies 'the SMTP server' when it actually identifies the CLIENT to the server. The HELO/EHLO command is how the client introduces itself to the SMTP server during connection.

Fix: Clarified this identifies 'this client to the SMTP server' which accurately reflects the SMTP protocol.
2026-02-09 16:13:49 -05:00
Charlie Voiselle e8b7045a9b docs(cli): fix help text for --pprof-enable
Issue: Factually incorrect terminology

The description incorrectly stated pprof serves 'metrics' when it actually serves profiling data (CPU profiles, memory profiles, goroutines, etc.). Metrics are Prometheus's domain, not pprof's.

Fix: Changed 'metrics' to 'profiling endpoints' to accurately describe what pprof provides.
2026-02-09 16:13:48 -05:00
Charlie Voiselle 2571089528 docs(cli): fix help text for --oidc-user-role-mapping
Issue: Factually incorrect (confuses roles with groups) + grammar error

The description incorrectly stated this maps to 'groups in Coder' when it actually maps to site ROLES (member, admin, etc.). Also had a grammar error: 'will ignored' should be 'will be ignored'.

Fix: Corrected to clarify this maps OIDC role names to Coder role names, and fixed the grammar error.
2026-02-09 16:13:48 -05:00
Charlie Voiselle 1fb733fe1e docs(cli): fix help text for --oidc-allowed-groups
Issue: Factually incorrect filter order

The description incorrectly stated that the check is applied 'after the group mapping and before the regex filter'. This is wrong.

Fix: Updated to reflect actual behavior where the check is applied BEFORE any group mapping or filtering. Also clarified the positive case (users WITH at least one matching group are allowed) instead of the confusing double-negative phrasing.
2026-02-09 16:13:48 -05:00
886 changed files with 13255 additions and 81112 deletions
-4
View File
@@ -1,4 +0,0 @@
# All artifacts of the build processed are dumped here.
# Ignore it for docker context, as all Dockerfiles should build their own
# binaries.
build
+6 -3
View File
@@ -4,7 +4,10 @@ description: |
inputs:
version:
description: "The Go version to use."
default: "1.25.7"
default: "1.25.6"
use-preinstalled-go:
description: "Whether to use preinstalled Go."
default: "false"
use-cache:
description: "Whether to use the cache."
default: "true"
@@ -12,9 +15,9 @@ runs:
using: "composite"
steps:
- name: Setup Go
uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5.6.0
uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2
with:
go-version: ${{ inputs.version }}
go-version: ${{ inputs.use-preinstalled-go == 'false' && inputs.version || '' }}
cache: ${{ inputs.use-cache }}
- name: Install gotestsum
+1 -1
View File
@@ -7,5 +7,5 @@ runs:
- name: Install Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
with:
terraform_version: 1.14.5
terraform_version: 1.14.1
terraform_wrapper: false
+21 -25
View File
@@ -35,7 +35,7 @@ jobs:
tailnet-integration: ${{ steps.filter.outputs.tailnet-integration }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -157,7 +157,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -247,7 +247,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -272,7 +272,7 @@ jobs:
if: ${{ !cancelled() }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -329,7 +329,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -381,7 +381,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -422,6 +422,10 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
use-cache: true
- name: Setup Terraform
@@ -485,14 +489,6 @@ jobs:
# macOS will output "The default interactive shell is now zsh" intermittently in CI.
touch ~/.bash_profile && echo "export BASH_SILENCE_DEPRECATION_WARNING=1" >> ~/.bash_profile
- name: Increase PTY limit (macOS)
if: runner.os == 'macOS'
shell: bash
run: |
# Increase PTY limit to avoid exhaustion during tests.
# Default is 511; 999 is the maximum value on CI runner.
sudo sysctl -w kern.tty.ptmx_max=999
- name: Test with PostgreSQL Database (Linux)
if: runner.os == 'Linux'
uses: ./.github/actions/test-go-pg
@@ -582,7 +578,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -644,7 +640,7 @@ jobs:
timeout-minutes: 25
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -716,7 +712,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -743,7 +739,7 @@ jobs:
timeout-minutes: 20
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -776,7 +772,7 @@ jobs:
name: ${{ matrix.variant.name }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -856,7 +852,7 @@ jobs:
if: needs.changes.outputs.site == 'true' || needs.changes.outputs.ci == 'true'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -937,7 +933,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1009,7 +1005,7 @@ jobs:
if: always()
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1124,7 +1120,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1179,7 +1175,7 @@ jobs:
IMAGE: ghcr.io/coder/coder-preview:${{ steps.build-docker.outputs.tag }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -1576,7 +1572,7 @@ jobs:
if: needs.changes.outputs.db == 'true' || needs.changes.outputs.ci == 'true' || github.ref == 'refs/heads/main'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+3 -3
View File
@@ -36,7 +36,7 @@ jobs:
verdict: ${{ steps.check.outputs.verdict }} # DEPLOY or NOOP
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -65,7 +65,7 @@ jobs:
packages: write # to retag image as dogfood
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
needs: deploy
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+3 -3
View File
@@ -38,7 +38,7 @@ jobs:
if: github.repository_owner == 'coder'
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -58,11 +58,11 @@ jobs:
run: mkdir base-build-context
- name: Install depot.dev CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: wl5hnrrkns
context: base-build-context
+4 -4
View File
@@ -26,7 +26,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-4' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -75,7 +75,7 @@ jobs:
BRANCH_NAME: ${{ steps.branch-name.outputs.current_branch }}
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0
@@ -88,7 +88,7 @@ jobs:
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and push Non-Nix image
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: b4q6ltmpzh
token: ${{ secrets.DEPOT_TOKEN }}
@@ -125,7 +125,7 @@ jobs:
id-token: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+6 -1
View File
@@ -28,7 +28,7 @@ jobs:
- windows-2022
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -64,6 +64,11 @@ jobs:
- name: Setup Go
uses: ./.github/actions/setup-go
with:
# Runners have Go baked-in and Go will automatically
# download the toolchain configured in go.mod, so we don't
# need to reinstall it. It's faster on Windows runners.
use-preinstalled-go: ${{ runner.os == 'Windows' }}
- name: Setup Terraform
uses: ./.github/actions/setup-tf
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -19,7 +19,7 @@ jobs:
packages: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+5 -5
View File
@@ -39,7 +39,7 @@ jobs:
PR_OPEN: ${{ steps.check_pr.outputs.pr_open }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -76,7 +76,7 @@ jobs:
runs-on: "ubuntu-latest"
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -184,7 +184,7 @@ jobs:
pull-requests: write # needed for commenting on PRs
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -228,7 +228,7 @@ jobs:
CODER_IMAGE_TAG: ${{ needs.get_info.outputs.CODER_IMAGE_TAG }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -288,7 +288,7 @@ jobs:
PR_HOSTNAME: "pr${{ needs.get_info.outputs.PR_NUMBER }}.${{ secrets.PR_DEPLOYMENTS_DOMAIN }}"
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+6 -6
View File
@@ -158,7 +158,7 @@ jobs:
version: ${{ steps.version.outputs.version }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -386,12 +386,12 @@ jobs:
- name: Install depot.dev CLI
if: steps.image-base-tag.outputs.tag != ''
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
uses: depot/setup-action@b0b1ea4f69e92ebf5dea3f8713a1b0c37b2126a5 # v1.6.0
# This uses OIDC authentication, so no auth variables are required.
- name: Build base Docker image via depot.dev
if: steps.image-base-tag.outputs.tag != ''
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: depot/build-push-action@9785b135c3c76c33db102e45be96a25ab55cd507 # v1.16.2
with:
project: wl5hnrrkns
context: base-build-context
@@ -796,7 +796,7 @@ jobs:
# TODO: skip this if it's not a new release (i.e. a backport). This is
# fine right now because it just makes a PR that we can close.
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -872,7 +872,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -965,7 +965,7 @@ jobs:
if: ${{ !inputs.dry_run }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+3 -3
View File
@@ -27,7 +27,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -69,7 +69,7 @@ jobs:
runs-on: ${{ github.repository_owner == 'coder' && 'depot-ubuntu-22.04-8' || 'ubuntu-latest' }}
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -146,7 +146,7 @@ jobs:
echo "image=$(cat "$image_job")" >> "$GITHUB_OUTPUT"
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@c1824fd6edce30d7ab345a9989de00bbd46ef284 # v0.34.0
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8
with:
image-ref: ${{ steps.build.outputs.image }}
format: sarif
+4 -4
View File
@@ -18,12 +18,12 @@ jobs:
pull-requests: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
- name: stale
uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
stale-issue-label: "stale"
stale-pr-label: "stale"
@@ -96,7 +96,7 @@ jobs:
contents: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
@@ -120,7 +120,7 @@ jobs:
actions: write
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
+1 -1
View File
@@ -21,7 +21,7 @@ jobs:
pull-requests: write # required to post PR review comments by the action
steps:
- name: Harden Runner
uses: step-security/harden-runner@5ef0c079ce82195b2a36a210272d6b661572d83e # v2.14.2
uses: step-security/harden-runner@e3f713f2d8f53843e71c69a996d56f51aa9adfb9 # v2.14.1
with:
egress-policy: audit
-3
View File
@@ -98,6 +98,3 @@ AGENTS.local.md
# Ignore plans written by AI agents.
PLAN.md
# Ignore any dev licenses
license.txt
+3 -1
View File
@@ -198,7 +198,9 @@ reviewer time and clutters the diff.
**Don't delete existing comments** that explain non-obvious behavior. These
comments preserve important context about why code works a certain way.
**When adding tests for new behavior**, read existing tests first to understand what's covered. Add new cases for uncovered behavior. Edit existing tests as needed, but don't change what they verify.
**When adding tests for new behavior**, add new test cases instead of modifying
existing ones. This preserves coverage for the original behavior and makes it
clear what the new test covers.
## Detailed Development Guides
+9 -20
View File
@@ -854,7 +854,7 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go')
# -C sets the directory for the go run command
go run -C ./scripts/apitypings main.go > $@
./scripts/biome_format.sh src/api/typesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/typesGenerated.ts)
touch "$@"
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
@@ -863,7 +863,7 @@ site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/prot
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*)
go run ./scripts/gensite/ -icons "$@"
./scripts/biome_format.sh src/theme/icons.json
(cd site/ && pnpm exec biome format --write src/theme/icons.json)
touch "$@"
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates)
@@ -901,18 +901,15 @@ codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scope
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go
go run scripts/typegen/main.go rbac typescript > "$@"
./scripts/biome_format.sh src/api/rbacresourcesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/rbacresourcesGenerated.ts)
touch "$@"
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go
go run scripts/typegen/main.go countries > "$@"
./scripts/biome_format.sh src/api/countriesGenerated.ts
(cd site/ && pnpm exec biome format --write src/api/countriesGenerated.ts)
touch "$@"
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES)
go run ./scripts/metricsdocgen/scanner > $@
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics
go run scripts/metricsdocgen/main.go
pnpm exec markdownlint-cli2 --fix ./docs/admin/integrations/prometheus.md
pnpm exec markdown-table-formatter ./docs/admin/integrations/prometheus.md
@@ -950,11 +947,11 @@ coderd/apidoc/.gen: \
touch "$@"
docs/manifest.json: site/node_modules/.installed coderd/apidoc/.gen docs/reference/cli/index.md
./scripts/biome_format.sh ../docs/manifest.json
(cd site/ && pnpm exec biome format --write ../docs/manifest.json)
touch "$@"
coderd/apidoc/swagger.json: site/node_modules/.installed coderd/apidoc/.gen
./scripts/biome_format.sh ../coderd/apidoc/swagger.json
(cd site/ && pnpm exec biome format --write ../coderd/apidoc/swagger.json)
touch "$@"
update-golden-files:
@@ -999,19 +996,11 @@ enterprise/tailnet/testdata/.gen-golden: $(wildcard enterprise/tailnet/testdata/
touch "$@"
helm/coder/tests/testdata/.gen-golden: $(wildcard helm/coder/tests/testdata/*.yaml) $(wildcard helm/coder/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/coder/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/coder golden generation" >&2
fi
TZ=UTC go test ./helm/coder/tests -run=TestUpdateGoldenFiles -update
touch "$@"
helm/provisioner/tests/testdata/.gen-golden: $(wildcard helm/provisioner/tests/testdata/*.yaml) $(wildcard helm/provisioner/tests/testdata/*.golden) $(GO_SRC_FILES) $(wildcard helm/provisioner/tests/*_test.go)
if command -v helm >/dev/null 2>&1; then
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
else
echo "WARNING: helm not found; skipping helm/provisioner golden generation" >&2
fi
TZ=UTC go test ./helm/provisioner/tests -run=TestUpdateGoldenFiles -update
touch "$@"
coderd/.gen-golden: $(wildcard coderd/testdata/*/*.golden) $(GO_SRC_FILES) $(wildcard coderd/*_test.go)
+3 -18
View File
@@ -41,7 +41,6 @@ import (
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentfiles"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/agent/agentscripts"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agentssh"
@@ -112,12 +111,6 @@ type Client interface {
ConnectRPC28(ctx context.Context) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
// ConnectRPC28WithRole is like ConnectRPC28 but sends an explicit
// role query parameter to the server. The workspace agent should
// use role "agent" to enable connection monitoring.
ConnectRPC28WithRole(ctx context.Context, role string) (
proto.DRPCAgentClient28, tailnetproto.DRPCTailnetClient28, error,
)
tailnet.DERPMapRewriter
agentsdk.RefreshableSessionTokenProvider
}
@@ -303,8 +296,7 @@ type agent struct {
containerAPIOptions []agentcontainers.Option
containerAPI *agentcontainers.API
filesAPI *agentfiles.API
processAPI *agentproc.API
filesAPI *agentfiles.API
socketServerEnabled bool
socketPath string
@@ -377,7 +369,6 @@ func (a *agent) init() {
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
a.filesAPI = agentfiles.NewAPI(a.logger.Named("files"), a.filesystem)
a.processAPI = agentproc.NewAPI(a.logger.Named("processes"), a.execer, a.updateCommandEnv)
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
@@ -1006,10 +997,8 @@ func (a *agent) run() (retErr error) {
return xerrors.Errorf("refresh token: %w", err)
}
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs.
// We pass role "agent" to enable connection monitoring on the server, which tracks
// the agent's connectivity state (first_connected_at, last_connected_at, disconnected_at).
aAPI, tAPI, err := a.client.ConnectRPC28WithRole(a.hardCtx, "agent")
// ConnectRPC returns the dRPC connection we use for the Agent and Tailnet v2+ APIs
aAPI, tAPI, err := a.client.ConnectRPC28(a.hardCtx)
if err != nil {
return err
}
@@ -2033,10 +2022,6 @@ func (a *agent) Close() error {
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
}
if err := a.processAPI.Close(); err != nil {
a.logger.Error(a.hardCtx, "process API close", slog.Error(err))
}
if a.boundaryLogProxy != nil {
err = a.boundaryLogProxy.Close()
if err != nil {
-1
View File
@@ -29,7 +29,6 @@ func (api *API) Routes() http.Handler {
r.Post("/list-directory", api.HandleLS)
r.Get("/read-file", api.HandleReadFile)
r.Get("/read-file-lines", api.HandleReadFileLines)
r.Post("/write-file", api.HandleWriteFile)
r.Post("/edit-files", api.HandleEditFiles)
+7 -283
View File
@@ -10,10 +10,11 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"github.com/icholy/replace"
"github.com/spf13/afero"
"golang.org/x/text/transform"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -22,22 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// ReadFileLinesResponse is the JSON response for the line-based file reader.
type ReadFileLinesResponse struct {
// Success indicates whether the read was successful.
Success bool `json:"success"`
// FileSize is the original file size in bytes.
FileSize int64 `json:"file_size,omitempty"`
// TotalLines is the total number of lines in the file.
TotalLines int `json:"total_lines,omitempty"`
// LinesRead is the count of lines returned in this response.
LinesRead int `json:"lines_read,omitempty"`
// Content is the line-numbered file content.
Content string `json:"content,omitempty"`
// Error is the error message when success is false.
Error string `json:"error,omitempty"`
}
type HTTPResponseCode = int
func (api *API) HandleReadFile(rw http.ResponseWriter, r *http.Request) {
@@ -118,166 +103,6 @@ func (api *API) streamFile(ctx context.Context, rw http.ResponseWriter, path str
return 0, nil
}
func (api *API) HandleReadFileLines(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
parser := httpapi.NewQueryParamParser().RequiredNotEmpty("path")
path := parser.String(query, "", "path")
offset := parser.PositiveInt64(query, 1, "offset")
limit := parser.PositiveInt64(query, 0, "limit")
maxFileSize := parser.PositiveInt64(query, workspacesdk.DefaultMaxFileSize, "max_file_size")
maxLineBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxLineBytes, "max_line_bytes")
maxResponseLines := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseLines, "max_response_lines")
maxResponseBytes := parser.PositiveInt64(query, workspacesdk.DefaultMaxResponseBytes, "max_response_bytes")
parser.ErrorExcessParams(query)
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
resp := api.readFileLines(ctx, path, offset, limit, workspacesdk.ReadFileLinesLimits{
MaxFileSize: maxFileSize,
MaxLineBytes: int(maxLineBytes),
MaxResponseLines: int(maxResponseLines),
MaxResponseBytes: int(maxResponseBytes),
})
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) readFileLines(_ context.Context, path string, offset, limit int64, limits workspacesdk.ReadFileLinesLimits) ReadFileLinesResponse {
errResp := func(msg string) ReadFileLinesResponse {
return ReadFileLinesResponse{Success: false, Error: msg}
}
if !filepath.IsAbs(path) {
return errResp(fmt.Sprintf("file path must be absolute: %q", path))
}
f, err := api.filesystem.Open(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return errResp(fmt.Sprintf("file does not exist: %s", path))
}
if errors.Is(err, os.ErrPermission) {
return errResp(fmt.Sprintf("permission denied: %s", path))
}
return errResp(fmt.Sprintf("open file: %s", err))
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return errResp(fmt.Sprintf("stat file: %s", err))
}
if stat.IsDir() {
return errResp(fmt.Sprintf("not a file: %s", path))
}
fileSize := stat.Size()
if fileSize > limits.MaxFileSize {
return errResp(fmt.Sprintf(
"file is %d bytes which exceeds the maximum of %d bytes. Use grep, sed, or awk to extract the content you need, or use offset and limit to read a portion.",
fileSize, limits.MaxFileSize,
))
}
// Read the entire file (up to MaxFileSize).
data, err := io.ReadAll(f)
if err != nil {
return errResp(fmt.Sprintf("read file: %s", err))
}
// Split into lines.
content := string(data)
// Handle empty file.
if content == "" {
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: 0,
LinesRead: 0,
Content: "",
}
}
lines := strings.Split(content, "\n")
totalLines := len(lines)
// offset is 1-based line number.
if offset < 1 {
offset = 1
}
if offset > int64(totalLines) {
return errResp(fmt.Sprintf(
"offset %d is beyond the file length of %d lines",
offset, totalLines,
))
}
// Default limit.
if limit <= 0 {
limit = int64(limits.MaxResponseLines)
}
startIdx := int(offset - 1) // convert to 0-based
endIdx := startIdx + int(limit)
if endIdx > totalLines {
endIdx = totalLines
}
var numbered []string
totalBytesAccumulated := 0
for i := startIdx; i < endIdx; i++ {
line := lines[i]
// Per-line truncation.
if len(line) > limits.MaxLineBytes {
line = line[:limits.MaxLineBytes] + "... [truncated]"
}
// Format with 1-based line number.
numberedLine := fmt.Sprintf("%d\t%s", i+1, line)
lineBytes := len(numberedLine)
// Check total byte budget.
newTotal := totalBytesAccumulated + lineBytes
if len(numbered) > 0 {
newTotal++ // account for \n joiner
}
if newTotal > limits.MaxResponseBytes {
return errResp(fmt.Sprintf(
"output would exceed %d bytes. Read less at a time using offset and limit parameters.",
limits.MaxResponseBytes,
))
}
// Check line count.
if len(numbered) >= limits.MaxResponseLines {
return errResp(fmt.Sprintf(
"output would exceed %d lines. Read less at a time using offset and limit parameters.",
limits.MaxResponseLines,
))
}
numbered = append(numbered, numberedLine)
totalBytesAccumulated = newTotal
}
return ReadFileLinesResponse{
Success: true,
FileSize: fileSize,
TotalLines: totalLines,
LinesRead: len(numbered),
Content: strings.Join(numbered, "\n"),
}
}
func (api *API) HandleWriteFile(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@@ -420,21 +245,9 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return http.StatusBadRequest, xerrors.Errorf("open %s: not a file", path)
}
data, err := io.ReadAll(f)
if err != nil {
return http.StatusInternalServerError, xerrors.Errorf("read %s: %w", path, err)
}
content := string(data)
for _, edit := range edits {
var ok bool
content, ok = fuzzyReplace(content, edit.Search, edit.Replace)
if !ok {
api.logger.Warn(ctx, "edit search string not found, skipping",
slog.F("path", path),
slog.F("search_preview", truncate(edit.Search, 64)),
)
}
transforms := make([]transform.Transformer, len(edits))
for i, edit := range edits {
transforms[i] = replace.String(edit.Search, edit.Replace)
}
// Create an adjacent file to ensure it will be on the same device and can be
@@ -445,7 +258,8 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
}
defer tmpfile.Close()
if _, err := tmpfile.Write([]byte(content)); err != nil {
_, err = io.Copy(tmpfile, replace.Chain(f, transforms...))
if err != nil {
if rerr := api.filesystem.Remove(tmpfile.Name()); rerr != nil {
api.logger.Warn(ctx, "unable to clean up temp file", slog.Error(rerr))
}
@@ -459,93 +273,3 @@ func (api *API) editFile(ctx context.Context, path string, edits []workspacesdk.
return 0, nil
}
// fuzzyReplace attempts to find `search` inside `content` and replace its first
// occurrence with `replace`. It uses a cascading match strategy inspired by
// openai/codex's apply_patch:
//
// 1. Exact substring match (byte-for-byte).
// 2. Line-by-line match ignoring trailing whitespace on each line.
// 3. Line-by-line match ignoring all leading/trailing whitespace (indentation-tolerant).
//
// When a fuzzy match is found (passes 2 or 3), the replacement is still applied
// at the byte offsets of the original content so that surrounding text (including
// indentation of untouched lines) is preserved.
//
// Returns the (possibly modified) content and a bool indicating whether a match
// was found.
func fuzzyReplace(content, search, replace string) (string, bool) {
// Pass 1 exact substring (replace all occurrences).
if strings.Contains(content, search) {
return strings.ReplaceAll(content, search, replace), true
}
// For line-level fuzzy matching we split both content and search into lines.
contentLines := strings.SplitAfter(content, "\n")
searchLines := strings.SplitAfter(search, "\n")
// A trailing newline in the search produces an empty final element from
// SplitAfter. Drop it so it doesn't interfere with line matching.
if len(searchLines) > 0 && searchLines[len(searchLines)-1] == "" {
searchLines = searchLines[:len(searchLines)-1]
}
// Pass 2 trim trailing whitespace on each line.
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimRight(a, " \t\r\n") == strings.TrimRight(b, " \t\r\n")
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
// Pass 3 trim all leading and trailing whitespace (indentation-tolerant).
if start, end, ok := seekLines(contentLines, searchLines, func(a, b string) bool {
return strings.TrimSpace(a) == strings.TrimSpace(b)
}); ok {
return spliceLines(contentLines, start, end, replace), true
}
return content, false
}
// seekLines scans contentLines looking for a contiguous subsequence that matches
// searchLines according to the provided `eq` function. It returns the start and
// end (exclusive) indices into contentLines of the match.
func seekLines(contentLines, searchLines []string, eq func(a, b string) bool) (start, end int, ok bool) {
if len(searchLines) == 0 {
return 0, 0, true
}
if len(searchLines) > len(contentLines) {
return 0, 0, false
}
outer:
for i := 0; i <= len(contentLines)-len(searchLines); i++ {
for j, sLine := range searchLines {
if !eq(contentLines[i+j], sLine) {
continue outer
}
}
return i, i + len(searchLines), true
}
return 0, 0, false
}
// spliceLines replaces contentLines[start:end] with replacement text, returning
// the full content as a single string.
func spliceLines(contentLines []string, start, end int, replacement string) string {
var b strings.Builder
for _, l := range contentLines[:start] {
_, _ = b.WriteString(l)
}
_, _ = b.WriteString(replacement)
for _, l := range contentLines[end:] {
_, _ = b.WriteString(l)
}
return b.String()
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
-285
View File
@@ -649,106 +649,6 @@ func TestEditFiles(t *testing.T) {
filepath.Join(tmpdir, "file3"): "edited3 3",
},
},
{
name: "TrailingWhitespace",
contents: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "foo \nbar\t\t\nbaz"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "trailing-ws"),
Edits: []workspacesdk.FileEdit{
{
Search: "foo\nbar\nbaz",
Replace: "replaced",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "trailing-ws"): "replaced"},
},
{
name: "TabsVsSpaces",
contents: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tfoo()\n\t}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "tabs-vs-spaces"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces but file uses tabs.
Search: " if true {\n foo()\n }",
Replace: "\tif true {\n\t\tbar()\n\t}",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "tabs-vs-spaces"): "\tif true {\n\t\tbar()\n\t}"},
},
{
name: "DifferentIndentDepth",
contents: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tnested()"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "indent-depth"),
Edits: []workspacesdk.FileEdit{
{
// Search has wrong indent depth (1 tab instead of 3).
Search: "\tdeep()\n\tnested()",
Replace: "\t\t\tdeep()\n\t\t\tchanged()",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "indent-depth"): "\t\t\tdeep()\n\t\t\tchanged()"},
},
{
name: "ExactMatchPreferred",
contents: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "hello world"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "exact-preferred"),
Edits: []workspacesdk.FileEdit{
{
Search: "hello world",
Replace: "goodbye world",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "exact-preferred"): "goodbye world"},
},
{
name: "NoMatchStillSucceeds",
contents: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "no-match"),
Edits: []workspacesdk.FileEdit{
{
Search: "this does not exist in the file",
Replace: "whatever",
},
},
},
},
// File should remain unchanged.
expected: map[string]string{filepath.Join(tmpdir, "no-match"): "original content"},
},
{
name: "MixedWhitespaceMultiline",
contents: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tfmt.Println(result)\n}"},
edits: []workspacesdk.FileEdits{
{
Path: filepath.Join(tmpdir, "mixed-ws"),
Edits: []workspacesdk.FileEdit{
{
// Search uses spaces, file uses tabs.
Search: " result := compute()\n fmt.Println(result)\n",
Replace: "\tresult := compute()\n\tlog.Println(result)\n",
},
},
},
},
expected: map[string]string{filepath.Join(tmpdir, "mixed-ws"): "func main() {\n\tresult := compute()\n\tlog.Println(result)\n}"},
},
{
name: "MultiError",
contents: map[string]string{
@@ -837,188 +737,3 @@ func TestEditFiles(t *testing.T) {
})
}
}
func TestReadFileLines(t *testing.T) {
t.Parallel()
tmpdir := os.TempDir()
noPermsFilePath := filepath.Join(tmpdir, "no-perms-lines")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
fs := newTestFs(afero.NewMemMapFs(), func(call, file string) error {
if file == noPermsFilePath {
return os.ErrPermission
}
return nil
})
api := agentfiles.NewAPI(logger, fs)
dirPath := filepath.Join(tmpdir, "a-directory-lines")
err := fs.MkdirAll(dirPath, 0o755)
require.NoError(t, err)
emptyFilePath := filepath.Join(tmpdir, "empty-file")
err = afero.WriteFile(fs, emptyFilePath, []byte(""), 0o644)
require.NoError(t, err)
basicFilePath := filepath.Join(tmpdir, "basic-file")
err = afero.WriteFile(fs, basicFilePath, []byte("line1\nline2\nline3"), 0o644)
require.NoError(t, err)
longLine := string(bytes.Repeat([]byte("x"), 1025))
longLineFilePath := filepath.Join(tmpdir, "long-line-file")
err = afero.WriteFile(fs, longLineFilePath, []byte(longLine), 0o644)
require.NoError(t, err)
largeFilePath := filepath.Join(tmpdir, "large-file")
err = afero.WriteFile(fs, largeFilePath, bytes.Repeat([]byte("x"), 1<<20+1), 0o644)
require.NoError(t, err)
tests := []struct {
name string
path string
offset int64
limit int64
expSuccess bool
expError string
expContent string
expTotal int
expRead int
expSize int64
// useCodersdk is set for cases where the handler returns
// codersdk.Response (query param validation) instead of ReadFileLinesResponse.
useCodersdk bool
}{
{
name: "NoPath",
path: "",
useCodersdk: true,
expError: "is required",
},
{
name: "RelativePath",
path: "relative/path",
expError: "file path must be absolute",
},
{
name: "NonExistent",
path: filepath.Join(tmpdir, "does-not-exist"),
expError: "file does not exist",
},
{
name: "IsDir",
path: dirPath,
expError: "not a file",
},
{
name: "NoPermissions",
path: noPermsFilePath,
expError: "permission denied",
},
{
name: "EmptyFile",
path: emptyFilePath,
expSuccess: true,
expTotal: 0,
expRead: 0,
expSize: 0,
},
{
name: "BasicRead",
path: basicFilePath,
expSuccess: true,
expContent: "1\tline1\n2\tline2\n3\tline3",
expTotal: 3,
expRead: 3,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2",
path: basicFilePath,
offset: 2,
expSuccess: true,
expContent: "2\tline2\n3\tline3",
expTotal: 3,
expRead: 2,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Limit1",
path: basicFilePath,
limit: 1,
expSuccess: true,
expContent: "1\tline1",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "Offset2Limit1",
path: basicFilePath,
offset: 2,
limit: 1,
expSuccess: true,
expContent: "2\tline2",
expTotal: 3,
expRead: 1,
expSize: int64(len("line1\nline2\nline3")),
},
{
name: "OffsetBeyondFile",
path: basicFilePath,
offset: 100,
expError: "offset 100 is beyond the file length of 3 lines",
},
{
name: "LongLineTruncation",
path: longLineFilePath,
expSuccess: true,
expContent: "1\t" + string(bytes.Repeat([]byte("x"), 1024)) + "... [truncated]",
expTotal: 1,
expRead: 1,
expSize: 1025,
},
{
name: "LargeFile",
path: largeFilePath,
expError: "exceeds the maximum",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/read-file-lines?path=%s&offset=%d&limit=%d", tt.path, tt.offset, tt.limit), nil)
api.Routes().ServeHTTP(w, r)
if tt.useCodersdk {
// Query param validation errors return codersdk.Response.
require.Equal(t, http.StatusBadRequest, w.Code)
require.Contains(t, w.Body.String(), tt.expError)
return
}
var resp agentfiles.ReadFileLinesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if tt.expSuccess {
require.Equal(t, http.StatusOK, w.Code)
require.True(t, resp.Success)
require.Equal(t, tt.expContent, resp.Content)
require.Equal(t, tt.expTotal, resp.TotalLines)
require.Equal(t, tt.expRead, resp.LinesRead)
require.Equal(t, tt.expSize, resp.FileSize)
} else {
require.Equal(t, http.StatusOK, w.Code)
require.False(t, resp.Success)
require.Contains(t, resp.Error, tt.expError)
}
})
}
}
-175
View File
@@ -1,175 +0,0 @@
package agentproc
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// API exposes process-related operations through the agent.
type API struct {
logger slog.Logger
manager *manager
}
// NewAPI creates a new process API handler.
func NewAPI(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *API {
return &API{
logger: logger,
manager: newManager(logger, execer, updateEnv),
}
}
// Close shuts down the process manager, killing all running
// processes.
func (api *API) Close() error {
return api.manager.Close()
}
// Routes returns the HTTP handler for process-related routes.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()
r.Post("/start", api.handleStartProcess)
r.Get("/list", api.handleListProcesses)
r.Get("/{id}/output", api.handleProcessOutput)
r.Post("/{id}/signal", api.handleSignalProcess)
return r
}
// handleStartProcess starts a new process.
func (api *API) handleStartProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req workspacesdk.StartProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Command == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Command is required.",
})
return
}
proc, err := api.manager.start(req)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to start process.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.StartProcessResponse{
ID: proc.id,
Started: true,
})
}
// handleListProcesses lists all tracked processes.
func (api *API) handleListProcesses(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
infos := api.manager.list()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ListProcessesResponse{
Processes: infos,
})
}
// handleProcessOutput returns the output of a process.
func (api *API) handleProcessOutput(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
proc, ok := api.manager.get(id)
if !ok {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
return
}
output, truncated := proc.output()
info := proc.info()
httpapi.Write(ctx, rw, http.StatusOK, workspacesdk.ProcessOutputResponse{
Output: output,
Truncated: truncated,
Running: info.Running,
ExitCode: info.ExitCode,
})
}
// handleSignalProcess sends a signal to a running process.
func (api *API) handleSignalProcess(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
var req workspacesdk.SignalProcessRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Request body must be valid JSON.",
Detail: err.Error(),
})
return
}
if req.Signal == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Signal is required.",
})
return
}
if req.Signal != "kill" && req.Signal != "terminate" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf(
"Unsupported signal %q. Use \"kill\" or \"terminate\".",
req.Signal,
),
})
return
}
if err := api.manager.signal(id, req.Signal); err != nil {
switch {
case errors.Is(err, errProcessNotFound):
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
Message: fmt.Sprintf("Process %q not found.", id),
})
case errors.Is(err, errProcessNotRunning):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf(
"Process %q is not running.", id,
),
})
default:
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to signal process.",
Detail: err.Error(),
})
}
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{
Message: fmt.Sprintf(
"Signal %q sent to process %q.", req.Signal, id,
),
})
}
-691
View File
@@ -1,691 +0,0 @@
package agentproc_test
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/agentproc"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/testutil"
)
// postStart sends a POST /start request and returns the recorder.
func postStart(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// getList sends a GET /list request and returns the recorder.
func getList(t *testing.T, handler http.Handler) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/list", nil)
handler.ServeHTTP(w, r)
return w
}
// getOutput sends a GET /{id}/output request and returns the
// recorder.
func getOutput(t *testing.T, handler http.Handler, id string) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("/%s/output", id), nil)
handler.ServeHTTP(w, r)
return w
}
// postSignal sends a POST /{id}/signal request and returns
// the recorder.
func postSignal(t *testing.T, handler http.Handler, id string, req workspacesdk.SignalProcessRequest) *httptest.ResponseRecorder {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
body, err := json.Marshal(req)
require.NoError(t, err)
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("/%s/signal", id), bytes.NewReader(body))
handler.ServeHTTP(w, r)
return w
}
// newTestAPI creates a new API with a test logger and default
// execer, returning the handler and API.
func newTestAPI(t *testing.T) http.Handler {
t.Helper()
return newTestAPIWithUpdateEnv(t, nil)
}
// newTestAPIWithUpdateEnv creates a new API with an optional
// updateEnv hook for testing environment injection.
func newTestAPIWithUpdateEnv(t *testing.T, updateEnv func([]string) ([]string, error)) http.Handler {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
api := agentproc.NewAPI(logger, agentexec.DefaultExecer, updateEnv)
t.Cleanup(func() {
_ = api.Close()
})
return api.Routes()
}
// waitForExit polls the output endpoint until the process is
// no longer running or the context expires.
func waitForExit(t *testing.T, handler http.Handler, id string) workspacesdk.ProcessOutputResponse {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.Fatal("timed out waiting for process to exit")
case <-ticker.C:
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
if !resp.Running {
return resp
}
}
}
}
// startAndGetID is a helper that starts a process and returns
// the process ID.
func startAndGetID(t *testing.T, handler http.Handler, req workspacesdk.StartProcessRequest) string {
t.Helper()
w := postStart(t, handler, req)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
return resp.ID
}
func TestStartProcess(t *testing.T) {
t.Parallel()
t.Run("ForegroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello",
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("BackgroundCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "echo background",
Background: true,
})
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.StartProcessResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Started)
require.NotEmpty(t, resp.ID)
})
t.Run("EmptyCommand", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postStart(t, handler, workspacesdk.StartProcessRequest{
Command: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Command is required")
})
t.Run("MalformedJSON", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
w := httptest.NewRecorder()
r := httptest.NewRequestWithContext(ctx, http.MethodPost, "/start", strings.NewReader("{invalid json"))
handler.ServeHTTP(w, r)
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "valid JSON")
})
t.Run("CustomWorkDir", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
tmpDir := t.TempDir()
// Write a marker file to verify the command ran in
// the correct directory. Comparing pwd output is
// unreliable on Windows where Git Bash returns POSIX
// paths.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "touch marker.txt && ls marker.txt",
WorkDir: tmpDir,
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "marker.txt")
})
t.Run("CustomEnv", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Use a unique env var name to avoid collisions in
// parallel tests.
envKey := fmt.Sprintf("TEST_PROC_ENV_%d", time.Now().UnixNano())
envVal := "custom_value_12345"
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: envVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHook", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_UPDATE_ENV_%d", time.Now().UnixNano())
envVal := "injected_by_hook"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, envVal)), nil
})
// The process should see the variable even though it
// was not passed in req.Env.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, strings.TrimSpace(resp.Output), envVal)
})
t.Run("UpdateEnvHookOverriddenByReqEnv", func(t *testing.T) {
t.Parallel()
envKey := fmt.Sprintf("TEST_OVERRIDE_%d", time.Now().UnixNano())
hookVal := "from_hook"
reqVal := "from_request"
handler := newTestAPIWithUpdateEnv(t, func(current []string) ([]string, error) {
return append(current, fmt.Sprintf("%s=%s", envKey, hookVal)), nil
})
// req.Env should take precedence over the hook.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: fmt.Sprintf("printenv %s", envKey),
Env: map[string]string{envKey: reqVal},
})
resp := waitForExit(t, handler, id)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// When duplicate env vars exist, shells use the last
// value. Since req.Env is appended after the hook,
// the request value wins.
require.Contains(t, strings.TrimSpace(resp.Output), reqVal)
})
}
func TestListProcesses(t *testing.T) {
t.Parallel()
t.Run("NoProcesses", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.NotNil(t, resp.Processes)
require.Empty(t, resp.Processes)
})
t.Run("MixedRunningAndExited", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a process that exits quickly.
exitedID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
waitForExit(t, handler, exitedID)
// Start a long-running process.
runningID := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// List should contain both.
w := getList(t, handler)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ListProcessesResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Len(t, resp.Processes, 2)
procMap := make(map[string]workspacesdk.ProcessInfo)
for _, p := range resp.Processes {
procMap[p.ID] = p
}
exited, ok := procMap[exitedID]
require.True(t, ok, "exited process should be in list")
require.False(t, exited.Running)
require.NotNil(t, exited.ExitCode)
running, ok := procMap[runningID]
require.True(t, ok, "running process should be in list")
require.True(t, running.Running)
// Clean up the long-running process.
sw := postSignal(t, handler, runningID, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
})
}
func TestProcessOutput(t *testing.T) {
t.Parallel()
t.Run("ExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo hello-output",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "hello-output")
})
t.Run("RunningProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var resp workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.True(t, resp.Running)
// Kill and wait for the process so cleanup does
// not hang.
postSignal(
t, handler, id,
workspacesdk.SignalProcessRequest{Signal: "kill"},
)
waitForExit(t, handler, id)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := getOutput(t, handler, "nonexistent-id-12345")
require.Equal(t, http.StatusNotFound, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "not found")
})
}
func TestSignalProcess(t *testing.T) {
t.Parallel()
t.Run("KillRunning", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("TerminateRunning", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("SIGTERM is not supported on Windows")
}
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "terminate",
})
require.Equal(t, http.StatusOK, w.Code)
// Verify the process exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
})
t.Run("NonexistentProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
w := postSignal(t, handler, "nonexistent-id-12345", workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("AlreadyExitedProcess", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo done",
})
// Wait for exit first.
waitForExit(t, handler, id)
// Signaling an exited process should return 409
// Conflict via the errProcessNotRunning sentinel.
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
assert.Equal(t, http.StatusConflict, w.Code,
"expected 409 for signaling exited process, got %d", w.Code)
})
t.Run("EmptySignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Signal is required")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
t.Run("InvalidSignal", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
w := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "SIGFOO",
})
require.Equal(t, http.StatusBadRequest, w.Code)
var resp codersdk.Response
err := json.NewDecoder(w.Body).Decode(&resp)
require.NoError(t, err)
require.Contains(t, resp.Message, "Unsupported signal")
// Clean up.
postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
})
}
func TestProcessLifecycle(t *testing.T) {
t.Parallel()
t.Run("StartWaitCheckOutput", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo lifecycle-test && echo second-line",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
require.Contains(t, resp.Output, "lifecycle-test")
require.Contains(t, resp.Output, "second-line")
})
t.Run("NonZeroExitCode", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "exit 42",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 42, *resp.ExitCode)
})
t.Run("StartSignalVerifyExit", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Start a long-running background process.
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "sleep 300",
Background: true,
})
// Verify it's running.
w := getOutput(t, handler, id)
require.Equal(t, http.StatusOK, w.Code)
var running workspacesdk.ProcessOutputResponse
err := json.NewDecoder(w.Body).Decode(&running)
require.NoError(t, err)
require.True(t, running.Running)
// Signal it.
sw := postSignal(t, handler, id, workspacesdk.SignalProcessRequest{
Signal: "kill",
})
require.Equal(t, http.StatusOK, sw.Code)
// Verify it exits.
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
})
t.Run("OutputExceedsBuffer", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
// Generate output that exceeds MaxHeadBytes +
// MaxTailBytes. Each line is ~100 chars, and we
// need more than 32KB total (16KB head + 16KB
// tail).
lineCount := (agentproc.MaxHeadBytes+agentproc.MaxTailBytes)/50 + 500
cmd := fmt.Sprintf(
"for i in $(seq 1 %d); do echo \"line-$i-padding-to-make-this-longer-than-fifty-characters-total\"; done",
lineCount,
)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: cmd,
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// The output should be truncated with head/tail
// strategy metadata.
require.NotNil(t, resp.Truncated, "large output should be truncated")
require.Equal(t, "head_tail", resp.Truncated.Strategy)
require.Greater(t, resp.Truncated.OmittedBytes, 0)
require.Greater(t, resp.Truncated.OriginalBytes, resp.Truncated.RetainedBytes)
// Verify the output contains the omission marker.
require.Contains(t, resp.Output, "... [omitted")
})
t.Run("StderrCaptured", func(t *testing.T) {
t.Parallel()
handler := newTestAPI(t)
id := startAndGetID(t, handler, workspacesdk.StartProcessRequest{
Command: "echo stdout-msg && echo stderr-msg >&2",
})
resp := waitForExit(t, handler, id)
require.False(t, resp.Running)
require.NotNil(t, resp.ExitCode)
require.Equal(t, 0, *resp.ExitCode)
// Both stdout and stderr should be captured.
require.Contains(t, resp.Output, "stdout-msg")
require.Contains(t, resp.Output, "stderr-msg")
})
}
-309
View File
@@ -1,309 +0,0 @@
package agentproc
import (
"fmt"
"strings"
"sync"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
const (
// MaxHeadBytes is the number of bytes retained from the
// beginning of the output for LLM consumption.
MaxHeadBytes = 16 << 10 // 16KB
// MaxTailBytes is the number of bytes retained from the
// end of the output for LLM consumption.
MaxTailBytes = 16 << 10 // 16KB
// MaxLineLength is the maximum length of a single line
// before it is truncated. This prevents minified files
// or other long single-line output from consuming the
// entire buffer.
MaxLineLength = 2048
// lineTruncationSuffix is appended to lines that exceed
// MaxLineLength.
lineTruncationSuffix = " ... [truncated]"
)
// HeadTailBuffer is a thread-safe buffer that captures process
// output and provides head+tail truncation for LLM consumption.
// It implements io.Writer so it can be used directly as
// cmd.Stdout or cmd.Stderr.
//
// The buffer stores up to MaxHeadBytes from the beginning of
// the output and up to MaxTailBytes from the end in a ring
// buffer, keeping total memory usage bounded regardless of
// how much output is written.
type HeadTailBuffer struct {
mu sync.Mutex
head []byte
tail []byte
tailPos int
tailFull bool
headFull bool
totalBytes int
maxHead int
maxTail int
}
// NewHeadTailBuffer creates a new HeadTailBuffer with the
// default head and tail sizes.
func NewHeadTailBuffer() *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: MaxHeadBytes,
maxTail: MaxTailBytes,
}
}
// NewHeadTailBufferSized creates a HeadTailBuffer with custom
// head and tail sizes. This is useful for testing truncation
// logic with smaller buffers.
func NewHeadTailBufferSized(maxHead, maxTail int) *HeadTailBuffer {
return &HeadTailBuffer{
maxHead: maxHead,
maxTail: maxTail,
}
}
// Write implements io.Writer. It is safe for concurrent use.
// All bytes are accepted; the return value always equals
// len(p) with a nil error.
func (b *HeadTailBuffer) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
b.mu.Lock()
defer b.mu.Unlock()
n := len(p)
b.totalBytes += n
// Fill head buffer if it is not yet full.
if !b.headFull {
remaining := b.maxHead - len(b.head)
if remaining > 0 {
take := remaining
if take > len(p) {
take = len(p)
}
b.head = append(b.head, p[:take]...)
p = p[take:]
if len(b.head) >= b.maxHead {
b.headFull = true
}
}
if len(p) == 0 {
return n, nil
}
}
// Write remaining bytes into the tail ring buffer.
b.writeTail(p)
return n, nil
}
// writeTail appends data to the tail ring buffer. The caller
// must hold b.mu.
func (b *HeadTailBuffer) writeTail(p []byte) {
if b.maxTail <= 0 {
return
}
// Lazily allocate the tail buffer on first use.
if b.tail == nil {
b.tail = make([]byte, b.maxTail)
}
for len(p) > 0 {
// Write as many bytes as fit starting at tailPos.
space := b.maxTail - b.tailPos
take := space
if take > len(p) {
take = len(p)
}
copy(b.tail[b.tailPos:b.tailPos+take], p[:take])
p = p[take:]
b.tailPos += take
if b.tailPos >= b.maxTail {
b.tailPos = 0
b.tailFull = true
}
}
}
// tailBytes returns the current tail contents in order. The
// caller must hold b.mu.
func (b *HeadTailBuffer) tailBytes() []byte {
if b.tail == nil {
return nil
}
if !b.tailFull {
// Haven't wrapped yet; data is [0, tailPos).
return b.tail[:b.tailPos]
}
// Wrapped: data is [tailPos, maxTail) + [0, tailPos).
out := make([]byte, b.maxTail)
n := copy(out, b.tail[b.tailPos:])
copy(out[n:], b.tail[:b.tailPos])
return out
}
// Bytes returns a copy of the raw buffer contents. If no
// truncation has occurred the full output is returned;
// otherwise the head and tail portions are concatenated.
func (b *HeadTailBuffer) Bytes() []byte {
b.mu.Lock()
defer b.mu.Unlock()
tail := b.tailBytes()
if len(tail) == 0 {
out := make([]byte, len(b.head))
copy(out, b.head)
return out
}
out := make([]byte, len(b.head)+len(tail))
copy(out, b.head)
copy(out[len(b.head):], tail)
return out
}
// Len returns the number of bytes currently stored in the
// buffer.
func (b *HeadTailBuffer) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
tailLen := 0
if b.tailFull {
tailLen = b.maxTail
} else if b.tail != nil {
tailLen = b.tailPos
}
return len(b.head) + tailLen
}
// TotalWritten returns the total number of bytes written to
// the buffer, which may exceed the stored capacity.
func (b *HeadTailBuffer) TotalWritten() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.totalBytes
}
// Output returns the truncated output suitable for LLM
// consumption, along with truncation metadata. If the total
// output fits within the head buffer alone, the full output is
// returned with nil truncation info. Otherwise the head and
// tail are joined with an omission marker and long lines are
// truncated.
func (b *HeadTailBuffer) Output() (string, *workspacesdk.ProcessTruncation) {
b.mu.Lock()
head := make([]byte, len(b.head))
copy(head, b.head)
tail := b.tailBytes()
total := b.totalBytes
headFull := b.headFull
b.mu.Unlock()
storedLen := len(head) + len(tail)
// If everything fits, no head/tail split is needed.
if !headFull || len(tail) == 0 {
out := truncateLines(string(head))
if total == 0 {
return "", nil
}
return out, nil
}
// We have both head and tail data, meaning the total
// output exceeded the head capacity. Build the
// combined output with an omission marker.
omitted := total - storedLen
headStr := truncateLines(string(head))
tailStr := truncateLines(string(tail))
var sb strings.Builder
_, _ = sb.WriteString(headStr)
if omitted > 0 {
_, _ = sb.WriteString(fmt.Sprintf(
"\n\n... [omitted %d bytes] ...\n\n",
omitted,
))
} else {
// Head and tail are contiguous but were stored
// separately because the head filled up.
_, _ = sb.WriteString("\n")
}
_, _ = sb.WriteString(tailStr)
result := sb.String()
return result, &workspacesdk.ProcessTruncation{
OriginalBytes: total,
RetainedBytes: len(result),
OmittedBytes: omitted,
Strategy: "head_tail",
}
}
// truncateLines scans the input line by line and truncates
// any line longer than MaxLineLength.
func truncateLines(s string) string {
if len(s) <= MaxLineLength {
// Fast path: if the entire string is shorter than
// the max line length, no line can exceed it.
return s
}
var b strings.Builder
b.Grow(len(s))
for len(s) > 0 {
idx := strings.IndexByte(s, '\n')
var line string
if idx == -1 {
line = s
s = ""
} else {
line = s[:idx]
s = s[idx+1:]
}
if len(line) > MaxLineLength {
// Truncate preserving the suffix length so the
// total does not exceed a reasonable size.
cut := MaxLineLength - len(lineTruncationSuffix)
if cut < 0 {
cut = 0
}
_, _ = b.WriteString(line[:cut])
_, _ = b.WriteString(lineTruncationSuffix)
} else {
_, _ = b.WriteString(line)
}
// Re-add the newline unless this was the final
// segment without a trailing newline.
if idx != -1 {
_ = b.WriteByte('\n')
}
}
return b.String()
}
// Reset clears the buffer, discarding all data.
func (b *HeadTailBuffer) Reset() {
b.mu.Lock()
defer b.mu.Unlock()
b.head = nil
b.tail = nil
b.tailPos = 0
b.tailFull = false
b.headFull = false
b.totalBytes = 0
}
-338
View File
@@ -1,338 +0,0 @@
package agentproc_test
import (
"fmt"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/agent/agentproc"
)
func TestHeadTailBuffer_EmptyBuffer(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
require.Empty(t, buf.Bytes())
}
func TestHeadTailBuffer_SmallOutput(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
data := "hello world\n"
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, len(data), n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "small output should not be truncated")
require.Equal(t, len(data), buf.Len())
require.Equal(t, len(data), buf.TotalWritten())
}
func TestHeadTailBuffer_ExactlyHeadSize(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Build data that is exactly MaxHeadBytes using short
// lines so that line truncation does not apply.
line := strings.Repeat("x", 79) + "\n" // 80 bytes per line
count := agentproc.MaxHeadBytes / len(line)
pad := agentproc.MaxHeadBytes - (count * len(line))
data := strings.Repeat(line, count) + strings.Repeat("y", pad)
require.Equal(t, agentproc.MaxHeadBytes, len(data),
"test data must be exactly MaxHeadBytes")
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, agentproc.MaxHeadBytes, n)
out, info := buf.Output()
require.Equal(t, data, out)
require.Nil(t, info, "output fitting in head should not be truncated")
require.Equal(t, agentproc.MaxHeadBytes, buf.Len())
}
func TestHeadTailBuffer_HeadPlusTailNoOmission(t *testing.T) {
t.Parallel()
// Use a small buffer so we can test the boundary where
// head fills and tail starts but nothing is omitted.
// With maxHead=10, maxTail=10, writing exactly 20 bytes
// means head gets 10, tail gets 10, omitted = 0.
buf := agentproc.NewHeadTailBufferSized(10, 10)
data := "0123456789abcdefghij" // 20 bytes
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 20, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 0, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// The output should contain both head and tail.
require.Contains(t, out, "0123456789")
require.Contains(t, out, "abcdefghij")
}
func TestHeadTailBuffer_LargeOutputTruncation(t *testing.T) {
t.Parallel()
// Use small head/tail so truncation is easy to verify.
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write 100 bytes: head=10, tail=10, omitted=80.
data := strings.Repeat("A", 50) + strings.Repeat("Z", 50)
n, err := buf.Write([]byte(data))
require.NoError(t, err)
require.Equal(t, 100, n)
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 100, info.OriginalBytes)
require.Equal(t, 80, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// Head should be first 10 bytes (all A's).
require.True(t, strings.HasPrefix(out, "AAAAAAAAAA"))
// Tail should be last 10 bytes (all Z's).
require.True(t, strings.HasSuffix(out, "ZZZZZZZZZZ"))
// Omission marker should be present.
require.Contains(t, out, "... [omitted 80 bytes] ...")
require.Equal(t, 20, buf.Len())
require.Equal(t, 100, buf.TotalWritten())
}
func TestHeadTailBuffer_MultiMBStaysBounded(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write 5MB of data in chunks.
chunk := []byte(strings.Repeat("x", 4096) + "\n")
totalWritten := 0
for totalWritten < 5*1024*1024 {
n, err := buf.Write(chunk)
require.NoError(t, err)
require.Equal(t, len(chunk), n)
totalWritten += n
}
// Memory should be bounded to head+tail.
require.LessOrEqual(t, buf.Len(),
agentproc.MaxHeadBytes+agentproc.MaxTailBytes)
require.Equal(t, totalWritten, buf.TotalWritten())
out, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, totalWritten, info.OriginalBytes)
require.Greater(t, info.OmittedBytes, 0)
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_LongLineTruncation(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write a line longer than MaxLineLength.
longLine := strings.Repeat("m", agentproc.MaxLineLength+500)
_, err := buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 1)
require.LessOrEqual(t, len(lines[0]), agentproc.MaxLineLength)
require.True(t, strings.HasSuffix(lines[0], "... [truncated]"))
}
func TestHeadTailBuffer_LongLineInTail(t *testing.T) {
t.Parallel()
// Use small buffers so we can force data into the tail.
buf := agentproc.NewHeadTailBufferSized(20, 5000)
// Fill head with short data.
_, err := buf.Write([]byte("head data goes here\n"))
require.NoError(t, err)
// Now write a very long line into the tail.
longLine := strings.Repeat("T", agentproc.MaxLineLength+100)
_, err = buf.Write([]byte(longLine + "\n"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// The long line in the tail should be truncated.
require.Contains(t, out, "... [truncated]")
}
func TestHeadTailBuffer_ConcurrentWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
const goroutines = 10
const writes = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func() {
defer wg.Done()
line := fmt.Sprintf("goroutine-%d: data\n", g)
for range writes {
_, err := buf.Write([]byte(line))
assert.NoError(t, err)
}
}()
}
wg.Wait()
// Verify totals are consistent.
require.Greater(t, buf.TotalWritten(), 0)
require.Greater(t, buf.Len(), 0)
out, _ := buf.Output()
require.NotEmpty(t, out)
}
func TestHeadTailBuffer_TruncationInfoFields(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBufferSized(10, 10)
// Write enough to cause omission.
data := strings.Repeat("D", 50)
_, err := buf.Write([]byte(data))
require.NoError(t, err)
_, info := buf.Output()
require.NotNil(t, info)
require.Equal(t, 50, info.OriginalBytes)
require.Equal(t, 30, info.OmittedBytes)
require.Equal(t, "head_tail", info.Strategy)
// RetainedBytes is the length of the formatted output
// string including the omission marker.
require.Greater(t, info.RetainedBytes, 0)
}
func TestHeadTailBuffer_MultipleSmallWrites(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
// Write one byte at a time.
expected := "hello world"
for i := range len(expected) {
n, err := buf.Write([]byte{expected[i]})
require.NoError(t, err)
require.Equal(t, 1, n)
}
out, info := buf.Output()
require.Equal(t, expected, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_WriteEmptySlice(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
n, err := buf.Write([]byte{})
require.NoError(t, err)
require.Equal(t, 0, n)
require.Equal(t, 0, buf.TotalWritten())
}
func TestHeadTailBuffer_Reset(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("some data"))
require.NoError(t, err)
require.Greater(t, buf.Len(), 0)
buf.Reset()
require.Equal(t, 0, buf.Len())
require.Equal(t, 0, buf.TotalWritten())
out, info := buf.Output()
require.Empty(t, out)
require.Nil(t, info)
}
func TestHeadTailBuffer_BytesReturnsCopy(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
_, err := buf.Write([]byte("original"))
require.NoError(t, err)
b := buf.Bytes()
require.Equal(t, []byte("original"), b)
// Mutating the returned slice should not affect the
// buffer.
b[0] = 'X'
require.Equal(t, []byte("original"), buf.Bytes())
}
func TestHeadTailBuffer_RingBufferWraparound(t *testing.T) {
t.Parallel()
// Use a tail of 10 bytes and write enough to wrap
// around multiple times.
buf := agentproc.NewHeadTailBufferSized(5, 10)
// Fill head (5 bytes).
_, err := buf.Write([]byte("HEADD"))
require.NoError(t, err)
// Write 25 bytes into tail, wrapping 2.5 times.
_, err = buf.Write([]byte("0123456789"))
require.NoError(t, err)
_, err = buf.Write([]byte("abcdefghij"))
require.NoError(t, err)
_, err = buf.Write([]byte("ABCDE"))
require.NoError(t, err)
out, info := buf.Output()
require.NotNil(t, info)
// Tail should contain the last 10 bytes: "fghijABCDE".
require.True(t, strings.HasSuffix(out, "fghijABCDE"),
"expected tail to be last 10 bytes, got: %q", out)
}
func TestHeadTailBuffer_MultipleLinesTruncated(t *testing.T) {
t.Parallel()
buf := agentproc.NewHeadTailBuffer()
short := "short line\n"
long := strings.Repeat("L", agentproc.MaxLineLength+100) + "\n"
_, err := buf.Write([]byte(short + long + short))
require.NoError(t, err)
out, _ := buf.Output()
lines := strings.Split(strings.TrimRight(out, "\n"), "\n")
require.Len(t, lines, 3)
require.Equal(t, "short line", lines[0])
require.True(t, strings.HasSuffix(lines[1], "... [truncated]"))
require.Equal(t, "short line", lines[2])
}
-294
View File
@@ -1,294 +0,0 @@
package agentproc
import (
"context"
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
var (
errProcessNotFound = xerrors.New("process not found")
errProcessNotRunning = xerrors.New("process is not running")
)
// process represents a running or completed process.
type process struct {
mu sync.Mutex
id string
command string
workDir string
background bool
cmd *exec.Cmd
cancel context.CancelFunc
buf *HeadTailBuffer
running bool
exitCode *int
startedAt int64
exitedAt *int64
done chan struct{} // closed when process exits
}
// info returns a snapshot of the process state.
func (p *process) info() workspacesdk.ProcessInfo {
p.mu.Lock()
defer p.mu.Unlock()
return workspacesdk.ProcessInfo{
ID: p.id,
Command: p.command,
WorkDir: p.workDir,
Background: p.background,
Running: p.running,
ExitCode: p.exitCode,
StartedAt: p.startedAt,
ExitedAt: p.exitedAt,
}
}
// output returns the truncated output from the process buffer
// along with optional truncation metadata.
func (p *process) output() (string, *workspacesdk.ProcessTruncation) {
return p.buf.Output()
}
// manager tracks processes spawned by the agent.
type manager struct {
mu sync.Mutex
logger slog.Logger
execer agentexec.Execer
clock quartz.Clock
procs map[string]*process
closed bool
updateEnv func(current []string) (updated []string, err error)
}
// newManager creates a new process manager.
func newManager(logger slog.Logger, execer agentexec.Execer, updateEnv func(current []string) (updated []string, err error)) *manager {
return &manager{
logger: logger,
execer: execer,
clock: quartz.NewReal(),
procs: make(map[string]*process),
updateEnv: updateEnv,
}
}
// start spawns a new process. Both foreground and background
// processes use a long-lived context so the process survives
// the HTTP request lifecycle. The background flag only affects
// client-side polling behavior.
func (m *manager) start(req workspacesdk.StartProcessRequest) (*process, error) {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil, xerrors.New("manager is closed")
}
m.mu.Unlock()
id := uuid.New().String()
// Use a cancellable context so Close() can terminate
// all processes. context.Background() is the parent so
// the process is not tied to any HTTP request.
ctx, cancel := context.WithCancel(context.Background())
cmd := m.execer.CommandContext(ctx, "sh", "-c", req.Command)
if req.WorkDir != "" {
cmd.Dir = req.WorkDir
}
cmd.Stdin = nil
// WaitDelay ensures cmd.Wait returns promptly after
// the process is killed, even if child processes are
// still holding the stdout/stderr pipes open.
cmd.WaitDelay = 5 * time.Second
buf := NewHeadTailBuffer()
cmd.Stdout = buf
cmd.Stderr = buf
// Build the process environment. If the manager has an
// updateEnv hook (provided by the agent), use it to get the
// full agent environment including GIT_ASKPASS, CODER_* vars,
// etc. Otherwise fall back to the current process env.
baseEnv := os.Environ()
if m.updateEnv != nil {
updated, err := m.updateEnv(baseEnv)
if err != nil {
m.logger.Warn(
context.Background(),
"failed to update command environment, falling back to os env",
slog.Error(err),
)
} else {
baseEnv = updated
}
}
// Always set cmd.Env explicitly so that req.Env overrides
// are applied on top of the full agent environment.
cmd.Env = baseEnv
for k, v := range req.Env {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v))
}
if err := cmd.Start(); err != nil {
cancel()
return nil, xerrors.Errorf("start process: %w", err)
}
now := m.clock.Now().Unix()
proc := &process{
id: id,
command: req.Command,
workDir: req.WorkDir,
background: req.Background,
cmd: cmd,
cancel: cancel,
buf: buf,
running: true,
startedAt: now,
done: make(chan struct{}),
}
m.mu.Lock()
if m.closed {
m.mu.Unlock()
// Manager closed between our check and now. Kill the
// process we just started.
cancel()
_ = cmd.Wait()
return nil, xerrors.New("manager is closed")
}
m.procs[id] = proc
m.mu.Unlock()
go func() {
err := cmd.Wait()
exitedAt := m.clock.Now().Unix()
proc.mu.Lock()
proc.running = false
proc.exitedAt = &exitedAt
code := 0
if err != nil {
// Extract the exit code from the error.
var exitErr *exec.ExitError
if xerrors.As(err, &exitErr) {
code = exitErr.ExitCode()
} else {
// Unknown error; use -1 as a sentinel.
code = -1
m.logger.Warn(
context.Background(),
"process wait returned non-exit error",
slog.F("id", id),
slog.Error(err),
)
}
}
proc.exitCode = &code
proc.mu.Unlock()
close(proc.done)
}()
return proc, nil
}
// get returns a process by ID.
func (m *manager) get(id string) (*process, bool) {
m.mu.Lock()
defer m.mu.Unlock()
proc, ok := m.procs[id]
return proc, ok
}
// list returns info about all tracked processes.
func (m *manager) list() []workspacesdk.ProcessInfo {
m.mu.Lock()
defer m.mu.Unlock()
infos := make([]workspacesdk.ProcessInfo, 0, len(m.procs))
for _, proc := range m.procs {
infos = append(infos, proc.info())
}
return infos
}
// signal sends a signal to a running process. It returns
// sentinel errors errProcessNotFound and errProcessNotRunning
// so callers can distinguish failure modes.
func (m *manager) signal(id string, sig string) error {
m.mu.Lock()
proc, ok := m.procs[id]
m.mu.Unlock()
if !ok {
return errProcessNotFound
}
proc.mu.Lock()
defer proc.mu.Unlock()
if !proc.running {
return errProcessNotRunning
}
switch sig {
case "kill":
if err := proc.cmd.Process.Kill(); err != nil {
return xerrors.Errorf("kill process: %w", err)
}
case "terminate":
//nolint:revive // syscall.SIGTERM is portable enough
// for our supported platforms.
if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil {
return xerrors.Errorf("terminate process: %w", err)
}
default:
return xerrors.Errorf("unsupported signal %q", sig)
}
return nil
}
// Close kills all running processes and prevents new ones from
// starting. It cancels each process's context, which causes
// CommandContext to kill the process and its pipe goroutines to
// drain.
func (m *manager) Close() error {
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return nil
}
m.closed = true
procs := make([]*process, 0, len(m.procs))
for _, p := range m.procs {
procs = append(procs, p)
}
m.mu.Unlock()
for _, p := range procs {
p.cancel()
}
// Wait for all processes to exit.
for _, p := range procs {
<-p.done
}
return nil
}
+103 -2
View File
@@ -1,22 +1,37 @@
package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentsocket"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestServer(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("StartStop", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -26,7 +41,7 @@ func TestServer(t *testing.T) {
t.Run("AlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server1, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
@@ -34,4 +49,90 @@ func TestServer(t *testing.T) {
_, err = agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "create socket")
})
t.Run("AutoSocketPath", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
server, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.NoError(t, err)
require.NoError(t, server.Close())
})
}
func TestServerWindowsNotSupported(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
t.Run("NewServer", func(t *testing.T) {
t.Parallel()
socketPath := filepath.Join(t.TempDir(), "test.sock")
logger := slog.Make().Leveled(slog.LevelDebug)
_, err := agentsocket.NewServer(logger, agentsocket.WithPath(socketPath))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
t.Run("NewClient", func(t *testing.T) {
t.Parallel()
_, err := agentsocket.NewClient(context.Background(), agentsocket.WithPath("test.sock"))
require.ErrorContains(t, err, "agentsocket is not supported on Windows")
})
}
func TestAgentInitializesOnWindowsWithoutSocketServer(t *testing.T) {
t.Parallel()
if runtime.GOOS != "windows" {
t.Skip("this test only runs on Windows")
}
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t).Named("agent")
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *agentproto.Stats, 50)
agentID := uuid.New()
manifest := agentsdk.Manifest{
AgentID: agentID,
AgentName: "test-agent",
WorkspaceName: "test-workspace",
OwnerName: "test-user",
WorkspaceID: uuid.New(),
DERPMap: derpMap,
}
client := agenttest.NewClient(t, logger.Named("agenttest"), agentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)
options := agent.Options{
Client: client,
Filesystem: afero.NewMemMapFs(),
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: testutil.WaitShort,
EnvironmentVariables: map[string]string{},
SocketPath: "",
}
agnt := agent.New(options)
t.Cleanup(func() {
_ = agnt.Close()
})
startup := testutil.TryReceive(ctx, t, client.GetStartup())
require.NotNil(t, startup, "agent should send startup message")
err := agnt.Close()
require.NoError(t, err, "agent should close cleanly")
}
+17 -11
View File
@@ -2,6 +2,8 @@ package agentsocket_test
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/require"
@@ -28,10 +30,14 @@ func newSocketClient(ctx context.Context, t *testing.T, socketPath string) *agen
func TestDRPCAgentSocketService(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("agentsocket is not supported on Windows")
}
t.Run("Ping", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -51,7 +57,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -73,7 +79,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -103,7 +109,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitAlreadyCompleted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -142,7 +148,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -172,7 +178,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("NewUnits", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -197,7 +203,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAlreadyRegistered", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -232,7 +238,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("DependencyAddedAfterDependentStarted", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -274,7 +280,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnregisteredUnit", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -293,7 +299,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitNotReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
@@ -317,7 +323,7 @@ func TestDRPCAgentSocketService(t *testing.T) {
t.Run("UnitReady", func(t *testing.T) {
t.Parallel()
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
ctx := testutil.Context(t, testutil.WaitShort)
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
+6 -47
View File
@@ -4,60 +4,19 @@ package agentsocket
import (
"context"
"fmt"
"net"
"os"
"os/user"
"strings"
"github.com/Microsoft/go-winio"
"golang.org/x/xerrors"
)
const defaultSocketPath = `\\.\pipe\com.coder.agentsocket`
func createSocket(path string) (net.Listener, error) {
if path == "" {
path = defaultSocketPath
}
if !strings.HasPrefix(path, `\\.\pipe\`) {
return nil, xerrors.Errorf("%q is not a valid local socket path", path)
}
user, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to look up current user: %w", err)
}
sid := user.Uid
// SecurityDescriptor is in SDDL format. c.f.
// https://learn.microsoft.com/en-us/windows/win32/secauthz/security-descriptor-string-format for full details.
// D: indicates this is a Discretionary Access Control List (DACL), which is Windows-speak for ACLs that allow or
// deny access (as opposed to SACL which controls audit logging).
// P indicates that this DACL is "protected" from being modified thru inheritance
// () delimit access control entries (ACEs), here we only have one, which, allows (A) generic all (GA) access to our
// specific user's security ID (SID).
//
// Note that although Microsoft docs at https://learn.microsoft.com/en-us/windows/win32/ipc/named-pipes warns that
// named pipes are accessible from remote machines in the general case, the `winio` package sets the flag
// windows.FILE_PIPE_REJECT_REMOTE_CLIENTS when creating pipes, so connections from remote machines are always
// denied. This is important because we sort of expect customers to run the Coder agent under a generic user
// account unless they are very sophisticated. We don't want this socket to cross the boundary of the local machine.
configuration := &winio.PipeConfig{
SecurityDescriptor: fmt.Sprintf("D:P(A;;GA;;;%s)", sid),
}
listener, err := winio.ListenPipe(path, configuration)
if err != nil {
return nil, xerrors.Errorf("failed to open named pipe: %w", err)
}
return listener, nil
func createSocket(_ string) (net.Listener, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
func cleanupSocket(path string) error {
return os.Remove(path)
func cleanupSocket(_ string) error {
return nil
}
func dialSocket(ctx context.Context, path string) (net.Conn, error) {
return winio.DialPipeContext(ctx, path)
func dialSocket(_ context.Context, _ string) (net.Conn, error) {
return nil, xerrors.New("agentsocket is not supported on Windows")
}
-10
View File
@@ -124,12 +124,6 @@ func (c *Client) Close() {
c.derpMapOnce.Do(func() { close(c.derpMapUpdates) })
}
func (c *Client) ConnectRPC28WithRole(ctx context.Context, _ string) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
return c.ConnectRPC28(ctx)
}
func (c *Client) ConnectRPC28(ctx context.Context) (
agentproto.DRPCAgentClient28, proto.DRPCTailnetClient28, error,
) {
@@ -235,10 +229,6 @@ type FakeAgentAPI struct {
pushResourcesMonitoringUsageFunc func(*agentproto.PushResourcesMonitoringUsageRequest) (*agentproto.PushResourcesMonitoringUsageResponse, error)
}
func (*FakeAgentAPI) UpdateAppStatus(context.Context, *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
panic("unimplemented")
}
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}
-1
View File
@@ -28,7 +28,6 @@ func (a *agent) apiHandler() http.Handler {
})
r.Mount("/api/v0", a.filesAPI.Routes())
r.Mount("/api/v0/processes", a.processAPI.Routes())
if a.devcontainers {
r.Mount("/api/v0/containers", a.containerAPI.Routes())
-316
View File
@@ -1,316 +0,0 @@
package filefinder_test
import (
"context"
"fmt"
"math/rand"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
)
var (
dirNames = []string{
"cmd", "internal", "pkg", "api", "auth", "database", "server", "client", "middleware",
"handler", "config", "utils", "models", "service", "worker", "scheduler", "notification",
"provisioner", "template", "workspace", "agent", "proxy", "crypto", "telemetry", "billing",
}
fileExts = []string{
".go", ".ts", ".tsx", ".js", ".py", ".sql", ".yaml", ".json", ".md", ".proto", ".sh",
}
fileStems = []string{
"main", "handler", "middleware", "service", "model", "query", "config", "utils", "helpers",
"types", "interface", "test", "mock", "factory", "builder", "adapter", "observer", "provider",
"resolver", "schema", "migration", "fixture", "snapshot", "checkpoint",
}
)
// generateFileTree creates n files under root in a realistic nested directory structure.
func generateFileTree(t testing.TB, root string, n int, seed int64) {
t.Helper()
rng := rand.New(rand.NewSource(seed)) //nolint:gosec // deterministic benchmarks
numDirs := n / 5
if numDirs < 10 {
numDirs = 10
}
dirs := make([]string, 0, numDirs)
for i := 0; i < numDirs; i++ {
depth := rng.Intn(6) + 1
parts := make([]string, depth)
for d := 0; d < depth; d++ {
parts[d] = dirNames[rng.Intn(len(dirNames))]
}
dirs = append(dirs, filepath.Join(parts...))
}
created := make(map[string]struct{})
for _, d := range dirs {
full := filepath.Join(root, d)
if _, ok := created[full]; ok {
continue
}
require.NoError(t, os.MkdirAll(full, 0o755))
created[full] = struct{}{}
}
for i := 0; i < n; i++ {
dir := dirs[rng.Intn(len(dirs))]
stem := fileStems[rng.Intn(len(fileStems))]
ext := fileExts[rng.Intn(len(fileExts))]
name := fmt.Sprintf("%s_%d%s", stem, i, ext)
full := filepath.Join(root, dir, name)
f, err := os.Create(full)
require.NoError(t, err)
_ = f.Close()
}
}
// buildIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func buildIndex(t testing.TB, root string) *filefinder.Index {
t.Helper()
absRoot, err := filepath.Abs(root)
require.NoError(t, err)
idx, err := filefinder.BuildTestIndex(absRoot)
require.NoError(t, err)
return idx
}
func BenchmarkBuildIndex(b *testing.B) {
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
if idx.Len() == 0 {
b.Fatal("expected non-empty index")
}
}
b.StopTimer()
idx := buildIndex(b, dir)
b.ReportMetric(float64(idx.Len())/b.Elapsed().Seconds(), "files/sec")
})
}
}
func BenchmarkSearch_ByScale(b *testing.B) {
queries := []struct {
name string
query string
}{
{"exact_basename", "handler.go"},
{"short_query", "ha"},
{"fuzzy_basename", "hndlr"},
{"path_structured", "internal/handler"},
{"multi_token", "api handler"},
}
scales := []struct {
name string
n int
}{
{"1K", 1_000},
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale benchmark")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
opts := filefinder.DefaultSearchOptions()
for _, q := range queries {
b.Run(q.name, func(b *testing.B) {
p := filefinder.NewQueryPlanForTest(q.query)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(p, snap, opts.MaxCandidates)
}
})
}
})
}
}
func BenchmarkSearch_ConcurrentReads(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
logger := slogtest.Make(b, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelError)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(b, eng.AddRoot(ctx, dir))
b.Cleanup(func() { _ = eng.Close() })
opts := filefinder.DefaultSearchOptions()
goroutines := []int{1, 4, 16, 64}
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.SetParallelism(g)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
results, err := eng.Search(ctx, "handler", opts)
if err != nil {
b.Fatal(err)
}
_ = results
}
})
})
}
}
func BenchmarkDeltaUpdate(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
addCounts := []int{1, 10, 100}
for _, count := range addCounts {
b.Run(fmt.Sprintf("add_%d_files", count), func(b *testing.B) {
paths := make([]string, count)
for i := range paths {
paths[i] = fmt.Sprintf("injected/dir_%d/newfile_%d.go", i%10, i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
idx := buildIndex(b, dir)
b.StartTimer()
for _, p := range paths {
idx.Add(p, 0)
}
}
b.ReportMetric(float64(count), "files_added/op")
})
}
b.Run("search_after_100_additions", func(b *testing.B) {
idx := buildIndex(b, dir)
for i := 0; i < 100; i++ {
idx.Add(fmt.Sprintf("injected/extra/file_%d.go", i), 0)
}
snap := idx.Snapshot()
plan := filefinder.NewQueryPlanForTest("handler")
opts := filefinder.DefaultSearchOptions()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, opts.MaxCandidates)
}
})
}
func BenchmarkMemoryProfile(b *testing.B) {
scales := []struct {
name string
n int
}{
{"10K", 10_000},
{"100K", 100_000},
}
for _, sc := range scales {
b.Run(sc.name, func(b *testing.B) {
if sc.n >= 100_000 && testing.Short() {
b.Skip("skipping large-scale memory profile")
}
dir := b.TempDir()
generateFileTree(b, dir, sc.n, 42)
b.ResetTimer()
for i := 0; i < b.N; i++ {
idx := buildIndex(b, dir)
_ = idx.Snapshot()
}
b.StopTimer()
// Report memory stats on the last iteration.
runtime.GC()
var before runtime.MemStats
runtime.ReadMemStats(&before)
idx := buildIndex(b, dir)
var after runtime.MemStats
runtime.ReadMemStats(&after)
allocDelta := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(allocDelta)/float64(idx.Len()), "bytes/file")
runtime.GC()
runtime.ReadMemStats(&before)
snap := idx.Snapshot()
_ = snap
runtime.GC()
runtime.ReadMemStats(&after)
snapAlloc := after.TotalAlloc - before.TotalAlloc
b.ReportMetric(float64(snapAlloc)/float64(idx.Len()), "snap-bytes/file")
})
}
}
func BenchmarkSearch_ConcurrentReads_Throughput(b *testing.B) {
dir := b.TempDir()
generateFileTree(b, dir, 10_000, 42)
idx := buildIndex(b, dir)
snap := idx.Snapshot()
goroutines := []int{1, 4, 16, 64}
plan := filefinder.NewQueryPlanForTest("handler.go")
maxCands := filefinder.DefaultSearchOptions().MaxCandidates
for _, g := range goroutines {
b.Run(fmt.Sprintf("goroutines_%d", g), func(b *testing.B) {
b.ResetTimer()
var wg sync.WaitGroup
perGoroutine := b.N / g
if perGoroutine < 1 {
perGoroutine = 1
}
for gi := 0; gi < g; gi++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < perGoroutine; j++ {
_ = filefinder.SearchSnapshotForTest(plan, snap, maxCands)
}
}()
}
wg.Wait()
totalOps := float64(g * perGoroutine)
b.ReportMetric(totalOps/b.Elapsed().Seconds(), "searches/sec")
})
}
}
-125
View File
@@ -1,125 +0,0 @@
package filefinder
import "strings"
// FileFlag represents the type of filesystem entry.
type FileFlag uint16
const (
FlagFile FileFlag = 0
FlagDir FileFlag = 1
FlagSymlink FileFlag = 2
)
type doc struct {
path string
baseOff int
baseLen int
depth int
flags uint16
}
// Index is an append-only in-memory file index with snapshot support.
type Index struct {
docs []doc
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
byPath map[string]uint32
deleted map[uint32]bool
}
// Snapshot is a frozen, read-only view of the index at a point in time.
type Snapshot struct {
docs []doc
deleted map[uint32]bool
byGram map[uint32][]uint32
byPrefix1 [256][]uint32
byPrefix2 map[uint16][]uint32
}
// NewIndex creates an empty Index.
func NewIndex() *Index {
return &Index{
byGram: make(map[uint32][]uint32),
byPrefix2: make(map[uint16][]uint32),
byPath: make(map[string]uint32),
deleted: make(map[uint32]bool),
}
}
// Add inserts a path into the index, tombstoning any previous entry.
func (idx *Index) Add(path string, flags uint16) uint32 {
norm := string(normalizePathBytes([]byte(path)))
if oldID, ok := idx.byPath[norm]; ok {
idx.deleted[oldID] = true
}
id := uint32(len(idx.docs)) //nolint:gosec // Index will never exceed 2^32 docs.
baseOff, baseLen := extractBasename([]byte(norm))
idx.docs = append(idx.docs, doc{
path: norm, baseOff: baseOff, baseLen: baseLen,
depth: strings.Count(norm, "/"), flags: flags,
})
idx.byPath[norm] = id
for _, g := range extractTrigrams([]byte(norm)) {
idx.byGram[g] = append(idx.byGram[g], id)
}
if baseLen > 0 {
basename := []byte(norm[baseOff : baseOff+baseLen])
p1 := prefix1(basename)
idx.byPrefix1[p1] = append(idx.byPrefix1[p1], id)
p2 := prefix2(basename)
idx.byPrefix2[p2] = append(idx.byPrefix2[p2], id)
}
return id
}
// Remove marks the entry for path as deleted.
func (idx *Index) Remove(path string) bool {
norm := string(normalizePathBytes([]byte(path)))
id, ok := idx.byPath[norm]
if !ok {
return false
}
idx.deleted[id] = true
delete(idx.byPath, norm)
return true
}
// Has reports whether path exists (not deleted) in the index.
func (idx *Index) Has(path string) bool {
_, ok := idx.byPath[string(normalizePathBytes([]byte(path)))]
return ok
}
// Len returns the number of live (non-deleted) documents.
func (idx *Index) Len() int { return len(idx.byPath) }
func copyPostings[K comparable](m map[K][]uint32) map[K][]uint32 {
cp := make(map[K][]uint32, len(m))
for k, v := range m {
cp[k] = v[:len(v):len(v)]
}
return cp
}
// Snapshot returns a frozen read-only view of the index.
func (idx *Index) Snapshot() *Snapshot {
del := make(map[uint32]bool, len(idx.deleted))
for id := range idx.deleted {
del[id] = true
}
var p1Copy [256][]uint32
for i, ids := range idx.byPrefix1 {
if len(ids) > 0 {
p1Copy[i] = ids[:len(ids):len(ids)]
}
}
return &Snapshot{
docs: idx.docs[:len(idx.docs):len(idx.docs)],
deleted: del,
byGram: copyPostings(idx.byGram),
byPrefix1: p1Copy,
byPrefix2: copyPostings(idx.byPrefix2),
}
}
-120
View File
@@ -1,120 +0,0 @@
package filefinder_test
import (
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestIndex_AddAndLen(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
if idx.Len() != 2 {
t.Fatalf("expected 2, got %d", idx.Len())
}
}
func TestIndex_Has(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Has("foo/bar.go") {
t.Fatal("expected Has to return true")
}
if idx.Has("foo/missing.go") {
t.Fatal("expected Has to return false for missing path")
}
}
func TestIndex_Remove(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
if !idx.Remove("foo/bar.go") {
t.Fatal("expected Remove to return true")
}
if idx.Has("foo/bar.go") {
t.Fatal("expected Has to return false after Remove")
}
if idx.Len() != 0 {
t.Fatalf("expected Len 0 after Remove, got %d", idx.Len())
}
}
func TestIndex_AddOverwrite(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", uint16(filefinder.FlagFile))
idx.Add("foo/bar.go", uint16(filefinder.FlagDir)) // overwrite
if idx.Len() != 1 {
t.Fatalf("expected 1 after overwrite, got %d", idx.Len())
}
// The old entry should be tombstoned.
if !filefinder.IndexIsDeleted(idx, 0) {
t.Fatal("expected old entry to be deleted")
}
if filefinder.IndexIsDeleted(idx, 1) {
t.Fatal("expected new entry to be live")
}
}
func TestIndex_Snapshot(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("foo/bar.go", 0)
idx.Add("foo/baz.go", 0)
snap := idx.Snapshot()
if filefinder.SnapshotCount(snap) != 2 {
t.Fatalf("expected snapshot count 2, got %d", filefinder.SnapshotCount(snap))
}
// Adding more docs after snapshot doesn't affect it.
idx.Add("foo/qux.go", 0)
if filefinder.SnapshotCount(snap) != 2 {
t.Fatal("snapshot count should not change after new adds")
}
}
func TestIndex_TrigramIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// "handler.go" should produce trigrams for "handler.go".
// Check that at least one trigram exists.
if filefinder.IndexByGramLen(idx) == 0 {
t.Fatal("expected non-empty trigram index")
}
}
func TestIndex_PrefixIndex(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
// basename is "handler.go", first byte is 'h'
if filefinder.IndexByPrefix1Len(idx, 'h') == 0 {
t.Fatal("expected prefix1['h'] to be non-empty")
}
}
func TestIndex_RemoveNonexistent(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
if idx.Remove("nonexistent.go") {
t.Fatal("expected Remove to return false for missing path")
}
}
func TestIndex_PathNormalization(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("Foo/Bar.go", 0)
// Should be findable with lowercase.
if !idx.Has("foo/bar.go") {
t.Fatal("expected case-insensitive Has")
}
}
-364
View File
@@ -1,364 +0,0 @@
// Package filefinder provides an in-memory file index with trigram
// matching, fuzzy search, and filesystem watching. It is designed
// to power file-finding features on workspace agents.
package filefinder
import (
"context"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// SearchOptions controls search behavior.
type SearchOptions struct {
Limit int
MaxCandidates int
}
// DefaultSearchOptions returns sensible default search options.
func DefaultSearchOptions() SearchOptions {
return SearchOptions{Limit: 100, MaxCandidates: 10000}
}
type rootSnapshot struct {
root string
snap *Snapshot
}
// Engine is the main file finder. Safe for concurrent use.
type Engine struct {
snap atomic.Pointer[[]*rootSnapshot]
logger slog.Logger
mu sync.Mutex
roots map[string]*rootState
eventCh chan rootEvent
closeCh chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
type rootState struct {
root string
index *Index
watcher *fsWatcher
cancel context.CancelFunc
}
type rootEvent struct {
root string
events []FSEvent
}
// walkRoot performs a full filesystem walk of absRoot and returns
// a populated Index containing all discovered files and directories.
func walkRoot(absRoot string) (*Index, error) {
idx := NewIndex()
err := filepath.Walk(absRoot, func(path string, info os.FileInfo, walkErr error) error {
if walkErr != nil {
return nil //nolint:nilerr
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if path == absRoot {
return nil
}
relPath, relErr := filepath.Rel(absRoot, path)
if relErr != nil {
return nil //nolint:nilerr
}
relPath = filepath.ToSlash(relPath)
var flags uint16
if info.IsDir() {
flags = uint16(FlagDir)
} else if info.Mode()&os.ModeSymlink != 0 {
flags = uint16(FlagSymlink)
}
idx.Add(relPath, flags)
return nil
})
return idx, err
}
// NewEngine creates a new Engine.
func NewEngine(logger slog.Logger) *Engine {
e := &Engine{
logger: logger,
roots: make(map[string]*rootState),
eventCh: make(chan rootEvent, 256),
closeCh: make(chan struct{}),
}
empty := make([]*rootSnapshot, 0)
e.snap.Store(&empty)
e.wg.Add(1)
go e.start()
return e
}
// ErrClosed is returned when operations are attempted on a
// closed engine.
var ErrClosed = xerrors.New("engine is closed")
// AddRoot adds a directory root to the engine.
func (e *Engine) AddRoot(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
if e.closed.Load() {
e.mu.Unlock()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
return nil
}
e.mu.Unlock()
// Walk and create the watcher outside the lock to avoid
// blocking the event pipeline on filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("walk root: %w", walkErr)
}
wCtx, wCancel := context.WithCancel(context.Background())
w, wErr := newFSWatcher(absRoot, e.logger)
if wErr != nil {
wCancel()
return xerrors.Errorf("create watcher: %w", wErr)
}
e.mu.Lock()
// Re-check after re-acquiring the lock: another goroutine
// may have added this root or closed the engine while we
// were walking.
if e.closed.Load() {
e.mu.Unlock()
wCancel()
_ = w.Close()
return ErrClosed
}
if _, exists := e.roots[absRoot]; exists {
e.mu.Unlock()
wCancel()
_ = w.Close()
return nil
}
rs := &rootState{root: absRoot, index: idx, watcher: w, cancel: wCancel}
e.roots[absRoot] = rs
w.Start(wCtx)
e.wg.Add(1)
go e.forwardEvents(wCtx, absRoot, w)
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "added root to engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
// RemoveRoot stops watching a root and removes it.
func (e *Engine) RemoveRoot(root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[absRoot]
if !exists {
return xerrors.Errorf("root %q not found", absRoot)
}
rs.cancel()
_ = rs.watcher.Close()
delete(e.roots, absRoot)
e.publishSnapshot()
return nil
}
// Search performs a fuzzy file search across all roots.
func (e *Engine) Search(_ context.Context, query string, opts SearchOptions) ([]Result, error) {
if e.closed.Load() {
return nil, ErrClosed
}
snapPtr := e.snap.Load()
if snapPtr == nil || len(*snapPtr) == 0 {
return nil, nil
}
roots := *snapPtr
plan := newQueryPlan(query)
if len(plan.Normalized) == 0 {
return nil, nil
}
if opts.Limit <= 0 {
opts.Limit = 100
}
if opts.MaxCandidates <= 0 {
opts.MaxCandidates = 10000
}
params := defaultScoreParams()
var allCands []candidate
for _, rs := range roots {
allCands = append(allCands, searchSnapshot(plan, rs.snap, opts.MaxCandidates)...)
}
results := mergeAndScore(allCands, plan, params, opts.Limit)
return results, nil
}
// Close shuts down the engine.
func (e *Engine) Close() error {
if e.closed.Swap(true) {
return nil
}
close(e.closeCh)
e.mu.Lock()
for _, rs := range e.roots {
rs.cancel()
_ = rs.watcher.Close()
}
e.roots = make(map[string]*rootState)
e.mu.Unlock()
e.wg.Wait()
return nil
}
// Rebuild forces a complete re-walk and re-index of a root.
func (e *Engine) Rebuild(ctx context.Context, root string) error {
absRoot, err := filepath.Abs(root)
if err != nil {
return xerrors.Errorf("resolve root: %w", err)
}
// Walk outside the lock to avoid blocking the event
// pipeline on potentially slow filesystem I/O.
idx, walkErr := walkRoot(absRoot)
if walkErr != nil {
return xerrors.Errorf("rebuild walk: %w", walkErr)
}
e.mu.Lock()
rs, exists := e.roots[absRoot]
if !exists {
e.mu.Unlock()
return xerrors.Errorf("root %q not found", absRoot)
}
rs.index = idx
e.publishSnapshot()
fileCount := idx.Len()
e.mu.Unlock()
e.logger.Info(ctx, "rebuilt root in engine",
slog.F("root", absRoot),
slog.F("files", fileCount),
)
return nil
}
func (e *Engine) start() {
defer e.wg.Done()
for {
select {
case <-e.closeCh:
return
case re, ok := <-e.eventCh:
if !ok {
return
}
e.applyEvents(re)
}
}
}
func (e *Engine) forwardEvents(ctx context.Context, root string, w *fsWatcher) {
defer e.wg.Done()
for {
select {
case <-ctx.Done():
return
case <-e.closeCh:
return
case evts, ok := <-w.Events():
if !ok {
return
}
select {
case e.eventCh <- rootEvent{root: root, events: evts}:
case <-ctx.Done():
return
case <-e.closeCh:
return
}
}
}
}
func (e *Engine) applyEvents(re rootEvent) {
e.mu.Lock()
defer e.mu.Unlock()
rs, exists := e.roots[re.root]
if !exists {
return
}
changed := false
for _, ev := range re.events {
relPath, err := filepath.Rel(rs.root, ev.Path)
if err != nil {
continue
}
relPath = filepath.ToSlash(relPath)
switch ev.Op {
case OpCreate:
if rs.index.Has(relPath) {
continue
}
var flags uint16
if ev.IsDir {
flags = uint16(FlagDir)
}
rs.index.Add(relPath, flags)
changed = true
case OpRemove, OpRename:
if rs.index.Remove(relPath) {
changed = true
}
if ev.IsDir || ev.Op == OpRename {
prefix := strings.ToLower(filepath.ToSlash(relPath)) + "/"
for path := range rs.index.byPath {
if strings.HasPrefix(path, prefix) {
rs.index.Remove(path)
changed = true
}
}
}
case OpModify:
}
}
if changed {
e.publishSnapshot()
}
}
// publishSnapshot builds and atomically publishes a new snapshot.
// Must be called with e.mu held.
func (e *Engine) publishSnapshot() {
roots := make([]*rootSnapshot, 0, len(e.roots))
for _, rs := range e.roots {
roots = append(roots, &rootSnapshot{
root: rs.root,
snap: rs.index.Snapshot(),
})
}
slices.SortFunc(roots, func(a, b *rootSnapshot) int {
return strings.Compare(a.root, b.root)
})
e.snap.Store(&roots)
}
-233
View File
@@ -1,233 +0,0 @@
package filefinder_test
import (
"context"
"os"
"path/filepath"
"sort"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/filefinder"
"github.com/coder/coder/v2/testutil"
)
func newTestEngine(t *testing.T) (*filefinder.Engine, context.Context) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
eng := filefinder.NewEngine(logger)
t.Cleanup(func() { _ = eng.Close() })
return eng, context.Background()
}
func requireResultHasPath(t *testing.T, results []filefinder.Result, path string) {
t.Helper()
for _, r := range results {
if r.Path == path {
return
}
}
t.Errorf("expected %q in results, got %v", path, resultPaths(results))
}
func TestEngine_SearchFindsKnownFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/main.go", "package main")
createFile(t, dir, "src/handler.go", "package main")
createFile(t, dir, "README.md", "# hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "main.go", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find main.go")
requireResultHasPath(t, results, "src/main.go")
}
func TestEngine_SearchFuzzyMatch(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "src/controllers/user_handler.go", "package controllers")
createFile(t, dir, "src/models/user.go", "package models")
createFile(t, dir, "docs/api.md", "# API")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
// "handler" should match "user_handler.go".
results, err := eng.Search(ctx, "handler", filefinder.DefaultSearchOptions())
require.NoError(t, err)
// The query is a subsequence of "user_handler.go" so it
// should appear somewhere in the results.
requireResultHasPath(t, results, "src/controllers/user_handler.go")
}
func TestEngine_IndexPicksUpNewFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "existing.txt", "hello")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "newfile_unique.txt", "world")
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "newfile_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "newfile_unique.txt" {
return true
}
}
return false
}, testutil.WaitShort, testutil.IntervalFast, "expected newfile_unique.txt to appear via watcher")
}
func TestEngine_IndexRemovesDeletedFile(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "deleteme_unique.txt", "goodbye")
createFile(t, dir, "keeper.txt", "stay")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results, "expected to find deleteme_unique.txt initially")
require.NoError(t, os.Remove(filepath.Join(dir, "deleteme_unique.txt")))
require.Eventually(t, func() bool {
results, sErr := eng.Search(ctx, "deleteme_unique", filefinder.DefaultSearchOptions())
if sErr != nil {
return false
}
for _, r := range results {
if r.Path == "deleteme_unique.txt" {
return false // still found
}
}
return true
}, testutil.WaitShort, testutil.IntervalFast, "expected deleteme_unique.txt to disappear after removal")
}
func TestEngine_MultipleRoots(t *testing.T) {
t.Parallel()
dir1 := t.TempDir()
dir2 := t.TempDir()
createFile(t, dir1, "alpha_unique.go", "package alpha")
createFile(t, dir2, "beta_unique.go", "package beta")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir1))
require.NoError(t, eng.AddRoot(ctx, dir2))
results, err := eng.Search(ctx, "alpha_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "alpha_unique.go")
results, err = eng.Search(ctx, "beta_unique", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "beta_unique.go")
}
func TestEngine_EmptyQueryReturnsEmpty(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "something.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results, "empty query should return no results")
}
func TestEngine_CloseIsClean(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
ctx := context.Background()
eng := filefinder.NewEngine(logger)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.Close())
_, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.Error(t, err)
}
func TestEngine_AddRootIdempotent(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
require.NoError(t, eng.AddRoot(ctx, dir))
snapLen := filefinder.EngineSnapLen(eng)
require.Equal(t, 1, snapLen, "expected exactly one root after duplicate add")
}
func TestEngine_RemoveRoot(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "file.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
results, err := eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.NotEmpty(t, results)
require.NoError(t, eng.RemoveRoot(dir))
results, err = eng.Search(ctx, "file", filefinder.DefaultSearchOptions())
require.NoError(t, err)
require.Empty(t, results)
}
func TestEngine_Rebuild(t *testing.T) {
t.Parallel()
dir := t.TempDir()
createFile(t, dir, "original.txt", "data")
eng, ctx := newTestEngine(t)
require.NoError(t, eng.AddRoot(ctx, dir))
createFile(t, dir, "sneaky_rebuild.txt", "hidden")
require.NoError(t, eng.Rebuild(ctx, dir))
results, err := eng.Search(ctx, "sneaky_rebuild", filefinder.DefaultSearchOptions())
require.NoError(t, err)
requireResultHasPath(t, results, "sneaky_rebuild.txt")
}
// createFile creates a file (and parent dirs) at relPath under dir.
func createFile(t *testing.T, dir, relPath, content string) {
t.Helper()
full := filepath.Join(dir, relPath)
require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755))
require.NoError(t, os.WriteFile(full, []byte(content), 0o600))
}
func resultPaths(results []filefinder.Result) []string {
paths := make([]string, len(results))
for i, r := range results {
paths[i] = r.Path
}
sort.Strings(paths)
return paths
}
-85
View File
@@ -1,85 +0,0 @@
package filefinder
// Test helpers that need internal access.
// MakeTestSnapshot builds a Snapshot from a list of paths. Useful for
// query-level tests that don't need a real filesystem.
func MakeTestSnapshot(paths []string) *Snapshot {
idx := NewIndex()
for _, p := range paths {
idx.Add(p, 0)
}
return idx.Snapshot()
}
// BuildTestIndex walks root and returns a populated Index, the same
// way Engine.AddRoot does but without starting a watcher.
func BuildTestIndex(root string) (*Index, error) {
return walkRoot(root)
}
// IndexIsDeleted reports whether the document at id is tombstoned.
func IndexIsDeleted(idx *Index, id uint32) bool {
return idx.deleted[id]
}
// IndexByGramLen returns the number of entries in the trigram index.
func IndexByGramLen(idx *Index) int {
return len(idx.byGram)
}
// IndexByPrefix1Len returns the number of posting-list entries for
// the given single-byte prefix.
func IndexByPrefix1Len(idx *Index, b byte) int {
return len(idx.byPrefix1[b])
}
// SnapshotCount returns the number of documents in a Snapshot.
func SnapshotCount(snap *Snapshot) int {
return len(snap.docs)
}
// EngineSnapLen returns the number of root snapshots currently held
// by the engine, or -1 if the pointer is nil.
func EngineSnapLen(eng *Engine) int {
p := eng.snap.Load()
if p == nil {
return -1
}
return len(*p)
}
// DefaultScoreParamsForTest exposes defaultScoreParams for tests.
var DefaultScoreParamsForTest = defaultScoreParams
// ScoreParamsForTest is a type alias for scoreParams.
type ScoreParamsForTest = scoreParams
// Exported aliases for internal functions used in tests.
var (
NewQueryPlanForTest = newQueryPlan
SearchSnapshotForTest = searchSnapshot
IntersectSortedForTest = intersectSorted
IntersectAllForTest = intersectAll
MergeAndScoreForTest = mergeAndScore
NormalizeQueryForTest = normalizeQuery
NormalizePathBytesForTest = normalizePathBytes
ExtractTrigramsForTest = extractTrigrams
ExtractBasenameForTest = extractBasename
ExtractSegmentsForTest = extractSegments
Prefix1ForTest = prefix1
Prefix2ForTest = prefix2
IsSubsequenceForTest = isSubsequence
LongestContiguousMatchForTest = longestContiguousMatch
IsBoundaryForTest = isBoundary
CountBoundaryHitsForTest = countBoundaryHits
EqualFoldASCIIForTest = equalFoldASCII
ScorePathForTest = scorePath
PackTrigramForTest = packTrigram
)
// Type aliases for internal types used in tests.
type (
CandidateForTest = candidate
QueryPlanForTest = queryPlan
)
-299
View File
@@ -1,299 +0,0 @@
package filefinder
import (
"container/heap"
"slices"
"strings"
)
type candidate struct {
DocID uint32
Path string
BaseOff int
BaseLen int
Depth int
Flags uint16
}
// Result is a scored search result returned to callers.
type Result struct {
Path string
Score float32
IsDir bool
}
type queryPlan struct {
Original string
Normalized string
Tokens [][]byte
Trigrams []uint32
IsShort bool
HasSlash bool
BasenameQ []byte
DirTokens [][]byte
}
func newQueryPlan(q string) *queryPlan {
norm := normalizeQuery(q)
p := &queryPlan{Original: q, Normalized: norm}
if len(norm) == 0 {
p.IsShort = true
return p
}
raw := strings.ReplaceAll(norm, "/", " ")
parts := strings.Fields(raw)
p.HasSlash = strings.ContainsRune(norm, '/')
for _, part := range parts {
p.Tokens = append(p.Tokens, []byte(part))
}
if len(p.Tokens) > 0 {
p.BasenameQ = p.Tokens[len(p.Tokens)-1]
if len(p.Tokens) > 1 {
p.DirTokens = p.Tokens[:len(p.Tokens)-1]
}
}
p.IsShort = true
for _, tok := range p.Tokens {
if len(tok) >= 3 {
p.IsShort = false
break
}
}
if !p.IsShort {
p.Trigrams = extractQueryTrigrams(p.Tokens)
}
return p
}
func extractQueryTrigrams(tokens [][]byte) []uint32 {
seen := make(map[uint32]struct{})
for _, tok := range tokens {
if len(tok) < 3 {
continue
}
for i := 0; i <= len(tok)-3; i++ {
seen[packTrigram(tok[i], tok[i+1], tok[i+2])] = struct{}{}
}
}
if len(seen) == 0 {
return nil
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
return result
}
func packTrigram(a, b, c byte) uint32 {
return uint32(toLowerASCII(a))<<16 | uint32(toLowerASCII(b))<<8 | uint32(toLowerASCII(c))
}
// searchSnapshot runs the full search pipeline against a single
// root snapshot: it selects a strategy (prefix, trigram, or
// fuzzy fallback) based on query length, retrieves candidate
// doc IDs, and converts them into candidate structs.
func searchSnapshot(plan *queryPlan, snap *Snapshot, limit int) []candidate {
if snap == nil || len(snap.docs) == 0 || len(plan.Normalized) == 0 {
return nil
}
var ids []uint32
if plan.IsShort {
ids = searchShort(plan, snap)
} else {
ids = searchTrigrams(plan, snap)
if len(ids) == 0 && len(plan.BasenameQ) > 0 {
ids = searchFuzzyFallback(plan, snap)
}
}
if len(ids) == 0 {
return nil
}
cands := make([]candidate, 0, min(len(ids), limit))
for _, id := range ids {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
d := snap.docs[id]
cands = append(cands, candidate{
DocID: id, Path: d.path, BaseOff: d.baseOff,
BaseLen: d.baseLen, Depth: d.depth, Flags: d.flags,
})
if len(cands) >= limit {
break
}
}
return cands
}
func searchShort(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
if len(plan.BasenameQ) >= 2 {
if ids := snap.byPrefix2[prefix2(plan.BasenameQ)]; len(ids) > 0 {
return ids
}
}
return snap.byPrefix1[prefix1(plan.BasenameQ)]
}
func searchTrigrams(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.Trigrams) == 0 {
return nil
}
lists := make([][]uint32, 0, len(plan.Trigrams))
for _, g := range plan.Trigrams {
ids, ok := snap.byGram[g]
if !ok || len(ids) == 0 {
return nil
}
lists = append(lists, ids)
}
return intersectAll(lists)
}
func searchFuzzyFallback(plan *queryPlan, snap *Snapshot) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
bucket := snap.byPrefix1[prefix1(plan.BasenameQ)]
if len(bucket) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
var ids []uint32
for _, id := range bucket {
if snap.deleted[id] || int(id) >= len(snap.docs) {
continue
}
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, id)
}
}
if len(ids) == 0 {
return searchSubsequenceScan(plan, snap, 5000)
}
return ids
}
func searchSubsequenceScan(plan *queryPlan, snap *Snapshot, maxCheck int) []uint32 {
if len(plan.BasenameQ) == 0 {
return nil
}
var ids []uint32
checked := 0
for id := 0; id < len(snap.docs) && checked < maxCheck; id++ {
uid := uint32(id) //nolint:gosec // Snapshot count is bounded well below 2^32.
if snap.deleted[uid] {
continue
}
checked++
if isSubsequence([]byte(snap.docs[id].path), plan.BasenameQ) {
ids = append(ids, uid)
}
}
return ids
}
func intersectSorted(a, b []uint32) []uint32 {
if len(a) == 0 || len(b) == 0 {
return nil
}
var result []uint32
ai, bi := 0, 0
for ai < len(a) && bi < len(b) {
switch {
case a[ai] < b[bi]:
ai++
case a[ai] > b[bi]:
bi++
default:
result = append(result, a[ai])
ai++
bi++
}
}
return result
}
func intersectAll(lists [][]uint32) []uint32 {
if len(lists) == 0 {
return nil
}
if len(lists) == 1 {
return lists[0]
}
slices.SortFunc(lists, func(a, b []uint32) int { return len(a) - len(b) })
result := lists[0]
for i := 1; i < len(lists) && len(result) > 0; i++ {
result = intersectSorted(result, lists[i])
}
return result
}
func mergeAndScore(cands []candidate, plan *queryPlan, params scoreParams, topK int) []Result {
if topK <= 0 || len(cands) == 0 {
return nil
}
query := []byte(plan.Normalized)
h := &resultHeap{}
heap.Init(h)
for i := range cands {
c := &cands[i]
s := scorePath([]byte(c.Path), c.BaseOff, c.BaseLen, c.Depth, query, plan.Tokens, params)
if s <= 0 {
continue
}
// DirTokenHit is applied here rather than in scorePath because
// it depends on the query plan's directory tokens, which are
// split from the full query during planning. scorePath operates
// on raw query bytes without knowledge of token boundaries.
if len(plan.DirTokens) > 0 {
segments := extractSegments([]byte(c.Path))
for _, dt := range plan.DirTokens {
for _, seg := range segments {
if equalFoldASCII(seg, dt) {
s += params.DirTokenHit
break
}
}
}
}
r := Result{Path: c.Path, Score: s, IsDir: c.Flags == uint16(FlagDir)}
if h.Len() < topK {
heap.Push(h, r)
} else if s > (*h)[0].Score {
(*h)[0] = r
heap.Fix(h, 0)
}
}
n := h.Len()
results := make([]Result, n)
for i := n - 1; i >= 0; i-- {
v := heap.Pop(h)
if r, ok := v.(Result); ok {
results[i] = r
}
}
return results
}
type resultHeap []Result
func (h resultHeap) Len() int { return len(h) }
func (h resultHeap) Less(i, j int) bool { return h[i].Score < h[j].Score }
func (h resultHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *resultHeap) Push(x interface{}) {
r, ok := x.(Result)
if ok {
*h = append(*h, r)
}
}
func (h *resultHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}
-343
View File
@@ -1,343 +0,0 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNewQueryPlan(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
wantNorm string
wantShort bool
wantSlash bool
wantBase string
wantTokens []string
wantDirTok []string
wantTriCnt int // -1 to skip check
}{
{"Simple", "foo", "foo", false, false, "foo", []string{"foo"}, nil, 1},
{"MultiToken", "foo bar", "foo bar", false, false, "bar", []string{"foo", "bar"}, []string{"foo"}, -1},
{"Slash", "internal/foo", "internal/foo", false, true, "foo", []string{"internal", "foo"}, []string{"internal"}, -1},
{"SingleChar", "a", "a", true, false, "a", []string{"a"}, nil, 0},
{"TwoChars", "ab", "ab", true, false, "ab", []string{"ab"}, nil, -1},
{"ThreeChars", "abc", "abc", false, false, "abc", []string{"abc"}, nil, 1},
{"DotPrefix", ".go", ".go", false, false, ".go", []string{".go"}, nil, -1},
{"UpperCase", "FOO", "foo", false, false, "foo", []string{"foo"}, nil, -1},
{"Empty", "", "", true, false, "", nil, nil, -1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest(tt.query)
if plan.Normalized != tt.wantNorm {
t.Errorf("normalized = %q, want %q", plan.Normalized, tt.wantNorm)
}
if plan.IsShort != tt.wantShort {
t.Errorf("isShort = %v, want %v", plan.IsShort, tt.wantShort)
}
if plan.HasSlash != tt.wantSlash {
t.Errorf("hasSlash = %v, want %v", plan.HasSlash, tt.wantSlash)
}
if string(plan.BasenameQ) != tt.wantBase {
t.Errorf("basenameQ = %q, want %q", plan.BasenameQ, tt.wantBase)
}
if tt.wantTokens == nil {
if len(plan.Tokens) != 0 {
t.Errorf("expected 0 tokens, got %d", len(plan.Tokens))
}
} else {
if len(plan.Tokens) != len(tt.wantTokens) {
t.Fatalf("tokens len = %d, want %d", len(plan.Tokens), len(tt.wantTokens))
}
for i, tok := range plan.Tokens {
if string(tok) != tt.wantTokens[i] {
t.Errorf("tokens[%d] = %q, want %q", i, tok, tt.wantTokens[i])
}
}
}
if tt.wantDirTok != nil {
if len(plan.DirTokens) != len(tt.wantDirTok) {
t.Fatalf("dirTokens len = %d, want %d", len(plan.DirTokens), len(tt.wantDirTok))
}
for i, tok := range plan.DirTokens {
if string(tok) != tt.wantDirTok[i] {
t.Errorf("dirTokens[%d] = %q, want %q", i, tok, tt.wantDirTok[i])
}
}
}
if tt.wantTriCnt >= 0 && len(plan.Trigrams) != tt.wantTriCnt {
t.Errorf("trigram count = %d, want %d", len(plan.Trigrams), tt.wantTriCnt)
}
})
}
// ThreeChars: verify the actual trigram value.
plan := filefinder.NewQueryPlanForTest("abc")
if want := filefinder.PackTrigramForTest('a', 'b', 'c'); plan.Trigrams[0] != want {
t.Errorf("trigram = %x, want %x", plan.Trigrams[0], want)
}
// ShortMultiToken: both tokens < 3 chars so isShort should be true.
plan = filefinder.NewQueryPlanForTest("ab cd")
if !plan.IsShort {
t.Error("expected isShort=true when all tokens < 3 chars")
}
// One token >= 3 chars, so isShort should be false.
plan = filefinder.NewQueryPlanForTest("ab cde")
if plan.IsShort {
t.Error("expected isShort=false when any token >= 3 chars")
}
}
func requireCandHasPath(t *testing.T, cands []filefinder.CandidateForTest, path string) {
t.Helper()
for _, c := range cands {
if c.Path == path {
return
}
}
t.Errorf("expected to find %q in candidates", path)
}
func TestSearchSnapshot_TrigramMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'handler'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_ShortQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go", "bar.go", "fab.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("fo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'fo'")
}
requireCandHasPath(t, cands, "foo.go")
}
func TestSearchSnapshot_FuzzyFallback(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/router.go", "lib/utils.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("hndlr"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected fuzzy fallback to find 'handler.go' for query 'hndlr'")
}
requireCandHasPath(t, cands, "src/handler.go")
}
func TestSearchSnapshot_FuzzyFallbackNoFirstCharMatch(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/xylophone.go", "lib/extra.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("xylo"), snap, 100)
if len(cands) == 0 {
t.Fatal("expected at least 1 candidate for 'xylo'")
}
requireCandHasPath(t, cands, "src/xylophone.go")
}
func TestSearchSnapshot_NilSnapshot(t *testing.T) {
t.Parallel()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("foo"), nil, 100)
if cands != nil {
t.Errorf("expected nil for nil snapshot, got %v", cands)
}
}
func TestSearchSnapshot_EmptyQuery(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"foo.go"})
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest(""), snap, 100)
if cands != nil {
t.Errorf("expected nil for empty query, got %v", cands)
}
}
func TestSearchSnapshot_DeletedDocsExcluded(t *testing.T) {
t.Parallel()
idx := filefinder.NewIndex()
idx.Add("handler.go", 0)
idx.Remove("handler.go")
snap := idx.Snapshot()
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 100)
for _, c := range cands {
if c.Path == "handler.go" {
t.Error("deleted doc should not appear in results")
}
}
}
func TestSearchSnapshot_Limit(t *testing.T) {
t.Parallel()
paths := make([]string, 50)
for i := range paths {
paths[i] = "handler" + string(rune('a'+i%26)) + ".go"
}
snap := filefinder.MakeTestSnapshot(paths)
cands := filefinder.SearchSnapshotForTest(filefinder.NewQueryPlanForTest("handler"), snap, 3)
if len(cands) > 3 {
t.Errorf("expected at most 3 candidates, got %d", len(cands))
}
}
func TestIntersectSorted(t *testing.T) {
t.Parallel()
tests := []struct {
name string
a, b []uint32
want []uint32
}{
{"both empty", nil, nil, nil},
{"a empty", nil, []uint32{1, 2}, nil},
{"b empty", []uint32{1, 2}, nil, nil},
{"no overlap", []uint32{1, 3, 5}, []uint32{2, 4, 6}, nil},
{"full overlap", []uint32{1, 2, 3}, []uint32{1, 2, 3}, []uint32{1, 2, 3}},
{"partial overlap", []uint32{1, 2, 3, 5}, []uint32{2, 4, 5}, []uint32{2, 5}},
{"single match", []uint32{1, 2, 3}, []uint32{2}, []uint32{2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectSortedForTest(tt.a, tt.b)
if len(tt.want) == 0 {
if len(got) != 0 {
t.Errorf("got %v, want empty/nil", got)
}
return
}
if !slices.Equal(got, tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestIntersectAll(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest(nil); got != nil {
t.Errorf("got %v, want nil", got)
}
})
t.Run("single", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3}}); len(got) != 3 {
t.Fatalf("len = %d, want 3", len(got))
}
})
t.Run("multiple", func(t *testing.T) {
t.Parallel()
got := filefinder.IntersectAllForTest([][]uint32{{1, 2, 3, 4, 5}, {2, 3, 5}, {3, 5, 7}})
if !slices.Equal(got, []uint32{3, 5}) {
t.Errorf("got %v, want [3 5]", got)
}
})
t.Run("no overlap", func(t *testing.T) {
t.Parallel()
if got := filefinder.IntersectAllForTest([][]uint32{{1, 2}, {3, 4}}); got != nil {
t.Errorf("got %v, want nil", got)
}
})
}
func TestMergeAndScore_SortedDescending(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
params := filefinder.DefaultScoreParamsForTest()
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "a/b/c/d/e/foo", BaseOff: 10, BaseLen: 3, Depth: 5},
{DocID: 1, Path: "src/foo", BaseOff: 4, BaseLen: 3, Depth: 1},
{DocID: 2, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0},
}
results := filefinder.MergeAndScoreForTest(cands, plan, params, 10)
if len(results) == 0 {
t.Fatal("expected non-empty results")
}
for i := 1; i < len(results); i++ {
if results[i].Score > results[i-1].Score {
t.Errorf("results not sorted: [%d].Score=%f > [%d].Score=%f",
i, results[i].Score, i-1, results[i-1].Score)
}
}
}
func TestMergeAndScore_TopKLimit(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("f")
params := filefinder.DefaultScoreParamsForTest()
var cands []filefinder.CandidateForTest
for i := range 20 {
p := "f" + string(rune('a'+i))
cands = append(cands, filefinder.CandidateForTest{DocID: uint32(i), Path: p, BaseOff: 0, BaseLen: len(p), Depth: 0}) //nolint:gosec // test index is tiny
}
if results := filefinder.MergeAndScoreForTest(cands, plan, params, 5); len(results) != 5 {
t.Errorf("expected 5 results, got %d", len(results))
}
}
func TestMergeAndScore_ZeroTopK(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0}}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 0); len(results) != 0 {
t.Errorf("expected 0 results for topK=0, got %d", len(results))
}
}
func TestMergeAndScore_NoMatchCandidatesDropped(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("xyz")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "abc", BaseOff: 0, BaseLen: 3, Depth: 0},
{DocID: 1, Path: "def", BaseOff: 0, BaseLen: 3, Depth: 0},
}
if results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for non-matching candidates, got %d", len(results))
}
}
func TestMergeAndScore_IsDirFlag(t *testing.T) {
t.Parallel()
plan := filefinder.NewQueryPlanForTest("foo")
cands := []filefinder.CandidateForTest{
{DocID: 0, Path: "foo", BaseOff: 0, BaseLen: 3, Depth: 0, Flags: uint16(filefinder.FlagDir)},
}
results := filefinder.MergeAndScoreForTest(cands, plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if !results[0].IsDir {
t.Error("expected IsDir=true for FlagDir candidate")
}
}
func TestMergeAndScore_EmptyCandidates(t *testing.T) {
t.Parallel()
if results := filefinder.MergeAndScoreForTest(nil, filefinder.NewQueryPlanForTest("foo"), filefinder.DefaultScoreParamsForTest(), 10); len(results) != 0 {
t.Errorf("expected 0 results for nil candidates, got %d", len(results))
}
}
func TestSearchSnapshot_FuzzyFallbackEndToEnd(t *testing.T) {
t.Parallel()
snap := filefinder.MakeTestSnapshot([]string{"src/handler.go", "src/middleware.go", "pkg/config.go"})
plan := filefinder.NewQueryPlanForTest("hndlr")
results := filefinder.MergeAndScoreForTest(filefinder.SearchSnapshotForTest(plan, snap, 100), plan, filefinder.DefaultScoreParamsForTest(), 10)
if len(results) == 0 {
t.Fatal("expected fuzzy fallback to produce scored results for 'hndlr'")
}
if results[0].Path != "src/handler.go" {
t.Errorf("expected top result 'src/handler.go', got %q", results[0].Path)
}
}
-288
View File
@@ -1,288 +0,0 @@
package filefinder
import "slices"
func toLowerASCII(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
func normalizeQuery(q string) string {
b := make([]byte, 0, len(q))
prevSpace := true
for i := 0; i < len(q); i++ {
c := q[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == ' ' {
if prevSpace {
continue
}
prevSpace = true
} else {
prevSpace = false
}
b = append(b, c)
}
if len(b) > 0 && b[len(b)-1] == ' ' {
b = b[:len(b)-1]
}
return string(b)
}
func normalizePathBytes(p []byte) []byte {
j := 0
prevSlash := false
for i := 0; i < len(p); i++ {
c := p[i]
if c == '\\' {
c = '/'
}
c = toLowerASCII(c)
if c == '/' {
if prevSlash {
continue
}
prevSlash = true
} else {
prevSlash = false
}
p[j] = c
j++
}
return p[:j]
}
// extractTrigrams returns deduplicated, sorted trigrams (three-byte
// subsequences) from s. Trigrams are the primary index key: a
// document matches a query only if every query trigram appears in
// the document, giving O(1) candidate filtering per trigram.
func extractTrigrams(s []byte) []uint32 {
if len(s) < 3 {
return nil
}
seen := make(map[uint32]struct{}, len(s))
for i := 0; i <= len(s)-3; i++ {
b0 := toLowerASCII(s[i])
b1 := toLowerASCII(s[i+1])
b2 := toLowerASCII(s[i+2])
gram := uint32(b0)<<16 | uint32(b1)<<8 | uint32(b2)
seen[gram] = struct{}{}
}
result := make([]uint32, 0, len(seen))
for g := range seen {
result = append(result, g)
}
slices.Sort(result)
return result
}
func extractBasename(path []byte) (offset int, length int) {
end := len(path)
if end > 0 && path[end-1] == '/' {
end--
}
if end == 0 {
return 0, 0
}
i := end - 1
for i >= 0 && path[i] != '/' {
i--
}
start := i + 1
return start, end - start
}
func extractSegments(path []byte) [][]byte {
var segments [][]byte
start := 0
for i := 0; i <= len(path); i++ {
if i == len(path) || path[i] == '/' {
if i > start {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
return segments
}
func prefix1(name []byte) byte {
if len(name) == 0 {
return 0
}
return toLowerASCII(name[0])
}
func prefix2(name []byte) uint16 {
if len(name) == 0 {
return 0
}
hi := uint16(toLowerASCII(name[0])) << 8
if len(name) < 2 {
return hi
}
return hi | uint16(toLowerASCII(name[1]))
}
// scoreParams controls the weights for each scoring signal.
type scoreParams struct {
BasenameMatch float32
BasenamePrefix float32
ExactSegment float32
BoundaryHit float32
ContiguousRun float32
DirTokenHit float32
DepthPenalty float32
LengthPenalty float32
}
func defaultScoreParams() scoreParams {
return scoreParams{
BasenameMatch: 6.0,
BasenamePrefix: 3.5,
ExactSegment: 2.5,
BoundaryHit: 1.8,
ContiguousRun: 1.2,
DirTokenHit: 0.4,
DepthPenalty: 0.08,
LengthPenalty: 0.01,
}
}
func isSubsequence(haystack, needle []byte) bool {
if len(needle) == 0 {
return true
}
ni := 0
for _, hb := range haystack {
if toLowerASCII(hb) == toLowerASCII(needle[ni]) {
ni++
if ni == len(needle) {
return true
}
}
}
return false
}
func longestContiguousMatch(haystack, needle []byte) int {
if len(needle) == 0 || len(haystack) == 0 {
return 0
}
best := 0
ni := 0
run := 0
for _, hb := range haystack {
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run++
ni++
if run > best {
best = run
}
} else {
run = 0
ni = 0
if ni < len(needle) && toLowerASCII(hb) == toLowerASCII(needle[ni]) {
run = 1
ni = 1
if run > best {
best = run
}
}
}
}
return best
}
func isBoundary(b byte) bool {
return b == '/' || b == '.' || b == '_' || b == '-'
}
func countBoundaryHits(path []byte, query []byte) int {
if len(query) == 0 || len(path) == 0 {
return 0
}
hits := 0
qi := 0
for pi := 0; pi < len(path) && qi < len(query); pi++ {
atBoundary := pi == 0 || isBoundary(path[pi-1])
if atBoundary && toLowerASCII(path[pi]) == toLowerASCII(query[qi]) {
hits++
qi++
}
}
return hits
}
func equalFoldASCII(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if toLowerASCII(a[i]) != toLowerASCII(b[i]) {
return false
}
}
return true
}
func hasPrefixFoldASCII(haystack, prefix []byte) bool {
if len(prefix) > len(haystack) {
return false
}
for i := range prefix {
if toLowerASCII(haystack[i]) != toLowerASCII(prefix[i]) {
return false
}
}
return true
}
// scorePath computes a relevance score for a candidate path
// against a query. The score combines several signals:
// basename match, basename prefix, exact segment match,
// word-boundary hits, longest contiguous run, and penalties
// for depth and length. A return value of 0 means no match
// (the query is not a subsequence of the path).
func scorePath(
path []byte,
baseOff int,
baseLen int,
depth int,
query []byte,
queryTokens [][]byte,
params scoreParams,
) float32 {
if !isSubsequence(path, query) {
return 0
}
var score float32
basename := path[baseOff : baseOff+baseLen]
if isSubsequence(basename, query) {
score += params.BasenameMatch
}
if hasPrefixFoldASCII(basename, query) {
score += params.BasenamePrefix
}
segments := extractSegments(path)
for _, token := range queryTokens {
for _, seg := range segments {
if equalFoldASCII(seg, token) {
score += params.ExactSegment
break
}
}
}
bh := countBoundaryHits(path, query)
score += float32(bh) * params.BoundaryHit
lcm := longestContiguousMatch(path, query)
score += float32(lcm) * params.ContiguousRun
score -= float32(depth) * params.DepthPenalty
score -= float32(len(path)) * params.LengthPenalty
return score
}
-388
View File
@@ -1,388 +0,0 @@
package filefinder_test
import (
"slices"
"testing"
"github.com/coder/coder/v2/agent/filefinder"
)
func TestNormalizeQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"empty", "", ""},
{"leading and trailing spaces", " hello ", "hello"},
{"multiple internal spaces", "foo bar baz", "foo bar baz"},
{"uppercase to lower", "FooBar", "foobar"},
{"backslash to slash", `foo\bar\baz`, "foo/bar/baz"},
{"mixed case and spaces", " Hello World ", "hello world"},
{"unicode passthrough", "héllo wörld", "héllo wörld"},
{"only spaces", " ", ""},
{"single char", "A", "a"},
{"slashes preserved", "/foo/bar/", "/foo/bar/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.NormalizeQueryForTest(tt.input)
if got != tt.want {
t.Errorf("normalizeQuery(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestExtractTrigrams(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want []uint32
}{
{"too short", "ab", nil},
{"exactly three bytes", "abc", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"case insensitive", "ABC", []uint32{uint32('a')<<16 | uint32('b')<<8 | uint32('c')}},
{"deduplication", "aaaa", []uint32{uint32('a')<<16 | uint32('a')<<8 | uint32('a')}},
{"four bytes produces two trigrams", "abcd", []uint32{
uint32('a')<<16 | uint32('b')<<8 | uint32('c'),
uint32('b')<<16 | uint32('c')<<8 | uint32('d'),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractTrigramsForTest([]byte(tt.input))
if !slices.Equal(got, tt.want) {
t.Errorf("extractTrigrams(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestExtractBasename(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
wantOff int
wantName string
}{
{"full path", "/foo/bar/baz.go", 9, "baz.go"},
{"bare filename", "baz.go", 0, "baz.go"},
{"trailing slash", "/a/b/", 3, "b"},
{"root slash", "/", 0, ""},
{"empty", "", 0, ""},
{"single dir with slash", "/foo", 1, "foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
off, length := filefinder.ExtractBasenameForTest([]byte(tt.path))
if off != tt.wantOff {
t.Errorf("extractBasename(%q) offset = %d, want %d", tt.path, off, tt.wantOff)
}
gotName := string([]byte(tt.path)[off : off+length])
if gotName != tt.wantName {
t.Errorf("extractBasename(%q) name = %q, want %q", tt.path, gotName, tt.wantName)
}
})
}
}
func TestExtractSegments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want []string
}{
{"absolute path", "/foo/bar/baz", []string{"foo", "bar", "baz"}},
{"relative path", "foo/bar", []string{"foo", "bar"}},
{"trailing slash", "/a/b/", []string{"a", "b"}},
{"multiple slashes", "//a///b//", []string{"a", "b"}},
{"empty", "", nil},
{"single segment", "foo", []string{"foo"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.ExtractSegmentsForTest([]byte(tt.path))
if len(got) != len(tt.want) {
t.Fatalf("extractSegments(%q) got %d segments, want %d", tt.path, len(got), len(tt.want))
}
for i := range got {
if string(got[i]) != tt.want[i] {
t.Errorf("extractSegments(%q)[%d] = %q, want %q", tt.path, i, got[i], tt.want[i])
}
}
})
}
}
func TestPrefix1(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want byte
}{
{"lowercase", "foo", 'f'},
{"uppercase", "Foo", 'f'},
{"empty", "", 0},
{"digit", "1abc", '1'},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix1ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix1(%q) = %d (%c), want %d (%c)", tt.in, got, got, tt.want, tt.want)
}
})
}
}
func TestPrefix2(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want uint16
}{
{"two chars", "ab", uint16('a')<<8 | uint16('b')},
{"uppercase", "AB", uint16('a')<<8 | uint16('b')},
{"single char", "A", uint16('a') << 8},
{"empty", "", 0},
{"longer string", "Hello", uint16('h')<<8 | uint16('e')},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.Prefix2ForTest([]byte(tt.in))
if got != tt.want {
t.Errorf("prefix2(%q) = %d, want %d", tt.in, got, tt.want)
}
})
}
}
func TestNormalizePathBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{"backslash to slash", `C:\Users\test`, "c:/users/test"},
{"collapse slashes", "//foo///bar//", "/foo/bar/"},
{"lowercase", "FooBar", "foobar"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
buf := []byte(tt.input)
got := string(filefinder.NormalizePathBytesForTest(buf))
if got != tt.want {
t.Errorf("normalizePathBytes(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestIsSubsequence(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want bool
}{
{"empty needle", "anything", "", true},
{"empty both", "", "", true},
{"empty haystack", "", "a", false},
{"exact match", "abc", "abc", true},
{"scattered", "axbycz", "abc", true},
{"prefix", "abcdef", "abc", true},
{"suffix", "xyzabc", "abc", true},
{"case insensitive", "AbCdEf", "ace", true},
{"case insensitive reverse", "abcdef", "ACE", true},
{"no match", "abcdef", "xyz", false},
{"partial match", "abcdef", "abz", false},
{"longer needle", "ab", "abc", false},
{"single char match", "hello", "l", true},
{"single char no match", "hello", "z", false},
{"path like", "src/internal/foo.go", "sif", true},
{"path like no match", "src/internal/foo.go", "zzz", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.IsSubsequenceForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("isSubsequence(%q, %q) = %v, want %v", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestLongestContiguousMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
haystack string
needle string
want int
}{
{"empty needle", "abc", "", 0},
{"empty haystack", "", "abc", 0},
{"full match", "abc", "abc", 3},
{"prefix match", "abcdef", "abc", 3},
{"middle match", "xxabcyy", "abc", 3},
{"suffix match", "xxabc", "abc", 3},
{"partial", "axbc", "abc", 1},
{"scattered no contiguous", "axbxcx", "abc", 1},
{"case insensitive", "ABCdef", "abc", 3},
{"no match", "xyz", "abc", 0},
{"single char", "abc", "b", 1},
{"repeated", "aababc", "abc", 3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.LongestContiguousMatchForTest([]byte(tt.haystack), []byte(tt.needle))
if got != tt.want {
t.Errorf("longestContiguousMatch(%q, %q) = %d, want %d", tt.haystack, tt.needle, got, tt.want)
}
})
}
}
func TestIsBoundary(t *testing.T) {
t.Parallel()
for _, b := range []byte{'/', '.', '_', '-'} {
if !filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = false, want true", b)
}
}
for _, b := range []byte{'a', 'Z', '0', ' ', '('} {
if filefinder.IsBoundaryForTest(b) {
t.Errorf("isBoundary(%q) = true, want false", b)
}
}
}
func TestCountBoundaryHits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
query string
want int
}{
{"start of string", "foo/bar", "f", 1},
{"after slash", "foo/bar", "fb", 2},
{"after dot", "foo.bar", "fb", 2},
{"after underscore", "foo_bar", "fb", 2},
{"no hits", "xxxx", "y", 0},
{"empty query", "foo", "", 0},
{"empty path", "", "f", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := filefinder.CountBoundaryHitsForTest([]byte(tt.path), []byte(tt.query))
if got != tt.want {
t.Errorf("countBoundaryHits(%q, %q) = %d, want %d", tt.path, tt.query, got, tt.want)
}
})
}
}
func TestScorePath_NoSubsequenceReturnsZero(t *testing.T) {
t.Parallel()
path := []byte("src/internal/handler.go")
query := []byte("zzz")
tokens := [][]byte{[]byte("zzz")}
params := filefinder.DefaultScoreParamsForTest()
s := filefinder.ScorePathForTest(path, 13, 10, 2, query, tokens, params)
if s != 0 {
t.Errorf("expected 0 for no subsequence match, got %f", s)
}
}
func TestScorePath_ExactBasenameOverPartial(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("main")
tokens := [][]byte{query}
pathExact := []byte("src/main")
scoreExact := filefinder.ScorePathForTest(pathExact, 4, 4, 1, query, tokens, params)
pathPartial := []byte("module/amazing")
scorePartial := filefinder.ScorePathForTest(pathPartial, 7, 7, 1, query, tokens, params)
if scoreExact <= scorePartial {
t.Errorf("exact basename (%f) should score higher than partial (%f)", scoreExact, scorePartial)
}
}
func TestScorePath_BasenamePrefixOverScattered(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("han")
tokens := [][]byte{query}
pathPrefix := []byte("src/handler.go")
scorePrefix := filefinder.ScorePathForTest(pathPrefix, 4, 10, 1, query, tokens, params)
pathScattered := []byte("has/another/thing")
scoreScattered := filefinder.ScorePathForTest(pathScattered, 12, 5, 2, query, tokens, params)
if scorePrefix <= scoreScattered {
t.Errorf("basename prefix (%f) should score higher than scattered (%f)", scorePrefix, scoreScattered)
}
}
func TestScorePath_ShallowOverDeep(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShallow := []byte("src/foo.go")
scoreShallow := filefinder.ScorePathForTest(pathShallow, 4, 6, 1, query, tokens, params)
pathDeep := []byte("a/b/c/d/e/foo.go")
scoreDeep := filefinder.ScorePathForTest(pathDeep, 10, 6, 5, query, tokens, params)
if scoreShallow <= scoreDeep {
t.Errorf("shallow path (%f) should score higher than deep (%f)", scoreShallow, scoreDeep)
}
}
func TestScorePath_ShorterOverLongerSameMatch(t *testing.T) {
t.Parallel()
params := filefinder.DefaultScoreParamsForTest()
query := []byte("foo")
tokens := [][]byte{query}
pathShort := []byte("x/foo")
scoreShort := filefinder.ScorePathForTest(pathShort, 2, 3, 1, query, tokens, params)
pathLong := []byte("x/foo_extremely_long_suffix_name")
scoreLong := filefinder.ScorePathForTest(pathLong, 2, 29, 1, query, tokens, params)
if scoreShort <= scoreLong {
t.Errorf("shorter path (%f) should score higher than longer (%f)", scoreShort, scoreLong)
}
}
func BenchmarkScorePath(b *testing.B) {
path := []byte("src/internal/coderd/database/queries/workspaces.sql")
query := []byte("workspace")
tokens := [][]byte{query}
params := filefinder.DefaultScoreParamsForTest()
baseOff, baseLen := filefinder.ExtractBasenameForTest(path)
s := filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
if s == 0 {
b.Fatal("expected non-zero score for benchmark path")
}
b.ResetTimer()
for b.Loop() {
filefinder.ScorePathForTest(path, baseOff, baseLen, 4, query, tokens, params)
}
}
-210
View File
@@ -1,210 +0,0 @@
package filefinder
import (
"context"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"cdr.dev/slog/v3"
)
// FSEvent represents a filesystem change event.
type FSEvent struct {
Op FSEventOp
Path string
IsDir bool
}
// FSEventOp represents the type of filesystem operation.
type FSEventOp uint8
// Filesystem operations reported by the watcher.
const (
OpCreate FSEventOp = iota
OpRemove
OpRename
OpModify
)
var skipDirs = map[string]struct{}{
".git": {}, "node_modules": {}, ".hg": {}, ".svn": {},
"__pycache__": {}, ".cache": {}, ".venv": {}, "vendor": {}, ".terraform": {},
}
type fsWatcher struct {
w *fsnotify.Watcher
root string
events chan []FSEvent
logger slog.Logger
mu sync.Mutex
closed bool
done chan struct{}
}
func newFSWatcher(root string, logger slog.Logger) (*fsWatcher, error) {
w, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return &fsWatcher{
w: w,
root: root,
events: make(chan []FSEvent, 64),
logger: logger,
done: make(chan struct{}),
}, nil
}
func (fw *fsWatcher) Start(ctx context.Context) {
initEvents := fw.addRecursive(fw.root)
if len(initEvents) > 0 {
select {
case fw.events <- initEvents:
case <-ctx.Done():
return
}
}
fw.logger.Debug(ctx, "fs watcher started", slog.F("root", fw.root))
go fw.loop(ctx)
}
func (fw *fsWatcher) Events() <-chan []FSEvent { return fw.events }
func (fw *fsWatcher) Close() error {
fw.mu.Lock()
if fw.closed {
fw.mu.Unlock()
return nil
}
fw.closed = true
fw.mu.Unlock()
err := fw.w.Close()
<-fw.done
return err
}
func (fw *fsWatcher) loop(ctx context.Context) {
defer close(fw.done)
const batchWindow = 50 * time.Millisecond
var (
batch []FSEvent
seen = make(map[string]struct{})
timer *time.Timer
timerC <-chan time.Time
)
flush := func() {
if len(batch) == 0 {
return
}
select {
case fw.events <- batch:
default:
fw.logger.Warn(ctx, "fs watcher dropping batch", slog.F("count", len(batch)))
}
batch = nil
seen = make(map[string]struct{})
if timer != nil {
timer.Stop()
}
timer = nil
timerC = nil
}
addToBatch := func(ev FSEvent) {
if _, dup := seen[ev.Path]; dup {
return
}
seen[ev.Path] = struct{}{}
batch = append(batch, ev)
if timer == nil {
timer = time.NewTimer(batchWindow)
timerC = timer.C
}
}
for {
select {
case <-ctx.Done():
flush()
return
case ev, ok := <-fw.w.Events:
if !ok {
flush()
return
}
fsev := translateEvent(ev)
if fsev == nil {
continue
}
if fsev.IsDir && fsev.Op == OpCreate {
for _, s := range fw.addRecursive(fsev.Path) {
addToBatch(s)
}
}
addToBatch(*fsev)
case err, ok := <-fw.w.Errors:
if !ok {
flush()
return
}
fw.logger.Warn(ctx, "fsnotify watcher error", slog.Error(err))
case <-timerC:
flush()
}
}
}
func (fw *fsWatcher) addRecursive(dir string) []FSEvent {
var events []FSEvent
_ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil //nolint:nilerr // best-effort
}
base := filepath.Base(path)
if _, skip := skipDirs[base]; skip && info.IsDir() {
return filepath.SkipDir
}
if info.IsDir() {
if addErr := fw.w.Add(path); addErr != nil {
fw.logger.Debug(context.Background(), "failed to add watch",
slog.F("path", path), slog.Error(addErr))
}
if path != dir {
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: true})
}
return nil
}
events = append(events, FSEvent{Op: OpCreate, Path: path, IsDir: false})
return nil
})
return events
}
func translateEvent(ev fsnotify.Event) *FSEvent {
var op FSEventOp
switch {
case ev.Op&fsnotify.Create != 0:
op = OpCreate
case ev.Op&fsnotify.Remove != 0:
op = OpRemove
case ev.Op&fsnotify.Rename != 0:
op = OpRename
case ev.Op&fsnotify.Write != 0:
op = OpModify
default:
return nil
}
isDir := false
if op == OpCreate || op == OpModify {
fi, err := os.Lstat(ev.Name)
if err == nil {
isDir = fi.IsDir()
}
}
if isDir {
if _, skip := skipDirs[filepath.Base(ev.Name)]; skip {
return nil
}
}
return &FSEvent{Op: op, Path: ev.Name, IsDir: isDir}
}
+330 -544
View File
File diff suppressed because it is too large Load Diff
+1 -20
View File
@@ -436,7 +436,7 @@ message CreateSubAgentRequest {
}
repeated DisplayApp display_apps = 6;
optional bytes id = 7;
}
@@ -494,24 +494,6 @@ message ReportBoundaryLogsRequest {
message ReportBoundaryLogsResponse {}
// UpdateAppStatusRequest updates the given Workspace App's status. c.f. agentsdk.PatchAppStatus
message UpdateAppStatusRequest {
string slug = 1;
enum AppStatusState {
WORKING = 0;
IDLE = 1;
COMPLETE = 2;
FAILURE = 3;
}
AppStatusState state = 2;
string message = 3;
string uri = 4;
}
message UpdateAppStatusResponse {}
service Agent {
rpc GetManifest(GetManifestRequest) returns (Manifest);
rpc GetServiceBanner(GetServiceBannerRequest) returns (ServiceBanner);
@@ -530,5 +512,4 @@ service Agent {
rpc DeleteSubAgent(DeleteSubAgentRequest) returns (DeleteSubAgentResponse);
rpc ListSubAgents(ListSubAgentsRequest) returns (ListSubAgentsResponse);
rpc ReportBoundaryLogs(ReportBoundaryLogsRequest) returns (ReportBoundaryLogsResponse);
rpc UpdateAppStatus(UpdateAppStatusRequest) returns (UpdateAppStatusResponse);
}
+1 -41
View File
@@ -56,7 +56,6 @@ type DRPCAgentClient interface {
DeleteSubAgent(ctx context.Context, in *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(ctx context.Context, in *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type drpcAgentClient struct {
@@ -222,15 +221,6 @@ func (c *drpcAgentClient) ReportBoundaryLogs(ctx context.Context, in *ReportBoun
return out, nil
}
func (c *drpcAgentClient) UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
out := new(UpdateAppStatusResponse)
err := c.cc.Invoke(ctx, "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCAgentServer interface {
GetManifest(context.Context, *GetManifestRequest) (*Manifest, error)
GetServiceBanner(context.Context, *GetServiceBannerRequest) (*ServiceBanner, error)
@@ -249,7 +239,6 @@ type DRPCAgentServer interface {
DeleteSubAgent(context.Context, *DeleteSubAgentRequest) (*DeleteSubAgentResponse, error)
ListSubAgents(context.Context, *ListSubAgentsRequest) (*ListSubAgentsResponse, error)
ReportBoundaryLogs(context.Context, *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
type DRPCAgentUnimplementedServer struct{}
@@ -322,13 +311,9 @@ func (s *DRPCAgentUnimplementedServer) ReportBoundaryLogs(context.Context, *Repo
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCAgentUnimplementedServer) UpdateAppStatus(context.Context, *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCAgentDescription struct{}
func (DRPCAgentDescription) NumMethods() int { return 18 }
func (DRPCAgentDescription) NumMethods() int { return 17 }
func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n {
@@ -485,15 +470,6 @@ func (DRPCAgentDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*ReportBoundaryLogsRequest),
)
}, DRPCAgentServer.ReportBoundaryLogs, true
case 17:
return "/coder.agent.v2.Agent/UpdateAppStatus", drpcEncoding_File_agent_proto_agent_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCAgentServer).
UpdateAppStatus(
ctx,
in1.(*UpdateAppStatusRequest),
)
}, DRPCAgentServer.UpdateAppStatus, true
default:
return "", nil, nil, nil, false
}
@@ -774,19 +750,3 @@ func (x *drpcAgent_ReportBoundaryLogsStream) SendAndClose(m *ReportBoundaryLogsR
}
return x.CloseSend()
}
type DRPCAgent_UpdateAppStatusStream interface {
drpc.Stream
SendAndClose(*UpdateAppStatusResponse) error
}
type drpcAgent_UpdateAppStatusStream struct {
drpc.Stream
}
func (x *drpcAgent_UpdateAppStatusStream) SendAndClose(m *UpdateAppStatusResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_agent_proto_agent_proto{}); err != nil {
return err
}
return x.CloseSend()
}
+3 -7
View File
@@ -73,13 +73,9 @@ type DRPCAgentClient27 interface {
ReportBoundaryLogs(ctx context.Context, in *ReportBoundaryLogsRequest) (*ReportBoundaryLogsResponse, error)
}
// DRPCAgentClient28 is the Agent API at v2.8. It adds
// - a SubagentId field to the WorkspaceAgentDevcontainer message
// - an Id field to the CreateSubAgentRequest message.
// - UpdateAppStatus RPC.
//
// Compatible with Coder v2.31+
// DRPCAgentClient28 is the Agent API at v2.8. It adds a SubagentId field to the
// WorkspaceAgentDevcontainer message, and a Id field to the CreateSubAgentRequest
// message. Compatible with Coder v2.31+
type DRPCAgentClient28 interface {
DRPCAgentClient27
UpdateAppStatus(ctx context.Context, in *UpdateAppStatusRequest) (*UpdateAppStatusResponse, error)
}
+21 -25
View File
@@ -3,11 +3,11 @@
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true,
"defaultBranch": "main",
"defaultBranch": "main"
},
"files": {
"includes": ["**", "!**/pnpm-lock.yaml"],
"ignoreUnknown": true,
"ignoreUnknown": true
},
"linter": {
"rules": {
@@ -15,18 +15,18 @@
"noSvgWithoutTitle": "off",
"useButtonType": "off",
"useSemanticElements": "off",
"noStaticElementInteractions": "off",
"noStaticElementInteractions": "off"
},
"correctness": {
"noUnusedImports": "warn",
"correctness": {
"noUnusedImports": "warn",
"useUniqueElementIds": "off", // TODO: This is new but we want to fix it
"noNestedComponentDefinitions": "off", // TODO: Investigate, since it is used by shadcn components
"noUnusedVariables": {
"level": "warn",
"noUnusedVariables": {
"level": "warn",
"options": {
"ignoreRestSiblings": true,
},
},
"ignoreRestSiblings": true
}
}
},
"style": {
"noNonNullAssertion": "off",
@@ -45,10 +45,6 @@
"level": "error",
"options": {
"paths": {
"react": {
"message": "React 19 no longer requires forwardRef. Use ref as a prop instead.",
"importNames": ["forwardRef"],
},
// "@mui/material/Alert": "Use components/Alert/Alert instead.",
// "@mui/material/AlertTitle": "Use components/Alert/Alert instead.",
// "@mui/material/Autocomplete": "Use shadcn/ui Combobox instead.",
@@ -115,10 +111,10 @@
"@emotion/styled": "Use Tailwind CSS instead.",
// "@emotion/cache": "Use Tailwind CSS instead.",
// "components/Stack/Stack": "Use Tailwind flex utilities instead (e.g., <div className='flex flex-col gap-4'>).",
"lodash": "Use lodash/<name> instead.",
},
},
},
"lodash": "Use lodash/<name> instead."
}
}
}
},
"suspicious": {
"noArrayIndexKey": "off",
@@ -129,14 +125,14 @@
"noConsole": {
"level": "error",
"options": {
"allow": ["error", "info", "warn"],
},
},
"allow": ["error", "info", "warn"]
}
}
},
"complexity": {
"noImportantStyles": "off", // TODO: check and fix !important styles
},
},
"noImportantStyles": "off" // TODO: check and fix !important styles
}
}
},
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json"
}
+3 -9
View File
@@ -30,15 +30,9 @@ func RichParameter(inv *serpent.Invocation, templateVersionParameter codersdk.Te
_, _ = fmt.Fprint(inv.Stdout, "\033[1A")
var defaults []string
defaultSource := defaultValue
if defaultSource == "" {
defaultSource = templateVersionParameter.DefaultValue
}
if defaultSource != "" {
err = json.Unmarshal([]byte(defaultSource), &defaults)
if err != nil {
return "", err
}
err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &defaults)
if err != nil {
return "", err
}
values, err := RichMultiSelect(inv, RichMultiSelectOptions{
+45 -50
View File
@@ -10,7 +10,6 @@ import (
"path/filepath"
"slices"
"strings"
"time"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -24,7 +23,6 @@ import (
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/codersdk/toolsdk"
"github.com/coder/retry"
"github.com/coder/serpent"
)
@@ -541,6 +539,7 @@ func (r *RootCmd) mcpServer() *serpent.Command {
defer cancel()
defer srv.queue.Close()
cliui.Infof(inv.Stderr, "Failed to watch screen events")
// Start the reporter, watcher, and server. These are all tied to the
// lifetime of the MCP server, which is itself tied to the lifetime of the
// AI agent.
@@ -614,51 +613,48 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
}
func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
return
}
go func() {
for retrier := retry.New(time.Second, 30*time.Second); retrier.Wait(ctx); {
eventsCh, errCh, err := s.aiAgentAPIClient.SubscribeEvents(ctx)
if err == nil {
retrier.Reset()
loop:
for {
select {
case <-ctx.Done():
for {
select {
case <-ctx.Done():
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
// If the screen is stable, report idle.
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
case event := <-eventsCh:
switch ev := event.(type) {
case agentapi.EventStatusChange:
state := codersdk.WorkspaceAppStatusStateWorking
if ev.Status == agentapi.StatusStable {
state = codersdk.WorkspaceAppStatusStateIdle
}
err := s.queue.Push(taskReport{
state: state,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
}
}
case agentapi.EventMessageUpdate:
if ev.Role == agentapi.RoleUser {
err := s.queue.Push(taskReport{
messageID: &ev.Id,
state: codersdk.WorkspaceAppStatusStateWorking,
})
if err != nil {
cliui.Warnf(inv.Stderr, "Failed to queue update: %s", err)
return
}
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
break loop
}
}
} else {
cliui.Warnf(inv.Stderr, "Failed to watch screen events: %s", err)
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
cliui.Warnf(inv.Stderr, "Received error from screen event watcher: %s", err)
}
return
}
}
}()
@@ -696,14 +692,13 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
// Add tool dependencies.
toolOpts := []func(*toolsdk.Deps){
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
state := codersdk.WorkspaceAppStatusState(args.State)
// The agent does not reliably report idle, so when AgentAPI is
// enabled we override idle to working and let the screen watcher
// detect the real idle via StatusStable. Final states (failure,
// complete) are trusted from the agent since the screen watcher
// cannot produce them.
if s.aiAgentAPIClient != nil && state == codersdk.WorkspaceAppStatusStateIdle {
state = codersdk.WorkspaceAppStatusStateWorking
// The agent does not reliably report its status correctly. If AgentAPI
// is enabled, we will always set the status to "working" when we get an
// MCP message, and rely on the screen watcher to eventually catch the
// idle state.
state := codersdk.WorkspaceAppStatusStateWorking
if s.aiAgentAPIClient == nil {
state = codersdk.WorkspaceAppStatusState(args.State)
}
return s.queue.Push(taskReport{
link: args.Link,
+1 -185
View File
@@ -921,7 +921,7 @@ func TestExpMcpReporter(t *testing.T) {
},
},
},
// We override idle from the agent to working, but trust final states.
// We ignore the state from the agent and assume "working".
{
name: "IgnoreAgentState",
// AI agent reports that it is finished but the summary says it is doing
@@ -953,46 +953,6 @@ func TestExpMcpReporter(t *testing.T) {
Message: "finished",
},
},
// Agent reports failure; trusted even with AgentAPI enabled.
{
state: codersdk.WorkspaceAppStatusStateFailure,
summary: "something broke",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateFailure,
Message: "something broke",
},
},
// After failure, watcher reports stable -> idle.
{
event: makeStatusEvent(agentapi.StatusStable),
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateIdle,
Message: "something broke",
},
},
},
},
// Final states pass through with AgentAPI enabled.
{
name: "AllowFinalStates",
tests: []test{
{
state: codersdk.WorkspaceAppStatusStateWorking,
summary: "doing work",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateWorking,
Message: "doing work",
},
},
// Agent reports complete; not overridden.
{
state: codersdk.WorkspaceAppStatusStateComplete,
summary: "all done",
expected: &codersdk.WorkspaceAppStatus{
State: codersdk.WorkspaceAppStatusStateComplete,
Message: "all done",
},
},
},
},
// When AgentAPI is not being used, we accept agent state updates as-is.
@@ -1150,148 +1110,4 @@ func TestExpMcpReporter(t *testing.T) {
<-cmdDone
})
}
t.Run("Reconnect", func(t *testing.T) {
t.Parallel()
// Create a test deployment and workspace.
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user2.ID,
}).WithAgent(func(a []*proto.Agent) []*proto.Agent {
a[0].Apps = []*proto.App{
{
Slug: "vscode",
},
}
return a
}).Do()
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitLong))
// Watch the workspace for changes.
watcher, err := client.WatchWorkspace(ctx, r.Workspace.ID)
require.NoError(t, err)
var lastAppStatus codersdk.WorkspaceAppStatus
nextUpdate := func() codersdk.WorkspaceAppStatus {
for {
select {
case <-ctx.Done():
require.FailNow(t, "timed out waiting for status update")
case w, ok := <-watcher:
require.True(t, ok, "watch channel closed")
if w.LatestAppStatus != nil && w.LatestAppStatus.ID != lastAppStatus.ID {
t.Logf("Got status update: %s > %s", lastAppStatus.State, w.LatestAppStatus.State)
lastAppStatus = *w.LatestAppStatus
return lastAppStatus
}
}
}
}
// Mock AI AgentAPI server that supports disconnect/reconnect.
disconnect := make(chan struct{})
listening := make(chan func(sse codersdk.ServerSentEvent) error)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create a cancelable context so we can stop the SSE sender
// goroutine on disconnect without waiting for the HTTP
// serve loop to cancel r.Context().
sseCtx, sseCancel := context.WithCancel(r.Context())
defer sseCancel()
r = r.WithContext(sseCtx)
send, closed, err := httpapi.ServerSentEventSender(w, r)
if err != nil {
httpapi.Write(sseCtx, w, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error setting up server-sent events.",
Detail: err.Error(),
})
return
}
// Send initial message so the watcher knows the agent is active.
send(*makeMessageEvent(0, agentapi.RoleAgent))
select {
case listening <- send:
case <-r.Context().Done():
return
}
select {
case <-closed:
case <-disconnect:
sseCancel()
<-closed
}
}))
t.Cleanup(srv.Close)
inv, _ := clitest.New(t,
"exp", "mcp", "server",
"--agent-url", client.URL.String(),
"--agent-token", r.AgentToken,
"--app-status-slug", "vscode",
"--allowed-tools=coder_report_task",
"--ai-agentapi-url", srv.URL,
)
inv = inv.WithContext(ctx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
stderr := ptytest.New(t)
inv.Stderr = stderr.Output()
// Run the MCP server.
clitest.Start(t, inv)
// Initialize.
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore init response
// Get first sender from the initial SSE connection.
sender := testutil.RequireReceive(ctx, t, listening)
// Self-report a working status via tool call.
toolPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"doing work","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got := nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "doing work", got.Message)
// Watcher sends stable, verify idle is reported.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
// Disconnect the SSE connection by signaling the handler to return.
testutil.RequireSend(ctx, t, disconnect, struct{}{})
// Wait for the watcher to reconnect and get the new sender.
sender = testutil.RequireReceive(ctx, t, listening)
// After reconnect, self-report a working status again.
toolPayload = `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"coder_report_task","arguments":{"state":"working","summary":"reconnected","link":""}}}`
pty.WriteLine(toolPayload)
_ = pty.ReadLine(ctx) // ignore echo
_ = pty.ReadLine(ctx) // ignore response
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateWorking, got.State)
require.Equal(t, "reconnected", got.Message)
// Verify the watcher still processes events after reconnect.
err = sender(*makeStatusEvent(agentapi.StatusStable))
require.NoError(t, err)
got = nextUpdate()
require.Equal(t, codersdk.WorkspaceAppStatusStateIdle, got.State)
cancel()
})
}
-12
View File
@@ -29,7 +29,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
templateVersionJobTimeout time.Duration
prebuildWorkspaceTimeout time.Duration
noCleanup bool
provisionerTags []string
tracingFlags = &scaletestTracingFlags{}
timeoutStrategy = &timeoutFlags{}
@@ -112,16 +111,10 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
for i := range numTemplates {
id := strconv.Itoa(int(i))
cfg := prebuilds.Config{
OrganizationID: me.OrganizationIDs[0],
ProvisionerTags: tags,
NumPresets: int(numPresets),
NumPresetPrebuilds: int(numPresetPrebuilds),
TemplateVersionJobTimeout: templateVersionJobTimeout,
@@ -290,11 +283,6 @@ func (r *RootCmd) scaletestPrebuilds() *serpent.Command {
Description: "Skip cleanup (deletion test) and leave resources intact.",
Value: serpent.BoolOf(&noCleanup),
},
{
Flag: "provisioner-tag",
Description: "Specify a set of tags to target provisioner daemons.",
Value: serpent.StringArrayOf(&provisionerTags),
},
}
tracingFlags.attach(&cmd.Options)
+1 -45
View File
@@ -4,9 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"golang.org/x/xerrors"
@@ -19,29 +16,6 @@ import (
"github.com/coder/serpent"
)
// detectGitRef attempts to resolve the current git branch and remote
// origin URL from the given working directory. These are sent to the
// control plane so it can look up PR/diff status via the GitHub API
// without SSHing into the workspace. Failures are silently ignored
// since this is best-effort.
func detectGitRef(workingDirectory string) (branch string, remoteOrigin string) {
run := func(args ...string) string {
//nolint:gosec
cmd := exec.Command(args[0], args[1:]...)
if workingDirectory != "" {
cmd.Dir = workingDirectory
}
out, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
branch = run("git", "rev-parse", "--abbrev-ref", "HEAD")
remoteOrigin = run("git", "config", "--get", "remote.origin.url")
return branch, remoteOrigin
}
// gitAskpass is used by the Coder agent to automatically authenticate
// with Git providers based on a hostname.
func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
@@ -64,20 +38,8 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("create agent client: %w", err)
}
workingDirectory, err := os.Getwd()
if err != nil {
workingDirectory = ""
}
// Detect the current git branch and remote origin so
// the control plane can resolve diffs without needing
// to SSH back into the workspace.
gitBranch, gitRemoteOrigin := detectGitRef(workingDirectory)
token, err := client.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
Match: host,
GitBranch: gitBranch,
GitRemoteOrigin: gitRemoteOrigin,
Match: host,
})
if err != nil {
var apiError *codersdk.Error
@@ -96,12 +58,6 @@ func gitAskpass(agentAuth *AgentAuth) *serpent.Command {
return xerrors.Errorf("get git token: %w", err)
}
if token.URL != "" {
// This is to help the agent authenticate with Git.
if inv.Environ.Get("CODER_CHAT_AGENT") == "true" {
_, _ = fmt.Fprintf(inv.Stderr, `You must notify the user to authenticate with Git.\n\nThe URL is: %s\n`, token.URL)
return cliui.ErrCanceled
}
if err := openURL(inv, token.URL); err == nil {
cliui.Infof(inv.Stderr, "Your browser has been opened to authenticate with Git:\n%s", token.URL)
} else {
+5 -1
View File
@@ -106,7 +106,11 @@ func TestList(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
memberClient, member = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
sharedWorkspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+5 -29
View File
@@ -1,7 +1,6 @@
package cli
import (
"encoding/json"
"fmt"
"strings"
@@ -232,7 +231,7 @@ next:
continue // immutables should not be passed to consecutive builds
}
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, *tvp) {
if len(tvp.Options) > 0 && !isValidTemplateParameterOption(buildParameter, tvp.Options) {
continue // do not propagate invalid options
}
@@ -298,7 +297,7 @@ func (pr *ParameterResolver) verifyConstraints(resolved []codersdk.WorkspaceBuil
return xerrors.Errorf("ephemeral parameter %q can be used only with --prompt-ephemeral-parameters or --ephemeral-parameter flag", r.Name)
}
if !tvp.Mutable && action != WorkspaceCreate && !pr.isFirstTimeUse(r.Name) {
if !tvp.Mutable && action != WorkspaceCreate {
return xerrors.Errorf("parameter %q is immutable and cannot be updated", r.Name)
}
}
@@ -366,7 +365,7 @@ func (pr *ParameterResolver) isLastBuildParameterInvalidOption(templateVersionPa
for _, buildParameter := range pr.lastBuildParameters {
if buildParameter.Name == templateVersionParameter.Name {
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter)
return !isValidTemplateParameterOption(buildParameter, templateVersionParameter.Options)
}
}
return false
@@ -390,31 +389,8 @@ func findWorkspaceBuildParameter(parameterName string, params []codersdk.Workspa
return nil
}
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, templateVersionParameter codersdk.TemplateVersionParameter) bool {
// Multi-select parameters store values as a JSON array (e.g.
// '["vim","emacs"]'), so we need to parse the array and validate
// each element individually against the allowed options.
if templateVersionParameter.Type == "list(string)" {
var values []string
if err := json.Unmarshal([]byte(buildParameter.Value), &values); err != nil {
return false
}
for _, v := range values {
found := false
for _, opt := range templateVersionParameter.Options {
if opt.Value == v {
found = true
break
}
}
if !found {
return false
}
}
return true
}
for _, opt := range templateVersionParameter.Options {
func isValidTemplateParameterOption(buildParameter codersdk.WorkspaceBuildParameter, options []codersdk.TemplateVersionParameterOption) bool {
for _, opt := range options {
if opt.Value == buildParameter.Value {
return true
}
-85
View File
@@ -1,85 +0,0 @@
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/codersdk"
)
func TestIsValidTemplateParameterOption(t *testing.T) {
t.Parallel()
options := []codersdk.TemplateVersionParameterOption{
{Name: "Vim", Value: "vim"},
{Name: "Emacs", Value: "emacs"},
{Name: "VS Code", Value: "vscode"},
}
t.Run("SingleSelectValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "vim"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("SingleSelectInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editor", Value: "notepad"}
tvp := codersdk.TemplateVersionParameter{
Name: "editor",
Type: "string",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectAllValid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","emacs"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectOneInvalid", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `["vim","notepad"]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectEmptyArray", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `[]`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.True(t, isValidTemplateParameterOption(bp, tvp))
})
t.Run("MultiSelectInvalidJSON", func(t *testing.T) {
t.Parallel()
bp := codersdk.WorkspaceBuildParameter{Name: "editors", Value: `not-json`}
tvp := codersdk.TemplateVersionParameter{
Name: "editors",
Type: "list(string)",
Options: options,
}
assert.False(t, isValidTemplateParameterOption(bp, tvp))
})
}
+4 -15
View File
@@ -884,27 +884,16 @@ func (o *OrganizationContext) Selected(inv *serpent.Invocation, client *codersdk
index := slices.IndexFunc(orgs, func(org codersdk.Organization) bool {
return org.Name == o.FlagSelect || org.ID.String() == o.FlagSelect
})
if index >= 0 {
return orgs[index], nil
}
// Not in membership list - try direct fetch.
// This allows site-wide admins (e.g., Owners) to use orgs they aren't
// members of.
org, err := client.OrganizationByName(inv.Context(), o.FlagSelect)
if err != nil {
if index < 0 {
var names []string
for _, org := range orgs {
names = append(names, org.Name)
}
var sdkErr *codersdk.Error
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound {
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
}
return codersdk.Organization{}, xerrors.Errorf("get organization %q: %w", o.FlagSelect, err)
return codersdk.Organization{}, xerrors.Errorf("organization %q not found, are you sure you are a member of this organization? "+
"Valid options for '--org=' are [%s].", o.FlagSelect, strings.Join(names, ", "))
}
return org, nil
return orgs[index], nil
}
if len(orgs) == 1 {
+44 -128
View File
@@ -95,7 +95,6 @@ import (
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/cryptorand"
@@ -137,15 +136,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
if err != nil {
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
}
if vals.OIDC.RedirectURL.String() != "" {
redirectURL, err = vals.OIDC.RedirectURL.Value().Parse("/api/v2/users/oidc/callback")
if err != nil {
return nil, xerrors.Errorf("parse oidc redirect url %q", err)
}
logger.Warn(ctx, "custom OIDC redirect URL used instead of 'access_url', ensure this matches the value configured in your OIDC provider")
}
// If the scopes contain 'groups', we enable group support.
// Do not override any custom value set by the user.
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
@@ -617,8 +607,28 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
}
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
promRegistry := prometheus.NewRegistry()
oauthInstrument := promoauth.NewFactory(promRegistry)
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
externalAuthConfigs, err := externalauth.ConvertConfig(
oauthInstrument,
vals.ExternalAuthConfigs.Value,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range externalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
if err != nil {
@@ -649,7 +659,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
Pubsub: nil,
CacheDir: cacheDir,
GoogleTokenValidator: googleTokenValidator,
ExternalAuthConfigs: nil,
ExternalAuthConfigs: externalAuthConfigs,
RealIPConfig: realIPConfig,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
@@ -809,40 +819,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("set deployment id: %w", err)
}
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
if err != nil {
return xerrors.Errorf("read external auth providers from env: %w", err)
}
mergedExternalAuthProviders := append([]codersdk.ExternalAuthConfig{}, vals.ExternalAuthConfigs.Value...)
mergedExternalAuthProviders = append(mergedExternalAuthProviders, extAuthEnv...)
vals.ExternalAuthConfigs.Value = mergedExternalAuthProviders
mergedExternalAuthProviders, err = maybeAppendDefaultGithubExternalAuthProvider(
ctx,
options.Logger,
options.Database,
vals,
mergedExternalAuthProviders,
)
if err != nil {
return xerrors.Errorf("maybe append default github external auth provider: %w", err)
}
options.ExternalAuthConfigs, err = externalauth.ConvertConfig(
oauthInstrument,
mergedExternalAuthProviders,
vals.AccessURL.Value(),
)
if err != nil {
return xerrors.Errorf("convert external auth config: %w", err)
}
for _, c := range options.ExternalAuthConfigs {
logger.Debug(
ctx, "loaded external auth config",
slog.F("id", c.ID),
)
}
// Manage push notifications.
experiments := coderd.ReadExperiments(options.Logger, options.DeploymentValues.Experiments.Value())
if experiments.Enabled(codersdk.ExperimentWebPush) {
@@ -959,12 +935,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
options.StatsBatcher = batcher
defer closeBatcher()
wsBuilderMetrics, err := wsbuilder.NewMetrics(options.PrometheusRegistry)
if err != nil {
return xerrors.Errorf("failed to register workspace builder metrics: %w", err)
}
options.WorkspaceBuilderMetrics = wsBuilderMetrics
// Manage notifications.
var (
notificationsCfg = options.DeploymentValues.Notifications
@@ -1148,7 +1118,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
defer autobuildTicker.Stop()
autobuildExecutor := autobuild.NewExecutor(
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments, coderAPI.WorkspaceBuilderMetrics)
ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments)
autobuildExecutor.Run()
jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value())
@@ -1940,79 +1910,6 @@ type githubOAuth2ConfigParams struct {
enterpriseBaseURL string
}
func isDeploymentEligibleForGithubDefaultProvider(ctx context.Context, db database.Store) (bool, error) {
// We want to enable the default provider only for new deployments, and avoid
// enabling it if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return false, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return false, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
return defaultEligible, nil
}
func maybeAppendDefaultGithubExternalAuthProvider(
ctx context.Context,
logger slog.Logger,
db database.Store,
vals *codersdk.DeploymentValues,
mergedExplicitProviders []codersdk.ExternalAuthConfig,
) ([]codersdk.ExternalAuthConfig, error) {
if !vals.ExternalAuthGithubDefaultProviderEnable.Value() {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "disabled by configuration"),
slog.F("flag", "external-auth-github-default-provider-enable"),
)
return mergedExplicitProviders, nil
}
if len(mergedExplicitProviders) > 0 {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "explicit external auth providers configured"),
slog.F("provider_count", len(mergedExplicitProviders)),
)
return mergedExplicitProviders, nil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
}
if !defaultEligible {
logger.Info(ctx, "default github external auth provider suppressed",
slog.F("reason", "deployment is not eligible"),
)
return mergedExplicitProviders, nil
}
logger.Info(ctx, "injecting default github external auth provider",
slog.F("type", codersdk.EnhancedExternalAuthProviderGitHub.String()),
slog.F("client_id", GithubOAuth2DefaultProviderClientID),
slog.F("device_flow", GithubOAuth2DefaultProviderDeviceFlow),
)
return append(mergedExplicitProviders, codersdk.ExternalAuthConfig{
Type: codersdk.EnhancedExternalAuthProviderGitHub.String(),
ClientID: GithubOAuth2DefaultProviderClientID,
DeviceFlow: GithubOAuth2DefaultProviderDeviceFlow,
}), nil
}
func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *codersdk.DeploymentValues) (*githubOAuth2ConfigParams, error) {
params := githubOAuth2ConfigParams{
accessURL: vals.AccessURL.Value(),
@@ -2037,9 +1934,28 @@ func getGithubOAuth2ConfigParams(ctx context.Context, db database.Store, vals *c
return nil, nil //nolint:nilnil
}
defaultEligible, err := isDeploymentEligibleForGithubDefaultProvider(ctx, db)
if err != nil {
return nil, err
// Check if the deployment is eligible for the default GitHub OAuth2 provider.
// We want to enable it only for new deployments, and avoid enabling it
// if a deployment was upgraded from an older version.
// nolint:gocritic // Requires system privileges
defaultEligible, err := db.GetOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx))
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, xerrors.Errorf("get github default eligible: %w", err)
}
defaultEligibleNotSet := errors.Is(err, sql.ErrNoRows)
if defaultEligibleNotSet {
// nolint:gocritic // User count requires system privileges
userCount, err := db.GetUserCount(dbauthz.AsSystemRestricted(ctx), false)
if err != nil {
return nil, xerrors.Errorf("get user count: %w", err)
}
// We check if a deployment is new by checking if it has any users.
defaultEligible = userCount == 0
// nolint:gocritic // Requires system privileges
if err := db.UpsertOAuth2GithubDefaultEligible(dbauthz.AsSystemRestricted(ctx), defaultEligible); err != nil {
return nil, xerrors.Errorf("upsert github default eligible: %w", err)
}
}
if !defaultEligible {
-163
View File
@@ -53,7 +53,6 @@ import (
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/userpassword"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
@@ -303,7 +302,6 @@ func TestServer(t *testing.T) {
"open install.sh: file does not exist",
"telemetry disabled, unable to notify of security issues",
"installed terraform version newer than expected",
"report generator",
}
countLines := func(fullOutput string) int {
@@ -1742,18 +1740,6 @@ func TestServer(t *testing.T) {
// Next, we instruct the same server to display the YAML config
// and then save it.
// Because this is literally the same invocation, DefaultFn sets the
// value of 'Default'. Which triggers a mutually exclusive error
// on the next parse.
// Usually we only parse flags once, so this is not an issue
for _, c := range inv.Command.Children {
if c.Name() == "server" {
for i := range c.Options {
c.Options[i].DefaultFn = nil
}
break
}
}
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
//nolint:gocritic
inv.Args = append(args, "--write-config")
@@ -1807,155 +1793,6 @@ func TestServer(t *testing.T) {
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_ExternalAuthGitHubDefaultProvider(t *testing.T) {
type testCase struct {
name string
args []string
env map[string]string
createUserPreStart bool
expectedProviders []string
}
run := func(t *testing.T, tc testCase) {
ctx := testutil.Context(t, testutil.WaitLong)
unsetPrefixedEnv := func(prefix string) {
t.Helper()
for _, envVar := range os.Environ() {
envKey, _, found := strings.Cut(envVar, "=")
if !found || !strings.HasPrefix(envKey, prefix) {
continue
}
value, had := os.LookupEnv(envKey)
require.True(t, had)
require.NoError(t, os.Unsetenv(envKey))
keyCopy := envKey
valueCopy := value
t.Cleanup(func() {
// This is for setting/unsetting a number of prefixed env vars.
// t.Setenv doesn't cover this use case.
// nolint:usetesting
_ = os.Setenv(keyCopy, valueCopy)
})
}
}
unsetPrefixedEnv("CODER_EXTERNAL_AUTH_")
unsetPrefixedEnv("CODER_GITAUTH_")
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
db, _ := dbtestutil.NewDB(t, dbtestutil.WithURL(dbURL))
const (
existingUserEmail = "existing-user@coder.com"
existingUserUsername = "existing-user"
existingUserPassword = "SomeSecurePassword!"
)
if tc.createUserPreStart {
hashedPassword, err := userpassword.Hash(existingUserPassword)
require.NoError(t, err)
_ = dbgen.User(t, db, database.User{
Email: existingUserEmail,
Username: existingUserUsername,
HashedPassword: []byte(hashedPassword),
})
}
args := []string{
"server",
"--postgres-url", dbURL,
"--http-address", ":0",
"--access-url", "https://example.com",
}
args = append(args, tc.args...)
inv, cfg := clitest.New(t, args...)
for envKey, value := range tc.env {
t.Setenv(envKey, value)
}
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
if tc.createUserPreStart {
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: existingUserEmail,
Password: existingUserPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
} else {
_ = coderdtest.CreateFirstUser(t, client)
}
externalAuthResp, err := client.ListExternalAuths(ctx)
require.NoError(t, err)
gotProviders := map[string]codersdk.ExternalAuthLinkProvider{}
for _, provider := range externalAuthResp.Providers {
gotProviders[provider.ID] = provider
}
require.Len(t, gotProviders, len(tc.expectedProviders))
for _, providerID := range tc.expectedProviders {
provider, ok := gotProviders[providerID]
require.Truef(t, ok, "expected provider %q to be configured", providerID)
if providerID == codersdk.EnhancedExternalAuthProviderGitHub.String() {
require.Equal(t, codersdk.EnhancedExternalAuthProviderGitHub.String(), provider.Type)
require.True(t, provider.Device)
}
}
}
for _, tc := range []testCase{
{
name: "NewDeployment_NoExplicitProviders_InjectsDefaultGithub",
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitHub.String()},
},
{
name: "ExistingDeployment_DoesNotInjectDefaultGithub",
createUserPreStart: true,
expectedProviders: nil,
},
{
name: "DefaultProviderDisabled_DoesNotInjectDefaultGithub",
args: []string{
"--external-auth-github-default-provider-enable=false",
},
expectedProviders: nil,
},
{
name: "ExplicitProviderViaConfig_DoesNotInjectDefaultGithub",
args: []string{
`--external-auth-providers=[{"type":"gitlab","client_id":"config-client-id"}]`,
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_EXTERNAL_AUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_EXTERNAL_AUTH_0_CLIENT_ID": "env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
{
name: "ExplicitProviderViaLegacyEnv_DoesNotInjectDefaultGithub",
env: map[string]string{
"CODER_GITAUTH_0_TYPE": codersdk.EnhancedExternalAuthProviderGitLab.String(),
"CODER_GITAUTH_0_CLIENT_ID": "legacy-env-client-id",
},
expectedProviders: []string{codersdk.EnhancedExternalAuthProviderGitLab.String()},
},
} {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+31 -7
View File
@@ -25,7 +25,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -64,8 +68,12 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
orgOwner = coderdtest.CreateFirstUser(t, client)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -119,7 +127,11 @@ func TestSharingShare(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -170,7 +182,11 @@ func TestSharingStatus(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -214,7 +230,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -271,7 +291,11 @@ func TestSharingRemove(t *testing.T) {
t.Parallel()
var (
client, db = coderdtest.NewWithDatabase(t, nil)
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
}),
})
orgOwner = coderdtest.CreateFirstUser(t, client)
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, orgOwner.OrganizationID, rbac.ScopedRoleOrgAuditor(orgOwner.OrganizationID))
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
+1 -1
View File
@@ -120,7 +120,7 @@ func (r *RootCmd) start() *serpent.Command {
func buildWorkspaceStartRequest(inv *serpent.Invocation, client *codersdk.Client, workspace codersdk.Workspace, parameterFlags workspaceParameterFlags, buildFlags buildFlags, action WorkspaceCLIAction) (codersdk.CreateWorkspaceBuildRequest, error) {
version := workspace.LatestBuild.TemplateVersionID
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || workspace.TemplateRequireActiveVersion || action == WorkspaceUpdate {
if workspace.AutomaticUpdates == codersdk.AutomaticUpdatesAlways || action == WorkspaceUpdate {
version = workspace.TemplateActiveVersionID
if version != workspace.LatestBuild.TemplateVersionID {
action = WorkspaceUpdate
+4 -4
View File
@@ -33,7 +33,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
statefilePath := filepath.Join(t.TempDir(), "state")
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name, statefilePath)
@@ -54,7 +54,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", r.Workspace.Name)
var gotState bytes.Buffer
@@ -74,7 +74,7 @@ func TestStatePull(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(wantState).
Seed(database.WorkspaceBuild{ProvisionerState: wantState}).
Do()
inv, root := clitest.New(t, "state", "pull", taUser.Username+"/"+r.Workspace.Name,
"--build", fmt.Sprintf("%d", r.Build.BuildNumber))
@@ -170,7 +170,7 @@ func TestStatePush(t *testing.T) {
OrganizationID: owner.OrganizationID,
OwnerID: taUser.ID,
}).
Seed(database.WorkspaceBuild{}).ProvisionerState(initialState).
Seed(database.WorkspaceBuild{ProvisionerState: initialState}).
Do()
wantState := []byte("updated state")
stateFile, err := os.CreateTemp(t.TempDir(), "")
+7 -9
View File
@@ -1,3 +1,5 @@
//go:build !windows
package cli_test
import (
@@ -5,7 +7,6 @@ import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"time"
@@ -24,15 +25,12 @@ func setupSocketServer(t *testing.T) (path string, cleanup func()) {
t.Helper()
// Use a temporary socket path for each test
socketPath := testutil.AgentSocketPath(t)
socketPath := filepath.Join(testutil.TempDirUnixSocket(t), "test.sock")
// Create parent directory if needed. Not necessary on Windows because named pipes live in an abstract namespace
// not tied to any real files.
if runtime.GOOS != "windows" {
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
}
// Create parent directory if needed
parentDir := filepath.Dir(socketPath)
err := os.MkdirAll(parentDir, 0o700)
require.NoError(t, err, "create socket directory")
server, err := agentsocket.NewServer(
slog.Make().Leveled(slog.LevelDebug),
-2
View File
@@ -17,8 +17,6 @@ func (r *RootCmd) tasksCommand() *serpent.Command {
r.taskDelete(),
r.taskList(),
r.taskLogs(),
r.taskPause(),
r.taskResume(),
r.taskSend(),
r.taskStatus(),
},
+10 -5
View File
@@ -41,7 +41,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client // user already has access to their own workspace
inv, root := clitest.New(t, "task", "logs", task.Name, "--output", "json")
output := clitest.Capture(inv)
@@ -64,7 +65,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String(), "--output", "json")
output := clitest.Capture(inv)
@@ -87,7 +89,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsOK(testMessages))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String())
output := clitest.Capture(inv)
@@ -141,7 +144,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskLogsErr(assert.AnError))
userClient := client
inv, root := clitest.New(t, "task", "logs", task.ID.String())
clitest.SetupConfig(t, userClient, root)
@@ -197,7 +201,8 @@ func Test_TaskLogs_Golden(t *testing.T) {
t.Run("SnapshotWithoutLogs_NoSnapshotCaptured", func(t *testing.T) {
t.Parallel()
userClient, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
client, task := setupCLITaskTestWithoutSnapshot(t, codersdk.TaskStatusPaused)
userClient := client
inv, root := clitest.New(t, "task", "logs", task.Name)
output := clitest.Capture(inv)
-90
View File
@@ -1,90 +0,0 @@
package cli
import (
"fmt"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func (r *RootCmd) taskPause() *serpent.Command {
cmd := &serpent.Command{
Use: "pause <task>",
Short: "Pause a task",
Long: FormatExamples(
Example{
Description: "Pause a task by name",
Command: "coder task pause my-task",
},
Example{
Description: "Pause another user's task",
Command: "coder task pause alice/my-task",
},
Example{
Description: "Pause a task without confirmation",
Command: "coder task pause my-task --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Options: serpent.OptionSet{
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
if err != nil {
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
if task.Status == codersdk.TaskStatusPaused {
return xerrors.Errorf("task %q is already paused", display)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Pause task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
resp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("pause task %q: %w", display, err)
}
if resp.WorkspaceBuild == nil {
return xerrors.Errorf("pause task %q: no workspace build returned", display)
}
err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID)
if err != nil {
return xerrors.Errorf("watch pause build for task %q: %w", display, err)
}
_, _ = fmt.Fprintf(
inv.Stdout,
"\nThe %s task has been paused at %s!\n",
cliui.Keyword(task.Name),
cliui.Timestamp(time.Now()),
)
return nil
},
}
return cmd
}
-144
View File
@@ -1,144 +0,0 @@
package cli_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskPause(t *testing.T) {
t.Parallel()
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: Expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
// OtherUserTask verifies that an admin can pause a task owned by
// another user using the "owner/name" identifier format.
t.Run("OtherUserTask", func(t *testing.T) {
t.Parallel()
// Given: A different user's running task
setupCtx := testutil.Context(t, testutil.WaitLong)
adminClient, _, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause their task
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "pause", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been paused")
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("PromptConfirm", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to pause the task
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Pause task")
pty.WriteLine("yes")
// Then: We expect the task to be paused
pty.ExpectMatchContext(ctx, "has been paused")
require.NoError(t, w.Wait())
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("PromptDecline", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to pause the task
inv, root := clitest.New(t, "task", "pause", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: We say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Pause task")
pty.WriteLine("no")
require.Error(t, w.Wait())
// Then: We expect the task to not be paused
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.NotEqual(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("TaskAlreadyPaused", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// And: We paused the running task
ctx := testutil.Context(t, testutil.WaitMedium)
resp, err := userClient.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, resp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, resp.WorkspaceBuild.ID)
// When: We attempt to pause the task again
inv, root := clitest.New(t, "task", "pause", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is already paused
err = inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "is already paused")
})
}
-95
View File
@@ -1,95 +0,0 @@
package cli
import (
"fmt"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/pretty"
"github.com/coder/serpent"
)
func (r *RootCmd) taskResume() *serpent.Command {
var noWait bool
cmd := &serpent.Command{
Use: "resume <task>",
Short: "Resume a task",
Long: FormatExamples(
Example{
Description: "Resume a task by name",
Command: "coder task resume my-task",
},
Example{
Description: "Resume another user's task",
Command: "coder task resume alice/my-task",
},
Example{
Description: "Resume a task without confirmation",
Command: "coder task resume my-task --yes",
},
),
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Options: serpent.OptionSet{
{
Flag: "no-wait",
Description: "Return immediately after resuming the task.",
Value: serpent.BoolOf(&noWait),
},
cliui.SkipPromptOption(),
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
client, err := r.InitClient(inv)
if err != nil {
return err
}
task, err := client.TaskByIdentifier(ctx, inv.Args[0])
if err != nil {
return xerrors.Errorf("resolve task %q: %w", inv.Args[0], err)
}
display := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
if task.Status == codersdk.TaskStatusError || task.Status == codersdk.TaskStatusUnknown {
return xerrors.Errorf("task %q is in %s state and cannot be resumed; check the workspace build logs and agent status for details", display, task.Status)
} else if task.Status != codersdk.TaskStatusPaused {
return xerrors.Errorf("task %q cannot be resumed (current status: %s)", display, task.Status)
}
_, err = cliui.Prompt(inv, cliui.PromptOptions{
Text: fmt.Sprintf("Resume task %s?", pretty.Sprint(cliui.DefaultStyles.Code, display)),
IsConfirm: true,
Default: cliui.ConfirmNo,
})
if err != nil {
return err
}
resp, err := client.ResumeTask(ctx, task.OwnerName, task.ID)
if err != nil {
return xerrors.Errorf("resume task %q: %w", display, err)
} else if resp.WorkspaceBuild == nil {
return xerrors.Errorf("resume task %q: no workspace build returned", display)
}
if noWait {
_, _ = fmt.Fprintf(inv.Stdout, "Resuming task %q in the background.\n", cliui.Keyword(display))
return nil
}
if err = cliui.WorkspaceBuild(ctx, inv.Stdout, client, resp.WorkspaceBuild.ID); err != nil {
return xerrors.Errorf("watch resume build for task %q: %w", display, err)
}
_, _ = fmt.Fprintf(inv.Stdout, "\nThe %s task has been resumed.\n", cliui.Keyword(display))
return nil
},
}
return cmd
}
-183
View File
@@ -1,183 +0,0 @@
package cli_test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpTaskResume(t *testing.T) {
t.Parallel()
// pauseTask is a helper that pauses a task and waits for the stop
// build to complete.
pauseTask := func(ctx context.Context, t *testing.T, client *codersdk.Client, task codersdk.Task) {
t.Helper()
pauseResp, err := client.PauseTask(ctx, task.OwnerName, task.ID)
require.NoError(t, err)
require.NotNil(t, pauseResp.WorkspaceBuild)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
}
t.Run("WithYesFlag", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
// OtherUserTask verifies that an admin can resume a task owned by
// another user using the "owner/name" identifier format.
t.Run("OtherUserTask", func(t *testing.T) {
t.Parallel()
// Given: A different user's paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
adminClient, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume their task
identifier := fmt.Sprintf("%s/%s", task.OwnerName, task.Name)
inv, root := clitest.New(t, "task", "resume", identifier, "--yes")
output := clitest.Capture(inv)
clitest.SetupConfig(t, adminClient, root)
// Then: We expect the task to be resumed
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "has been resumed")
updated, err := adminClient.TaskByIdentifier(ctx, identifier)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("NoWait", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task (and specify no wait)
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes", "--no-wait")
output := clitest.Capture(inv)
clitest.SetupConfig(t, userClient, root)
// Then: We expect the task to be resumed in the background
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
require.Contains(t, output.Stdout(), "in the background")
// And: The task to eventually be resumed
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
ws := coderdtest.MustWorkspace(t, userClient, task.WorkspaceID.UUID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, ws.LatestBuild.ID)
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptConfirm", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// And: We confirm we want to resume the task
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("yes")
// Then: We expect the task to be resumed
pty.ExpectMatchContext(ctx, "has been resumed")
require.NoError(t, w.Wait())
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusInitializing, updated.Status)
})
t.Run("PromptDecline", func(t *testing.T) {
t.Parallel()
// Given: A paused task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
pauseTask(setupCtx, t, userClient, task)
// When: We attempt to resume the task
inv, root := clitest.New(t, "task", "resume", task.Name)
clitest.SetupConfig(t, userClient, root)
// But: Say no at the confirmation screen
ctx := testutil.Context(t, testutil.WaitMedium)
inv = inv.WithContext(ctx)
pty := ptytest.New(t).Attach(inv)
w := clitest.StartWithWaiter(t, inv)
pty.ExpectMatchContext(ctx, "Resume task")
pty.WriteLine("no")
require.Error(t, w.Wait())
// Then: We expect the task to still be paused
updated, err := userClient.TaskByIdentifier(ctx, task.Name)
require.NoError(t, err)
require.Equal(t, codersdk.TaskStatusPaused, updated.Status)
})
t.Run("TaskNotPaused", func(t *testing.T) {
t.Parallel()
// Given: A running task
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, nil)
// When: We attempt to resume the task that is not paused
inv, root := clitest.New(t, "task", "resume", task.Name, "--yes")
clitest.SetupConfig(t, userClient, root)
// Then: We expect to get an error that the task is not paused
ctx := testutil.Context(t, testutil.WaitMedium)
err := inv.WithContext(ctx).Run()
require.ErrorContains(t, err, "cannot be resumed")
})
}
+7 -4
View File
@@ -25,7 +25,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "carry on with the task")
@@ -41,7 +42,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.ID.String(), "carry on with the task")
@@ -57,7 +59,8 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
client, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendOK(t, "carry on with the task", "you got it"))
userClient := client
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "--stdin")
@@ -110,7 +113,7 @@ func Test_TaskSend(t *testing.T) {
t.Parallel()
setupCtx := testutil.Context(t, testutil.WaitLong)
_, userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
userClient, task := setupCLITaskTest(setupCtx, t, fakeAgentAPITaskSendErr(t, assert.AnError))
var stdout strings.Builder
inv, root := clitest.New(t, "task", "send", task.Name, "some task input")
+10 -44
View File
@@ -120,40 +120,6 @@ func Test_Tasks(t *testing.T) {
require.Equal(t, logs[2].Type, codersdk.TaskLogTypeOutput, "third message should be an output")
},
},
{
name: "pause task",
cmdArgs: []string{"task", "pause", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, "has been paused", "pause output should confirm task was paused")
},
},
{
name: "get task status after pause",
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, taskName, task.Name, "task name should match")
require.Equal(t, codersdk.TaskStatusPaused, task.Status, "task should be paused")
},
},
{
name: "resume task",
cmdArgs: []string{"task", "resume", taskName, "--yes"},
assertFn: func(stdout string, userClient *codersdk.Client) {
require.Contains(t, stdout, "has been resumed", "resume output should confirm task was resumed")
},
},
{
name: "get task status after resume",
cmdArgs: []string{"task", "status", taskName, "--output", "json"},
assertFn: func(stdout string, userClient *codersdk.Client) {
var task codersdk.Task
require.NoError(t, json.NewDecoder(strings.NewReader(stdout)).Decode(&task), "should unmarshal task status")
require.Equal(t, taskName, task.Name, "task name should match")
require.Equal(t, codersdk.TaskStatusInitializing, task.Status, "task should be initializing after resume")
},
},
{
name: "delete task",
cmdArgs: []string{"task", "delete", taskName, "--yes"},
@@ -272,17 +238,17 @@ func fakeAgentAPIEcho(ctx context.Context, t testing.TB, initMsg agentapisdk.Mes
// setupCLITaskTest creates a test workspace with an AI task template and agent,
// with a fake agent API configured with the provided set of handlers.
// Returns the user client and workspace.
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (ownerClient *codersdk.Client, memberClient *codersdk.Client, task codersdk.Task) {
func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[string]http.HandlerFunc) (*codersdk.Client, codersdk.Task) {
t.Helper()
ownerClient = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, ownerClient)
userClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
fakeAPI := startFakeAgentAPI(t, agentAPIHandlers)
authToken := uuid.NewString()
template := createAITaskTemplate(t, ownerClient, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
template := createAITaskTemplate(t, client, owner.OrganizationID, withSidebarURL(fakeAPI.URL()), withAgentToken(authToken))
wantPrompt := "test prompt"
task, err := userClient.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
@@ -296,17 +262,17 @@ func setupCLITaskTest(ctx context.Context, t *testing.T, agentAPIHandlers map[st
require.True(t, task.WorkspaceID.Valid, "task should have a workspace ID")
workspace, err := userClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, userClient, workspace.LatestBuild.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
agentClient := agentsdk.New(userClient.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, userClient.URL, authToken, func(o *agent.Options) {
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(authToken))
_ = agenttest.New(t, client.URL, authToken, func(o *agent.Options) {
o.Client = agentClient
})
coderdtest.NewWorkspaceAgentWaiter(t, userClient, workspace.ID).
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).
WaitFor(coderdtest.AgentsReady)
return ownerClient, userClient, task
return userClient, task
}
// setupCLITaskTestWithSnapshot creates a task in the specified status with a log snapshot.
-4
View File
@@ -139,10 +139,8 @@ func (r *RootCmd) templateVersionsList() *serpent.Command {
type templateVersionRow struct {
// For json format:
TemplateVersion codersdk.TemplateVersion `table:"-"`
ActiveJSON bool `json:"active" table:"-"`
// For table format:
ID string `json:"-" table:"id"`
Name string `json:"-" table:"name,default_sort"`
CreatedAt time.Time `json:"-" table:"created at"`
CreatedBy string `json:"-" table:"created by"`
@@ -168,8 +166,6 @@ func templateVersionsToRows(activeVersionID uuid.UUID, templateVersions ...coder
rows[i] = templateVersionRow{
TemplateVersion: templateVersion,
ActiveJSON: templateVersion.ID == activeVersionID,
ID: templateVersion.ID.String(),
Name: templateVersion.Name,
CreatedAt: templateVersion.CreatedAt,
CreatedBy: templateVersion.CreatedBy.Username,
-29
View File
@@ -1,9 +1,7 @@
package cli_test
import (
"bytes"
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
@@ -42,33 +40,6 @@ func TestTemplateVersions(t *testing.T) {
pty.ExpectMatch(version.CreatedBy.Username)
pty.ExpectMatch("Active")
})
t.Run("ListVersionsJSON", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, nil)
_ = coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "templates", "versions", "list", template.Name, "--output", "json")
clitest.SetupConfig(t, member, root)
var stdout bytes.Buffer
inv.Stdout = &stdout
require.NoError(t, inv.Run())
var rows []struct {
TemplateVersion codersdk.TemplateVersion `json:"TemplateVersion"`
Active bool `json:"active"`
}
require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows))
require.Len(t, rows, 1)
assert.Equal(t, version.ID, rows[0].TemplateVersion.ID)
assert.True(t, rows[0].Active)
})
}
func TestTemplateVersionsPromote(t *testing.T) {
+5 -13
View File
@@ -49,9 +49,10 @@ OPTIONS:
security purposes if a --wildcard-access-url is configured.
--disable-workspace-sharing bool, $CODER_DISABLE_WORKSPACE_SHARING
Disable workspace sharing. Workspace ACL checking is disabled and only
owners can have ssh, apps and terminal access to workspaces. Access
based on the 'owner' role is also allowed unless disabled via
Disable workspace sharing (requires the "workspace-sharing" experiment
to be enabled). Workspace ACL checking is disabled and only owners can
have ssh, apps and terminal access to workspaces. Access based on the
'owner' role is also allowed unless disabled via
--disable-owner-workspace-access.
--swagger-enable bool, $CODER_SWAGGER_ENABLE
@@ -62,9 +63,6 @@ OPTIONS:
Separate multiple experiments with commas, or enter '*' to opt-in to
all available experiments.
--external-auth-github-default-provider-enable bool, $CODER_EXTERNAL_AUTH_GITHUB_DEFAULT_PROVIDER_ENABLE (default: true)
Enable the default GitHub external auth provider managed by Coder.
--postgres-auth password|awsiamrds, $CODER_PG_AUTH (default: password)
Type of auth to use when connecting to postgres. For AWS RDS, using
IAM authentication (awsiamrds) is recommended.
@@ -385,19 +383,13 @@ NETWORKING OPTIONS:
--samesite-auth-cookie lax|none, $CODER_SAMESITE_AUTH_COOKIE (default: lax)
Controls the 'SameSite' property is set on browser session cookies.
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE (default: false)
--secure-auth-cookie bool, $CODER_SECURE_AUTH_COOKIE
Controls if the 'Secure' property is set on browser session cookies.
--wildcard-access-url string, $CODER_WILDCARD_ACCESS_URL
Specifies the wildcard hostname to use for workspace applications in
the form "*.example.com".
--host-prefix-cookie bool, $CODER_HOST_PREFIX_COOKIE (default: false)
Recommended to be enabled. Enables `__Host-` prefix for cookies to
guarantee they are only set by the right domain. This change is
disruptive to any workspaces built before release 2.31, requiring a
workspace restart.
NETWORKING / DERP OPTIONS:
Most Coder deployments never have to think about DERP because all connections
between workspaces and users are peer-to-peer. However, when Coder cannot
-2
View File
@@ -12,8 +12,6 @@ SUBCOMMANDS:
delete Delete tasks
list List tasks
logs Show a task's logs
pause Pause a task
resume Resume a task
send Send input to a task
status Show the status of a task.
-25
View File
@@ -1,25 +0,0 @@
coder v0.0.0-devel
USAGE:
coder task pause [flags] <task>
Pause a task
- Pause a task by name:
$ coder task pause my-task
- Pause another user's task:
$ coder task pause alice/my-task
- Pause a task without confirmation:
$ coder task pause my-task --yes
OPTIONS:
-y, --yes bool
Bypass confirmation prompts.
———
Run `coder --help` for a list of global options.
-28
View File
@@ -1,28 +0,0 @@
coder v0.0.0-devel
USAGE:
coder task resume [flags] <task>
Resume a task
- Resume a task by name:
$ coder task resume my-task
- Resume another user's task:
$ coder task resume alice/my-task
- Resume a task without confirmation:
$ coder task resume my-task --yes
OPTIONS:
--no-wait bool
Return immediately after resuming the task.
-y, --yes bool
Bypass confirmation prompts.
———
Run `coder --help` for a list of global options.
+1 -1
View File
@@ -9,7 +9,7 @@ OPTIONS:
-O, --org string, $CODER_ORGANIZATION
Select which organization (uuid or name) to use.
-c, --column [id|name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
-c, --column [name|created at|created by|status|active|archived] (default: name,created at,created by,status,active)
Columns to display in table output.
--include-archived bool
+1 -1
View File
@@ -27,7 +27,7 @@ USAGE:
SUBCOMMANDS:
create Create a token
list List tokens
remove Expire or delete a token
remove Delete a token
view Display detailed information about a token
———
-4
View File
@@ -15,10 +15,6 @@ OPTIONS:
-c, --column [id|name|scopes|allow list|last used|expires at|created at|owner] (default: id,name,scopes,allow list,last used,expires at,created at)
Columns to display in table output.
--include-expired bool
Include expired tokens in the output. By default, expired tokens are
hidden.
-o, --output table|json (default: table)
Output format.
+2 -10
View File
@@ -1,19 +1,11 @@
coder v0.0.0-devel
USAGE:
coder tokens remove [flags] <name|id|token>
coder tokens remove <name|id|token>
Expire or delete a token
Delete a token
Aliases: delete, rm
Remove a token by expiring it. Use --delete to permanently hard-delete the
token instead.
OPTIONS:
--delete bool
Permanently delete the token instead of expiring it. This removes the
audit trail.
———
Run `coder --help` for a list of global options.
+5 -18
View File
@@ -176,16 +176,11 @@ networking:
# (default: <unset>, type: string-array)
proxyTrustedOrigins: []
# Controls if the 'Secure' property is set on browser session cookies.
# (default: false, type: bool)
# (default: <unset>, type: bool)
secureAuthCookie: false
# Controls the 'SameSite' property is set on browser session cookies.
# (default: lax, type: enum[lax\|none])
sameSiteAuthCookie: lax
# Recommended to be enabled. Enables `__Host-` prefix for cookies to guarantee
# they are only set by the right domain. This change is disruptive to any
# workspaces built before release 2.31, requiring a workspace restart.
# (default: false, type: bool)
hostPrefixCookie: false
# Whether Coder only allows connections to workspaces via the browser.
# (default: <unset>, type: bool)
browserOnly: false
@@ -422,11 +417,6 @@ oidc:
# an insecure OIDC configuration. It is not recommended to use this flag.
# (default: <unset>, type: bool)
dangerousSkipIssuerChecks: false
# Optional override of the default redirect url which uses the deployment's access
# url. Useful in situations where a deployment has more than 1 domain. Using this
# setting can also break OIDC, so use with caution.
# (default: <unset>, type: url)
oidc-redirect-url:
# Telemetry is critical to our ability to improve Coder. We strip all personal
# information before sending data to our servers. Please only disable telemetry
# when required by your organization's security policy.
@@ -524,10 +514,10 @@ disablePathApps: false
# workspaces.
# (default: <unset>, type: bool)
disableOwnerWorkspaceAccess: false
# Disable workspace sharing. Workspace ACL checking is disabled and only owners
# can have ssh, apps and terminal access to workspaces. Access based on the
# 'owner' role is also allowed unless disabled via
# --disable-owner-workspace-access.
# Disable workspace sharing (requires the "workspace-sharing" experiment to be
# enabled). Workspace ACL checking is disabled and only owners can have ssh, apps
# and terminal access to workspaces. Access based on the 'owner' role is also
# allowed unless disabled via --disable-owner-workspace-access.
# (default: <unset>, type: bool)
disableWorkspaceSharing: false
# These options change the behavior of how clients interact with the Coder.
@@ -564,9 +554,6 @@ supportLinks: []
# External Authentication providers.
# (default: <unset>, type: struct[[]codersdk.ExternalAuthConfig])
externalAuthProviders: []
# Enable the default GitHub external auth provider managed by Coder.
# (default: true, type: bool)
externalAuthGithubDefaultProviderEnable: true
# Hostname of HTTPS server that runs https://github.com/coder/wgtunnel. By
# default, this will pick the best available wgtunnel server hosted by Coder. e.g.
# "tunnel.example.com".
+14 -37
View File
@@ -218,10 +218,9 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
var (
all bool
includeExpired bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
all bool
displayTokens []tokenListRow
formatter = cliui.NewOutputFormatter(
cliui.TableFormat([]tokenListRow{}, defaultCols),
cliui.JSONFormat(),
)
@@ -241,8 +240,7 @@ func (r *RootCmd) listTokens() *serpent.Command {
}
tokens, err := client.Tokens(inv.Context(), codersdk.Me, codersdk.TokensFilter{
IncludeAll: all,
IncludeExpired: includeExpired,
IncludeAll: all,
})
if err != nil {
return xerrors.Errorf("list tokens: %w", err)
@@ -276,12 +274,6 @@ func (r *RootCmd) listTokens() *serpent.Command {
Description: "Specifies whether all users' tokens will be listed or not (must have Owner role to see all tokens).",
Value: serpent.BoolOf(&all),
},
{
Name: "include-expired",
Flag: "include-expired",
Description: "Include expired tokens in the output. By default, expired tokens are hidden.",
Value: serpent.BoolOf(&includeExpired),
},
}
formatter.AttachOptions(&cmd.Options)
@@ -331,13 +323,10 @@ func (r *RootCmd) viewToken() *serpent.Command {
}
func (r *RootCmd) removeToken() *serpent.Command {
var deleteToken bool
cmd := &serpent.Command{
Use: "remove <name|id|token>",
Aliases: []string{"delete"},
Short: "Expire or delete a token",
Long: "Remove a token by expiring it. Use --delete to permanently hard-" +
"delete the token instead.",
Short: "Delete a token",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
@@ -349,7 +338,7 @@ func (r *RootCmd) removeToken() *serpent.Command {
token, err := client.APIKeyByName(inv.Context(), codersdk.Me, inv.Args[0])
if err != nil {
// If it's a token, we need to extract the ID.
// If it's a token, we need to extract the ID
maybeID := strings.Split(inv.Args[0], "-")[0]
token, err = client.APIKeyByID(inv.Context(), codersdk.Me, maybeID)
if err != nil {
@@ -357,29 +346,17 @@ func (r *RootCmd) removeToken() *serpent.Command {
}
}
if deleteToken {
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been deleted.")
return nil
}
err = client.ExpireAPIKey(inv.Context(), codersdk.Me, token.ID)
err = client.DeleteAPIKey(inv.Context(), codersdk.Me, token.ID)
if err != nil {
return xerrors.Errorf("expire api key: %w", err)
return xerrors.Errorf("delete api key: %w", err)
}
cliui.Infof(inv.Stdout, "Token has been expired.")
return nil
},
}
cmd.Options = serpent.OptionSet{
{
Flag: "delete",
Description: "Permanently delete the token instead of expiring it. This removes the audit trail.",
Value: serpent.BoolOf(&deleteToken),
cliui.Infof(
inv.Stdout,
"Token has been deleted.",
)
return nil
},
}
+17 -153
View File
@@ -6,16 +6,12 @@ import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -26,7 +22,7 @@ func TestTokens(t *testing.T) {
adminUser := coderdtest.CreateFirstUser(t, client)
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
thirdUserClient, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
_, thirdUser := coderdtest.CreateAnotherUser(t, client, adminUser.OrganizationID)
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancelFunc()
@@ -159,7 +155,7 @@ func TestTokens(t *testing.T) {
require.Len(t, scopedToken.AllowList, 1)
require.Equal(t, allowSpec, scopedToken.AllowList[0].String())
// Delete by name (default behavior is now expire)
// Delete by name
inv, root = clitest.New(t, "tokens", "rm", "token-one")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
@@ -168,42 +164,21 @@ func TestTokens(t *testing.T) {
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "expired")
// Regular users cannot expire other users' tokens (expire is default now).
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, thirdUserClient, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
// Only admin users can expire other users' tokens (expire is default now).
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
// Validate that token was expired
if token, err := client.APIKeyByName(ctx, secondUser.ID.String(), "token-two"); assert.NoError(t, err) {
require.True(t, token.ExpiresAt.Before(time.Now()))
}
// Delete by ID (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", scopedTokenID)
// Delete by ID
inv, root = clitest.New(t, "tokens", "rm", secondTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
// Delete scoped token by ID
inv, root = clitest.New(t, "tokens", "rm", scopedTokenID)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
@@ -224,8 +199,8 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
fourthToken := res
// Delete by token (explicit delete flag)
inv, root = clitest.New(t, "tokens", "rm", "--delete", fourthToken)
// Delete by token
inv, root = clitest.New(t, "tokens", "rm", fourthToken)
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
@@ -235,114 +210,3 @@ func TestTokens(t *testing.T) {
require.NotEmpty(t, res)
require.Contains(t, res, "deleted")
}
func TestTokensListExpiredFiltering(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
// Create a valid (non-expired) token
validToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
UserID: owner.UserID,
ExpiresAt: time.Now().Add(24 * time.Hour),
LoginType: database.LoginTypeToken,
TokenName: "valid-token",
})
// Create an expired token
expiredToken, _ := dbgen.APIKey(t, api.Database, database.APIKey{
UserID: owner.UserID,
ExpiresAt: time.Now().Add(-24 * time.Hour),
LoginType: database.LoginTypeToken,
TokenName: "expired-token",
})
t.Run("HidesExpiredByDefault", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.NotContains(t, res, expiredToken.ID)
require.NotContains(t, res, "expired-token")
})
t.Run("ShowsExpiredWithFlag", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls", "--include-expired")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.Contains(t, res, expiredToken.ID)
require.Contains(t, res, "expired-token")
})
t.Run("JSONOutputRespectsFilter", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
// Default (no expired)
inv, root := clitest.New(t, "tokens", "ls", "--output=json")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
require.Contains(t, res, "valid-token")
require.NotContains(t, res, "expired-token")
// With --include-expired
inv, root = clitest.New(t, "tokens", "ls", "--output=json", "--include-expired")
clitest.SetupConfig(t, client, root)
buf = new(bytes.Buffer)
inv.Stdout = buf
err = inv.WithContext(ctx).Run()
require.NoError(t, err)
res = buf.String()
require.Contains(t, res, "valid-token")
require.Contains(t, res, "expired-token")
})
t.Run("AllUsersWithIncludeExpired", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
inv, root := clitest.New(t, "tokens", "ls", "--all", "--include-expired")
clitest.SetupConfig(t, client, root)
buf := new(bytes.Buffer)
inv.Stdout = buf
err := inv.WithContext(ctx).Run()
require.NoError(t, err)
res := buf.String()
// Should show both valid and expired tokens
require.Contains(t, res, validToken.ID)
require.Contains(t, res, "valid-token")
require.Contains(t, res, expiredToken.ID)
require.Contains(t, res, "expired-token")
})
}
-70
View File
@@ -990,74 +990,4 @@ func TestUpdateValidateRichParameters(t *testing.T) {
_ = testutil.TryReceive(ctx, t, doneChan)
})
t.Run("NewImmutableParameterViaFlag", func(t *testing.T) {
t.Parallel()
// Create template and workspace with only a mutable parameter.
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
member, memberUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
templateParameters := []*proto.RichParameter{
{Name: stringParameterName, Type: "string", Mutable: true, Required: true, Options: []*proto.RichParameterOption{
{Name: "First option", Description: "This is first option", Value: "1st"},
{Name: "Second option", Description: "This is second option", Value: "2nd"},
}},
}
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(templateParameters))
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
inv, root := clitest.New(t, "create", "my-workspace", "--yes", "--template", template.Name, "--parameter", fmt.Sprintf("%s=%s", stringParameterName, "1st"))
clitest.SetupConfig(t, member, root)
err := inv.Run()
require.NoError(t, err)
// Update template: add a new immutable parameter.
updatedTemplateParameters := []*proto.RichParameter{
templateParameters[0],
{Name: immutableParameterName, Type: "string", Mutable: false, Required: true, Options: []*proto.RichParameterOption{
{Name: "fir", Description: "First option for immutable parameter", Value: "I"},
{Name: "sec", Description: "Second option for immutable parameter", Value: "II"},
}},
}
updatedVersion := coderdtest.UpdateTemplateVersion(t, client, owner.OrganizationID, prepareEchoResponses(updatedTemplateParameters), template.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, client, updatedVersion.ID)
err = client.UpdateActiveTemplateVersion(context.Background(), template.ID, codersdk.UpdateActiveTemplateVersion{
ID: updatedVersion.ID,
})
require.NoError(t, err)
// Update workspace, supplying the new immutable parameter via
// the --parameter flag. This should succeed because it's the
// first time this parameter is being set.
inv, root = clitest.New(t, "update", "my-workspace",
"--parameter", fmt.Sprintf("%s=%s", immutableParameterName, "II"))
clitest.SetupConfig(t, member, root)
pty := ptytest.New(t).Attach(inv)
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
err := inv.Run()
assert.NoError(t, err)
}()
pty.ExpectMatch("Planning workspace")
ctx := testutil.Context(t, testutil.WaitLong)
_ = testutil.TryReceive(ctx, t, doneChan)
// Verify the immutable parameter was set correctly.
workspace, err := client.WorkspaceByOwnerAndName(ctx, memberUser.ID.String(), "my-workspace", codersdk.WorkspaceOptions{})
require.NoError(t, err)
actualParameters, err := client.WorkspaceBuildParameters(ctx, workspace.LatestBuild.ID)
require.NoError(t, err)
require.Contains(t, actualParameters, codersdk.WorkspaceBuildParameter{
Name: immutableParameterName,
Value: "II",
})
})
}
+24
View File
@@ -0,0 +1,24 @@
//go:build !windows && !darwin
package cli
import (
"golang.org/x/xerrors"
"github.com/coder/serpent"
)
func (*RootCmd) vpnDaemonRun() *serpent.Command {
cmd := &serpent.Command{
Use: "run",
Short: "Run the VPN daemon on Windows.",
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
Handler: func(_ *serpent.Invocation) error {
return xerrors.New("vpn-daemon subcommand is not supported on this platform")
},
}
return cmd
}
@@ -1,4 +1,4 @@
//go:build windows || linux
//go:build windows
package cli
@@ -11,7 +11,7 @@ import (
"github.com/coder/serpent"
)
func (*RootCmd) vpnDaemonRun() *serpent.Command {
func (r *RootCmd) vpnDaemonRun() *serpent.Command {
var (
rpcReadHandleInt int64
rpcWriteHandleInt int64
@@ -19,7 +19,7 @@ func (*RootCmd) vpnDaemonRun() *serpent.Command {
cmd := &serpent.Command{
Use: "run",
Short: "Run the VPN daemon on Windows and Linux.",
Short: "Run the VPN daemon on Windows.",
Middleware: serpent.Chain(
serpent.RequireNArgs(0),
),
@@ -53,8 +53,8 @@ func (*RootCmd) vpnDaemonRun() *serpent.Command {
return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be different", rpcReadHandleInt, rpcWriteHandleInt)
}
// The manager passes the read and write descriptors directly to the
// daemon, so we can open the RPC pipe from the raw values.
// We don't need to worry about duplicating the handles on Windows,
// which is different from Unix.
logger.Info(ctx, "opening bidirectional RPC pipe", slog.F("rpc_read_handle", rpcReadHandleInt), slog.F("rpc_write_handle", rpcWriteHandleInt))
pipe, err := vpn.NewBidirectionalPipe(uintptr(rpcReadHandleInt), uintptr(rpcWriteHandleInt))
if err != nil {
@@ -62,7 +62,7 @@ func (*RootCmd) vpnDaemonRun() *serpent.Command {
}
defer pipe.Close()
logger.Info(ctx, "starting VPN tunnel")
logger.Info(ctx, "starting tunnel")
tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient(), vpn.UseOSNetworkingStack())
if err != nil {
return xerrors.Errorf("create new tunnel for client: %w", err)
@@ -1,19 +0,0 @@
//go:build linux
package cli_test
import (
"os"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
func dupHandle(t *testing.T, f *os.File) uintptr {
t.Helper()
dupFD, err := unix.Dup(int(f.Fd()))
require.NoError(t, err)
return uintptr(dupFD)
}
@@ -1,33 +0,0 @@
//go:build windows
package cli_test
import (
"os"
"syscall"
"testing"
"github.com/stretchr/testify/require"
)
func dupHandle(t *testing.T, f *os.File) uintptr {
t.Helper()
src := syscall.Handle(f.Fd())
var dup syscall.Handle
proc, err := syscall.GetCurrentProcess()
require.NoError(t, err)
err = syscall.DuplicateHandle(
proc,
src,
proc,
&dup,
0,
false,
syscall.DUPLICATE_SAME_ACCESS,
)
require.NoError(t, err)
return uintptr(dup)
}
@@ -1,4 +1,4 @@
//go:build windows || linux
//go:build windows
package cli_test
@@ -67,35 +67,22 @@ func TestVPNDaemonRun(t *testing.T) {
r1, w1, err := os.Pipe()
require.NoError(t, err)
defer r1.Close()
defer w1.Close()
r2, w2, err := os.Pipe()
require.NoError(t, err)
defer r2.Close()
// The daemon closes the handles passed via NewBidirectionalPipe. Since our
// CLI tests run in-process, pass duplicated handles so we can close the
// originals without risking a double-close on FD reuse.
rpcReadHandle := dupHandle(t, r1)
rpcWriteHandle := dupHandle(t, w2)
require.NoError(t, r1.Close())
require.NoError(t, w2.Close())
defer w2.Close()
ctx := testutil.Context(t, testutil.WaitLong)
inv, _ := clitest.New(t,
"vpn-daemon",
"run",
"--rpc-read-handle",
fmt.Sprint(rpcReadHandle),
"--rpc-write-handle",
fmt.Sprint(rpcWriteHandle),
)
inv, _ := clitest.New(t, "vpn-daemon", "run", "--rpc-read-handle", fmt.Sprint(r1.Fd()), "--rpc-write-handle", fmt.Sprint(w2.Fd()))
waiter := clitest.StartWithWaiter(t, inv.WithContext(ctx))
// Send an invalid header, including a newline delimiter, so the handshake
// fails without requiring context cancellation.
_, err = w1.Write([]byte("garbage\n"))
// Send garbage which should cause the handshake to fail and the daemon
// to exit.
_, err = w1.Write([]byte("garbage"))
require.NoError(t, err)
waiter.Cancel()
err = waiter.Wait()
require.ErrorContains(t, err, "handshake failed")
})
-2
View File
@@ -179,8 +179,6 @@ func New(opts Options, workspace database.Workspace) *API {
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
Clock: opts.Clock,
NotificationsEnqueuer: opts.NotificationsEnqueuer,
}
api.MetadataAPI = &MetadataAPI{
-240
View File
@@ -2,10 +2,6 @@ package agentapi
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -13,14 +9,7 @@ import (
"cdr.dev/slog/v3"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/notifications"
strutil "github.com/coder/coder/v2/coderd/util/strings"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
type AppsAPI struct {
@@ -28,8 +17,6 @@ type AppsAPI struct {
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent, wspubsub.WorkspaceEventKind) error
NotificationsEnqueuer notifications.Enqueuer
Clock quartz.Clock
}
func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
@@ -117,230 +104,3 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
func (a *AppsAPI) UpdateAppStatus(ctx context.Context, req *agentproto.UpdateAppStatusRequest) (*agentproto.UpdateAppStatusResponse, error) {
if len(req.Message) > 160 {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Message is too long.",
Detail: "Message must be less than 160 characters.",
Validations: []codersdk.ValidationError{
{Field: "message", Detail: "Message must be less than 160 characters."},
},
})
}
var dbState database.WorkspaceAppStatusState
switch req.State {
case agentproto.UpdateAppStatusRequest_COMPLETE:
dbState = database.WorkspaceAppStatusStateComplete
case agentproto.UpdateAppStatusRequest_FAILURE:
dbState = database.WorkspaceAppStatusStateFailure
case agentproto.UpdateAppStatusRequest_WORKING:
dbState = database.WorkspaceAppStatusStateWorking
case agentproto.UpdateAppStatusRequest_IDLE:
dbState = database.WorkspaceAppStatusStateIdle
default:
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Invalid state provided.",
Detail: fmt.Sprintf("invalid state: %q", req.State),
Validations: []codersdk.ValidationError{
{Field: "state", Detail: "State must be one of: complete, failure, working, idle."},
},
})
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
}
app, err := a.Database.GetWorkspaceAppByAgentIDAndSlug(ctx, database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: workspaceAgent.ID,
Slug: req.Slug,
})
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace app.",
Detail: fmt.Sprintf("No app found with slug %q", req.Slug),
})
}
workspace, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID)
if err != nil {
return nil, codersdk.NewError(http.StatusBadRequest, codersdk.Response{
Message: "Failed to get workspace.",
Detail: err.Error(),
})
}
// Treat the message as untrusted input.
cleaned := strutil.UISanitize(req.Message)
// Get the latest status for the workspace app to detect no-op updates
// nolint:gocritic // This is a system restricted operation.
latestAppStatus, err := a.Database.GetLatestWorkspaceAppStatusByAppID(dbauthz.AsSystemRestricted(ctx), app.ID)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get latest workspace app status.",
Detail: err.Error(),
})
}
// If no rows found, latestAppStatus will be a zero-value struct (ID == uuid.Nil)
// nolint:gocritic // This is a system restricted operation.
_, err = a.Database.InsertWorkspaceAppStatus(dbauthz.AsSystemRestricted(ctx), database.InsertWorkspaceAppStatusParams{
ID: uuid.New(),
CreatedAt: dbtime.Now(),
WorkspaceID: workspace.ID,
AgentID: workspaceAgent.ID,
AppID: app.ID,
State: dbState,
Message: cleaned,
Uri: sql.NullString{
String: req.Uri,
Valid: req.Uri != "",
},
})
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to insert workspace app status.",
Detail: err.Error(),
})
}
if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent, wspubsub.WorkspaceEventKindAgentAppStatusUpdate)
if err != nil {
return nil, codersdk.NewError(http.StatusInternalServerError, codersdk.Response{
Message: "Failed to publish workspace update.",
Detail: err.Error(),
})
}
}
// Notify on state change to Working/Idle for AI tasks
a.enqueueAITaskStateNotification(ctx, app.ID, latestAppStatus, dbState, workspace, workspaceAgent)
if shouldBump(dbState, latestAppStatus) {
// We pass time.Time{} for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic handles this by
// defaulting to the template's activity_bump duration (typically 1 hour).
workspacestats.ActivityBumpWorkspace(ctx, a.Log, a.Database, workspace.ID, time.Time{})
}
// just return a blank response because it doesn't contain any settable fields at present.
return new(agentproto.UpdateAppStatusResponse), nil
}
func shouldBump(dbState database.WorkspaceAppStatusState, latestAppStatus database.WorkspaceAppStatus) bool {
// Bump deadline when agent reports working or transitions away from working.
// This prevents auto-pause during active work and gives users time to interact
// after work completes.
// Bump if reporting working state.
if dbState == database.WorkspaceAppStatusStateWorking {
return true
}
// Bump if transitioning away from working state.
if latestAppStatus.ID != uuid.Nil {
prevState := latestAppStatus.State
if prevState == database.WorkspaceAppStatusStateWorking {
return true
}
}
return false
}
// enqueueAITaskStateNotification enqueues a notification when an AI task's app
// transitions to Working or Idle.
// No-op if:
// - the workspace agent app isn't configured as an AI task,
// - the new state equals the latest persisted state,
// - the workspace agent is not ready (still starting up).
func (a *AppsAPI) enqueueAITaskStateNotification(
ctx context.Context,
appID uuid.UUID,
latestAppStatus database.WorkspaceAppStatus,
newAppStatus database.WorkspaceAppStatusState,
workspace database.Workspace,
agent database.WorkspaceAgent,
) {
var notificationTemplate uuid.UUID
switch newAppStatus {
case database.WorkspaceAppStatusStateWorking:
notificationTemplate = notifications.TemplateTaskWorking
case database.WorkspaceAppStatusStateIdle:
notificationTemplate = notifications.TemplateTaskIdle
case database.WorkspaceAppStatusStateComplete:
notificationTemplate = notifications.TemplateTaskCompleted
case database.WorkspaceAppStatusStateFailure:
notificationTemplate = notifications.TemplateTaskFailed
default:
// Not a notifiable state, do nothing
return
}
if !workspace.TaskID.Valid {
// Workspace has no task ID, do nothing.
return
}
// Only send notifications when the agent is ready. We want to skip
// any state transitions that occur whilst the workspace is starting
// up as it doesn't make sense to receive them.
if agent.LifecycleState != database.WorkspaceAgentLifecycleStateReady {
a.Log.Debug(ctx, "skipping AI task notification because agent is not ready",
slog.F("agent_id", agent.ID),
slog.F("lifecycle_state", agent.LifecycleState),
slog.F("new_app_status", newAppStatus),
)
return
}
task, err := a.Database.GetTaskByID(ctx, workspace.TaskID.UUID)
if err != nil {
a.Log.Warn(ctx, "failed to get task", slog.Error(err))
return
}
if !task.WorkspaceAppID.Valid || task.WorkspaceAppID.UUID != appID {
// Non-task app, do nothing.
return
}
// Skip if the latest persisted state equals the new state (no new transition)
// Note: uuid.Nil check is valid here. If no previous status exists,
// GetLatestWorkspaceAppStatusByAppID returns sql.ErrNoRows and we get a zero-value struct.
if latestAppStatus.ID != uuid.Nil && latestAppStatus.State == newAppStatus {
return
}
// Skip the initial "Working" notification when the task first starts.
// This is obvious to the user since they just created the task.
// We still notify on the first "Idle" status and all subsequent transitions.
if latestAppStatus.ID == uuid.Nil && newAppStatus == database.WorkspaceAppStatusStateWorking {
return
}
if _, err := a.NotificationsEnqueuer.EnqueueWithData(
// nolint:gocritic // Need notifier actor to enqueue notifications
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notificationTemplate,
map[string]string{
"task": task.Name,
"workspace": workspace.Name,
},
map[string]any{
// Use a 1-minute bucketed timestamp to bypass per-day dedupe,
// allowing identical content to resend within the same day
// (but not more than once every 10s).
"dedupe_bypass_ts": a.Clock.Now().UTC().Truncate(time.Minute),
},
"api-workspace-agent-app-status",
// Associate this notification with related entities
workspace.ID, workspace.OwnerID, workspace.OrganizationID, appID,
); err != nil {
a.Log.Warn(ctx, "failed to notify of task state", slog.Error(err))
return
}
}
-115
View File
@@ -1,115 +0,0 @@
package agentapi
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/util/ptr"
)
func TestShouldBump(t *testing.T) {
t.Parallel()
tests := []struct {
name string
prevState *database.WorkspaceAppStatusState // nil means no previous state
newState database.WorkspaceAppStatusState
shouldBump bool
}{
{
name: "FirstStatusBumps",
prevState: nil,
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "WorkingToIdleBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: true,
},
{
name: "WorkingToCompleteBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: true,
},
{
name: "CompleteToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "CompleteToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "FailureToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "FailureToFailureNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: false,
},
{
name: "CompleteToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateComplete),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
{
name: "FailureToCompleteNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateFailure),
newState: database.WorkspaceAppStatusStateComplete,
shouldBump: false,
},
{
name: "WorkingToFailureBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateWorking),
newState: database.WorkspaceAppStatusStateFailure,
shouldBump: true,
},
{
name: "IdleToIdleNoBump",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateIdle,
shouldBump: false,
},
{
name: "IdleToWorkingBumps",
prevState: ptr.Ref(database.WorkspaceAppStatusStateIdle),
newState: database.WorkspaceAppStatusStateWorking,
shouldBump: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var prevAppStatus database.WorkspaceAppStatus
// If there's a previous state, report it first.
if tt.prevState != nil {
prevAppStatus.ID = uuid.UUID{1}
prevAppStatus.State = *tt.prevState
}
didBump := shouldBump(tt.newState, prevAppStatus)
if tt.shouldBump {
require.True(t, didBump, "wanted deadline to bump but it didn't")
} else {
require.False(t, didBump, "wanted deadline not to bump but it did")
}
})
}
}
-188
View File
@@ -2,13 +2,9 @@ package agentapi_test
import (
"context"
"database/sql"
"net/http"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
@@ -16,12 +12,8 @@ import (
"github.com/coder/coder/v2/coderd/agentapi"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/wspubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchUpdateAppHealths(t *testing.T) {
@@ -261,183 +253,3 @@ func TestBatchUpdateAppHealths(t *testing.T) {
require.Nil(t, resp)
})
}
func TestWorkspaceAgentAppStatus(t *testing.T) {
t.Parallel()
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
fEnq := &notificationstest.FakeEnqueuer{}
mClock := quartz.NewMock(t)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
workspaceUpdates := make(chan wspubsub.WorkspaceEventKind, 100)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
PublishWorkspaceUpdateFn: func(_ context.Context, agnt *database.WorkspaceAgent, kind wspubsub.WorkspaceEventKind) error {
assert.Equal(t, *agnt, agent)
testutil.AssertSend(ctx, t, workspaceUpdates, kind)
return nil
},
NotificationsEnqueuer: fEnq,
Clock: mClock,
}
app := database.WorkspaceApp{
ID: uuid.UUID{8},
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), database.GetWorkspaceAppByAgentIDAndSlugParams{
AgentID: agent.ID,
Slug: "vscode",
}).Times(1).Return(app, nil)
task := database.Task{
ID: uuid.UUID{7},
WorkspaceAppID: uuid.NullUUID{
Valid: true,
UUID: app.ID,
},
}
mDB.EXPECT().GetTaskByID(gomock.Any(), task.ID).Times(1).Return(task, nil)
workspace := database.Workspace{
ID: uuid.UUID{9},
TaskID: uuid.NullUUID{
Valid: true,
UUID: task.ID,
},
}
mDB.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Times(1).Return(workspace, nil)
appStatus := database.WorkspaceAppStatus{
ID: uuid.UUID{6},
}
mDB.EXPECT().GetLatestWorkspaceAppStatusByAppID(gomock.Any(), app.ID).Times(1).Return(appStatus, nil)
mDB.EXPECT().InsertWorkspaceAppStatus(
gomock.Any(),
gomock.Cond(func(params database.InsertWorkspaceAppStatusParams) bool {
if params.AgentID == agent.ID && params.AppID == app.ID {
assert.Equal(t, "testing", params.Message)
assert.Equal(t, database.WorkspaceAppStatusStateComplete, params.State)
assert.True(t, params.Uri.Valid)
assert.Equal(t, "https://example.com", params.Uri.String)
return true
}
return false
})).Times(1).Return(database.WorkspaceAppStatus{}, nil)
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.NoError(t, err)
kind := testutil.RequireReceive(ctx, t, workspaceUpdates)
require.Equal(t, wspubsub.WorkspaceEventKindAgentAppStatusUpdate, kind)
sent := fEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskCompleted))
require.Len(t, sent, 1)
})
t.Run("FailUnknownApp", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
mDB.EXPECT().GetWorkspaceAppByAgentIDAndSlug(gomock.Any(), gomock.Any()).
Times(1).
Return(database.WorkspaceApp{}, sql.ErrNoRows)
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "unknown",
Message: "testing",
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "No app found with slug")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailUnknownState", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: "testing",
Uri: "https://example.com",
State: 77,
})
require.ErrorContains(t, err, "Invalid state")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
t.Run("FailTooLong", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
agent := database.WorkspaceAgent{
ID: uuid.UUID{2},
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
}
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: mDB,
Log: testutil.Logger(t),
}
_, err := api.UpdateAppStatus(ctx, &agentproto.UpdateAppStatusRequest{
Slug: "vscode",
Message: strings.Repeat("a", 161),
Uri: "https://example.com",
State: agentproto.UpdateAppStatusRequest_COMPLETE,
})
require.ErrorContains(t, err, "Message is too long")
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
})
}
+1 -1
View File
@@ -128,7 +128,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
Name: agentName,
ResourceID: parentAgent.ResourceID,
AuthToken: uuid.New(),
AuthInstanceID: sql.NullString{},
AuthInstanceID: parentAgent.AuthInstanceID,
Architecture: req.Architecture,
EnvironmentVariables: pqtype.NullRawMessage{},
OperatingSystem: req.OperatingSystem,
+1 -46
View File
@@ -175,52 +175,6 @@ func TestSubAgentAPI(t *testing.T) {
}
})
// Context: https://github.com/coder/coder/pull/22196
t.Run("CreateSubAgentDoesNotInheritAuthInstanceID", func(t *testing.T) {
t.Parallel()
var (
log = testutil.Logger(t)
clock = quartz.NewMock(t)
db, org = newDatabaseWithOrg(t)
user, agent = newUserWithWorkspaceAgent(t, db, org)
)
// Given: The parent agent has an AuthInstanceID set
ctx := testutil.Context(t, testutil.WaitShort)
parentAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agent.ID)
require.NoError(t, err)
require.True(t, parentAgent.AuthInstanceID.Valid, "parent agent should have an AuthInstanceID")
require.NotEmpty(t, parentAgent.AuthInstanceID.String)
api := newAgentAPI(t, log, db, clock, user, org, agent)
// When: We create a sub agent
createResp, err := api.CreateSubAgent(ctx, &proto.CreateSubAgentRequest{
Name: "sub-agent",
Directory: "/workspaces/test",
Architecture: "amd64",
OperatingSystem: "linux",
})
require.NoError(t, err)
subAgentID, err := uuid.FromBytes(createResp.Agent.Id)
require.NoError(t, err)
// Then: The sub-agent must NOT re-use the parent's AuthInstanceID.
subAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), subAgentID)
require.NoError(t, err)
assert.False(t, subAgent.AuthInstanceID.Valid, "sub-agent should not have an AuthInstanceID")
assert.Empty(t, subAgent.AuthInstanceID.String, "sub-agent AuthInstanceID string should be empty")
// Double-check: looking up by the parent's instance ID must
// still return the parent, not the sub-agent.
lookedUp, err := db.GetWorkspaceAgentByInstanceID(dbauthz.AsSystemRestricted(ctx), parentAgent.AuthInstanceID.String)
require.NoError(t, err)
assert.Equal(t, parentAgent.ID, lookedUp.ID, "instance ID lookup should still return the parent agent")
})
type expectedAppError struct {
index int32
field string
@@ -1366,6 +1320,7 @@ func TestSubAgentAPI(t *testing.T) {
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
+3 -124
View File
@@ -21,12 +21,10 @@ import (
agentapisdk "github.com/coder/agentapi-sdk-go"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/searchquery"
@@ -192,8 +190,7 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) {
})
defer commitAuditWS()
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, &createWorkspaceOptions{
remoteAddr: r.RemoteAddr,
workspace, err := createWorkspace(ctx, aReqWS, apiKey.UserID, api, owner, createReq, r, &createWorkspaceOptions{
// Before creating the workspace, ensure that this task can be created.
preCreateInTX: func(ctx context.Context, tx database.Store) error {
// Create task record in the database before creating the workspace so that
@@ -467,6 +464,7 @@ func (api *API) convertTasks(ctx context.Context, requesterID uuid.UUID, dbTasks
apiWorkspaces, err := convertWorkspaces(
ctx,
api.Experiments,
api.Logger,
requesterID,
workspaces,
@@ -546,6 +544,7 @@ func (api *API) taskGet(rw http.ResponseWriter, r *http.Request) {
ws, err := convertWorkspace(
ctx,
api.Experiments,
api.Logger,
apiKey.UserID,
workspace,
@@ -1301,127 +1300,7 @@ func (api *API) pauseTask(rw http.ResponseWriter, r *http.Request) {
return
}
if _, err := api.NotificationsEnqueuer.Enqueue(
// nolint:gocritic // Need notifier actor to enqueue notifications.
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notifications.TemplateTaskPaused,
map[string]string{
"task": task.Name,
"task_id": task.ID.String(),
"workspace": workspace.Name,
"pause_reason": "manual",
},
"api-task-pause",
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
); err != nil {
api.Logger.Warn(ctx, "failed to notify of task paused", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
}
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.PauseTaskResponse{
WorkspaceBuild: &build,
})
}
// @Summary Resume task
// @ID resume-task
// @Security CoderSessionToken
// @Accept json
// @Tags Tasks
// @Param user path string true "Username, user ID, or 'me' for the authenticated user"
// @Param task path string true "Task ID" format(uuid)
// @Success 202 {object} codersdk.ResumeTaskResponse
// @Router /tasks/{user}/{task}/resume [post]
func (api *API) resumeTask(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apiKey = httpmw.APIKey(r)
task = httpmw.TaskParam(r)
)
if !task.WorkspaceID.Valid {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Task does not have a workspace.",
})
return
}
workspace, err := api.Database.GetWorkspaceByID(ctx, task.WorkspaceID.UUID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task workspace.",
Detail: err.Error(),
})
return
}
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task workspace build.",
Detail: err.Error(),
})
return
}
job, err := api.Database.GetProvisionerJobByID(ctx, latestBuild.JobID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching task workspace build job.",
Detail: err.Error(),
})
return
}
workspaceStatus := codersdk.ConvertWorkspaceStatus(
codersdk.ProvisionerJobStatus(job.JobStatus),
codersdk.WorkspaceTransition(latestBuild.Transition),
)
if workspaceStatus == codersdk.WorkspaceStatusRunning {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Task workspace is already running.",
Detail: fmt.Sprintf("Workspace status is %q.", workspaceStatus),
})
return
}
buildReq := codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionStart,
Reason: codersdk.CreateWorkspaceBuildReasonTaskResume,
}
build, err := api.postWorkspaceBuildsInternal(
ctx,
apiKey,
workspace,
buildReq,
func(action policy.Action, object rbac.Objecter) bool {
return api.Authorize(r, action, object)
},
audit.WorkspaceBuildBaggageFromRequest(r),
)
if err != nil {
httperror.WriteWorkspaceBuildError(ctx, rw, err)
return
}
if _, err := api.NotificationsEnqueuer.Enqueue(
// nolint:gocritic // Need notifier actor to enqueue notifications.
dbauthz.AsNotifier(ctx),
workspace.OwnerID,
notifications.TemplateTaskResumed,
map[string]string{
"task": task.Name,
"task_id": task.ID.String(),
"workspace": workspace.Name,
},
"api-task-resume",
workspace.ID, workspace.OwnerID, workspace.OrganizationID,
); err != nil {
api.Logger.Warn(ctx, "failed to notify of task resumed", slog.Error(err), slog.F("task_id", task.ID), slog.F("workspace_id", workspace.ID))
}
httpapi.Write(ctx, rw, http.StatusAccepted, codersdk.ResumeTaskResponse{
WorkspaceBuild: &build,
})
}
+40 -448
View File
@@ -45,10 +45,10 @@ import (
)
// createTaskInState is a helper to create a task in the desired state.
// It returns a function that takes context, test, and status, and returns the task.
// It returns a function that takes context, test, and status, and returns the task ID.
// The caller is responsible for setting up the database, owner, and user.
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) database.Task {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) database.Task {
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) uuid.UUID {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)
builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
@@ -65,9 +65,6 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusActive:
// Default builder produces a succeeded start build.
// Post-processing below sets agent and app to active.
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
@@ -79,32 +76,31 @@ func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID,
}
resp := builder.Do()
taskID := resp.Task.ID
// Post-process by manipulating agent and app state.
if status == database.TaskStatusActive || status == database.TaskStatusError {
// Set agent to ready state so agent_status returns 'active'.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)
// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")
appHealth := database.WorkspaceAppHealthHealthy
if status == database.TaskStatusError {
appHealth = database.WorkspaceAppHealthUnhealthy
}
err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: appHealth,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}
return resp.Task
return taskID
}
}
@@ -832,7 +828,7 @@ func TestTasks(t *testing.T) {
t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{})
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)
@@ -849,9 +845,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPaused)
taskID := createTask(ctx, t, database.TaskStatusPaused)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -867,9 +863,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusInitializing)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -885,9 +881,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -903,9 +899,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusError)
taskID := createTask(ctx, t, database.TaskStatusError)
err := client.TaskSend(ctx, "me", task.ID, codersdk.TaskSendRequest{
err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})
@@ -1124,16 +1120,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1142,16 +1138,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusInitializing)
taskID := createTask(ctx, t, database.TaskStatusInitializing)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1160,16 +1156,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPaused)
taskID := createTask(ctx, t, database.TaskStatusPaused)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(snapshotJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err, "upserting task snapshot")
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err, "fetching task logs")
verifySnapshotLogs(t, logsResp)
})
@@ -1178,9 +1174,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
logsResp, err := client.TaskLogs(ctx, "me", task.ID)
logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err)
assert.True(t, logsResp.Snapshot)
@@ -1192,7 +1188,7 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
invalidEnvelope := coderd.TaskLogSnapshotEnvelope{
Format: "unknown-format",
@@ -1202,13 +1198,13 @@ func TestTasks(t *testing.T) {
require.NoError(t, err)
err = db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(invalidJSON),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err)
_, err = client.TaskLogs(ctx, "me", task.ID)
_, err = client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -1221,16 +1217,16 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)
err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: task.ID,
TaskID: taskID,
LogSnapshot: json.RawMessage(`{"format":"agentapi","data":"not an object"}`),
LogSnapshotCreatedAt: snapshotTime,
})
require.NoError(t, err)
_, err = client.TaskLogs(ctx, "me", task.ID)
_, err = client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -1242,9 +1238,9 @@ func TestTasks(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
task := createTask(ctx, t, database.TaskStatusError)
taskID := createTask(ctx, t, database.TaskStatusError)
_, err := client.TaskLogs(ctx, "me", task.ID)
_, err := client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
var sdkErr *codersdk.Error
@@ -2516,20 +2512,13 @@ func TestPauseTask(t *testing.T) {
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
resp, err := client.PauseTask(ctx, codersdk.Me, task.ID)
// Verify that the request was accepted correctly:
require.NoError(t, err)
build := *resp.WorkspaceBuild
require.NotNil(t, build)
require.Equal(t, codersdk.WorkspaceTransitionStop, build.Transition)
require.Equal(t, task.WorkspaceID.UUID, build.WorkspaceID)
require.Equal(t, workspace.LatestBuild.BuildNumber+1, build.BuildNumber)
require.Equal(t, string(codersdk.CreateWorkspaceBuildReasonTaskManualPause), string(build.Reason))
// Verify that the accepted request was processed correctly:
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
workspace, err = client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceStatusStopped, workspace.LatestBuild.Status)
})
t.Run("Non-owner role access", func(t *testing.T) {
@@ -2567,6 +2556,7 @@ func TestPauseTask(t *testing.T) {
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
task, _ := setupWorkspaceTask(t, db, owner)
userClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, tc.roles...)
@@ -2790,402 +2780,4 @@ func TestPauseTask(t *testing.T) {
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
})
t.Run("Notification", func(t *testing.T) {
t.Parallel()
var (
notifyEnq = &notificationstest.FakeEnqueuer{}
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
owner = coderdtest.CreateFirstUser(t, ownerClient)
)
ctx := testutil.Context(t, testutil.WaitMedium)
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
require.NoError(t, err)
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
// Given: A task in an active state
task := createTask(ctx, t, database.TaskStatusActive)
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
// When: We pause the task
_, err = ownerClient.PauseTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
// Then: A notification should be sent
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskPaused))
require.Len(t, sent, 1)
require.Equal(t, owner.UserID, sent[0].UserID)
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
require.Equal(t, "manual", sent[0].Labels["pause_reason"])
})
}
func TestResumeTask(t *testing.T) {
t.Parallel()
setupClient := func(t *testing.T, db database.Store, ps pubsub.Pubsub, authorizer rbac.Authorizer) *codersdk.Client {
t.Helper()
client, _, _ := coderdtest.NewWithAPI(t, &coderdtest.Options{
Database: db,
Pubsub: ps,
Authorizer: authorizer,
IncludeProvisionerDaemon: true,
})
return client
}
setupWorkspaceTask := func(t *testing.T, db database.Store, user codersdk.CreateFirstUserResponse) (database.Task, uuid.UUID) {
t.Helper()
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithTask(database.TaskTable{
Prompt: "resume me",
}, nil).Do()
return workspaceBuild.Task, workspaceBuild.Workspace.ID
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionGraph: []*proto.Response{
{Type: &proto.Response_Graph{Graph: &proto.GraphComplete{
HasAiTasks: true,
}}},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
task, err := client.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "resume me",
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
pauseResp, err := client.PauseTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
resumeResp, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
build := *resumeResp.WorkspaceBuild
require.Equal(t, codersdk.WorkspaceTransitionStart, build.Transition)
require.Equal(t, task.WorkspaceID.UUID, build.WorkspaceID)
require.Equal(t, workspace.LatestBuild.BuildNumber+2, build.BuildNumber)
require.Equal(t, string(codersdk.CreateWorkspaceBuildReasonTaskResume), string(build.Reason))
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, build.ID)
workspace, err = client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
require.Equal(t, codersdk.WorkspaceStatusRunning, workspace.LatestBuild.Status)
})
t.Run("Resume a task that is not paused", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
db, ps := dbtestutil.NewDB(t)
client := setupClient(t, db, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).
WithTask(database.TaskTable{
Prompt: "pause me",
}, nil).
Succeeded().
Do()
_, err := client.ResumeTask(ctx, codersdk.Me, workspaceBuild.Task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
})
t.Run("Task not found", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
_ = coderdtest.CreateFirstUser(t, client)
_, err := client.ResumeTask(ctx, codersdk.Me, uuid.New())
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Task lookup forbidden", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
auth := &coderdtest.FakeAuthorizer{
ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error {
if action == policy.ActionRead && object.Type == rbac.ResourceTask.Type {
return rbac.UnauthorizedError{}
}
return nil
},
}
client := setupClient(t, db, ps, auth)
user := coderdtest.CreateFirstUser(t, client)
task, _ := setupWorkspaceTask(t, db, user)
_, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Workspace lookup forbidden", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
auth := &coderdtest.FakeAuthorizer{
ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error {
if action == policy.ActionRead && object.Type == rbac.ResourceWorkspace.Type {
return rbac.UnauthorizedError{}
}
return nil
},
}
client := setupClient(t, db, ps, auth)
user := coderdtest.CreateFirstUser(t, client)
task, _ := setupWorkspaceTask(t, db, user)
_, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("No Workspace for Task", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
client := setupClient(t, db, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).Do()
task := dbgen.Task(t, db, database.TaskTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
TemplateVersionID: workspaceBuild.Build.TemplateVersionID,
Prompt: "no workspace",
})
_, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
require.Equal(t, "Task does not have a workspace.", apiErr.Message)
})
t.Run("Workspace not found", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
var workspaceID uuid.UUID
wrapped := aiTaskStoreWrapper{
Store: db,
getWorkspaceByID: func(ctx context.Context, id uuid.UUID) (database.Workspace, error) {
if id == workspaceID && id != uuid.Nil {
return database.Workspace{}, sql.ErrNoRows
}
return db.GetWorkspaceByID(ctx, id)
},
}
client := setupClient(t, wrapped, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
task, workspaceIDValue := setupWorkspaceTask(t, db, user)
workspaceID = workspaceIDValue
_, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusNotFound, apiErr.StatusCode())
})
t.Run("Workspace lookup internal error", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
var workspaceID uuid.UUID
wrapped := aiTaskStoreWrapper{
Store: db,
getWorkspaceByID: func(ctx context.Context, id uuid.UUID) (database.Workspace, error) {
if id == workspaceID && id != uuid.Nil {
return database.Workspace{}, xerrors.New("boom")
}
return db.GetWorkspaceByID(ctx, id)
},
}
client := setupClient(t, wrapped, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
task, workspaceIDValue := setupWorkspaceTask(t, db, user)
workspaceID = workspaceIDValue
_, err := client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
require.Equal(t, "Internal error fetching task workspace.", apiErr.Message)
})
t.Run("Build Forbidden", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
auth := &coderdtest.FakeAuthorizer{
ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error {
if action == policy.ActionWorkspaceStart && object.Type == rbac.ResourceWorkspace.Type {
return rbac.UnauthorizedError{}
}
return nil
},
}
client := setupClient(t, db, ps, auth)
user := coderdtest.CreateFirstUser(t, client)
task, _ := setupWorkspaceTask(t, db, user)
pauseResp, err := client.PauseTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
_, err = client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
})
t.Run("Job already in progress", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
client := setupClient(t, db, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
workspaceBuild := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).
WithTask(database.TaskTable{
Prompt: "resume me",
}, nil).
Starting().
Do()
_, err := client.ResumeTask(ctx, codersdk.Me, workspaceBuild.Task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
})
t.Run("Build Internal Error", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, ps := dbtestutil.NewDB(t)
wrapped := aiTaskStoreWrapper{
Store: db,
}
client := setupClient(t, &wrapped, ps, nil)
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionApply: echo.ApplyComplete,
ProvisionGraph: []*proto.Response{
{Type: &proto.Response_Graph{Graph: &proto.GraphComplete{
HasAiTasks: true,
}}},
},
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
task, err := client.CreateTask(ctx, codersdk.Me, codersdk.CreateTaskRequest{
TemplateVersionID: template.ActiveVersionID,
Input: "resume me",
})
require.NoError(t, err)
workspace, err := client.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
pauseResp, err := client.PauseTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, pauseResp.WorkspaceBuild.ID)
// Induce a transient failure in the database after the task has been paused.
wrapped.insertWorkspaceBuild = func(ctx context.Context, arg database.InsertWorkspaceBuildParams) error {
return xerrors.New("insert failed")
}
_, err = client.ResumeTask(ctx, codersdk.Me, task.ID)
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusInternalServerError, apiErr.StatusCode())
})
t.Run("Notification", func(t *testing.T) {
t.Parallel()
var (
notifyEnq = &notificationstest.FakeEnqueuer{}
ownerClient, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{NotificationsEnqueuer: notifyEnq})
owner = coderdtest.CreateFirstUser(t, ownerClient)
)
ctx := testutil.Context(t, testutil.WaitMedium)
ownerUser, err := ownerClient.User(ctx, owner.UserID.String())
require.NoError(t, err)
createTask := createTaskInState(db, coderdtest.AuthzUserSubject(ownerUser), owner.OrganizationID, owner.UserID)
// Given: A task in a paused state
task := createTask(ctx, t, database.TaskStatusPaused)
workspace, err := ownerClient.Workspace(ctx, task.WorkspaceID.UUID)
require.NoError(t, err)
// When: We resume the task
_, err = ownerClient.ResumeTask(ctx, codersdk.Me, task.ID)
require.NoError(t, err)
// Then: A notification should be sent
sent := notifyEnq.Sent(notificationstest.WithTemplateID(notifications.TemplateTaskResumed))
require.Len(t, sent, 1)
require.Equal(t, owner.UserID, sent[0].UserID)
require.Equal(t, task.Name, sent[0].Labels["task"])
require.Equal(t, task.ID.String(), sent[0].Labels["task_id"])
require.Equal(t, workspace.Name, sent[0].Labels["workspace"])
})
}

Some files were not shown because too many files have changed in this diff Show More