Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e60bde9b9 | |||
| 3db5558603 | |||
| 61961db41d | |||
| d2d7c0ee40 | |||
| d25d95231f | |||
| 3a62a8e70e | |||
| 7fc84ecf0b | |||
| 0ebe8e57ad | |||
| 3894edbcc3 | |||
| d5296a4855 | |||
| 5073493850 | |||
| 32354261d3 | |||
| 6683d807ac | |||
| 7c2479ce92 | |||
| e1156b050f | |||
| 0712faef4f | |||
| 7d5cd06f83 | |||
| 8d6a202ee4 | |||
| ffa83a4ebc | |||
| b3a81be1aa |
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.24.10"
|
||||
default: "1.24.11"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
.eslintcache
|
||||
.gitpod.yml
|
||||
.idea
|
||||
.run
|
||||
**/*.swp
|
||||
gotests.coverage
|
||||
gotests.xml
|
||||
|
||||
@@ -69,6 +69,9 @@ MOST_GO_SRC_FILES := $(shell \
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
MIGRATION_FILES := $(shell find ./coderd/database/migrations/ -maxdepth 1 $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
FIXTURE_FILES := $(shell find ./coderd/database/migrations/testdata/fixtures/ $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
|
||||
# Ensure we don't use the user's git configs which might cause side-effects
|
||||
GIT_FLAGS = GIT_CONFIG_GLOBAL=/dev/null GIT_CONFIG_SYSTEM=/dev/null
|
||||
|
||||
@@ -561,7 +564,7 @@ endif
|
||||
|
||||
# Note: we don't run zizmor in the lint target because it takes a while. CI
|
||||
# runs it explicitly.
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes lint/migrations
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -619,6 +622,12 @@ lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
lint/migrations:
|
||||
./scripts/check_pg_schema.sh "Migrations" $(MIGRATION_FILES)
|
||||
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
|
||||
.PHONY: lint/migrations
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
|
||||
@@ -68,6 +68,8 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestTaskStatus(),
|
||||
r.scaletestSMTP(),
|
||||
r.scaletestPrebuilds(),
|
||||
r.scaletestBridge(),
|
||||
r.scaletestLLMMock(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/bridge"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestBridge() *serpent.Command {
|
||||
var (
|
||||
concurrentUsers int64
|
||||
noCleanup bool
|
||||
mode string
|
||||
upstreamURL string
|
||||
provider string
|
||||
requestsPerUser int64
|
||||
useStreamingAPI bool
|
||||
requestPayloadSize int64
|
||||
numMessages int64
|
||||
httpTimeout time.Duration
|
||||
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "bridge",
|
||||
Short: "Generate load on the AI Bridge service.",
|
||||
Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons).
|
||||
|
||||
Examples:
|
||||
# Test OpenAI API through bridge
|
||||
coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test Anthropic API through bridge
|
||||
coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test directly against mock server
|
||||
coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions
|
||||
`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := bridge.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
var userConfig createusers.Config
|
||||
if bridge.RequestMode(mode) == bridge.RequestModeBridge {
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(me.OrganizationIDs) == 0 {
|
||||
return xerrors.Errorf("admin user must have at least one organization")
|
||||
}
|
||||
userConfig = createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL)
|
||||
}
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse output flags: %w", err)
|
||||
}
|
||||
|
||||
config := bridge.Config{
|
||||
Mode: bridge.RequestMode(mode),
|
||||
Metrics: metrics,
|
||||
Provider: provider,
|
||||
RequestCount: int(requestsPerUser),
|
||||
Stream: useStreamingAPI,
|
||||
RequestPayloadSize: int(requestPayloadSize),
|
||||
NumMessages: int(numMessages),
|
||||
HTTPTimeout: httpTimeout,
|
||||
UpstreamURL: upstreamURL,
|
||||
User: userConfig,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
if err := config.PrepareRequestBody(); err != nil {
|
||||
return xerrors.Errorf("prepare request body: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i := range concurrentUsers {
|
||||
id := strconv.Itoa(int(i))
|
||||
name := fmt.Sprintf("bridge-%s", id)
|
||||
var runner harness.Runnable = bridge.NewRunner(client, config)
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:")
|
||||
tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0)
|
||||
for _, opt := range inv.Command.Options {
|
||||
if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String())
|
||||
if opt.ValueSource != serpent.ValueSourceDefault {
|
||||
_, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource)
|
||||
}
|
||||
_, _ = fmt.Fprintln(tw)
|
||||
}
|
||||
_ = tw.Flush()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "concurrent-users",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS",
|
||||
Description: "Required: Number of concurrent users.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--concurrent-users must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "mode",
|
||||
Env: "CODER_SCALETEST_BRIDGE_MODE",
|
||||
Default: "direct",
|
||||
Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).",
|
||||
Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)),
|
||||
},
|
||||
{
|
||||
Flag: "upstream-url",
|
||||
Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL",
|
||||
Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).",
|
||||
Value: serpent.StringOf(&upstreamURL),
|
||||
},
|
||||
{
|
||||
Flag: "provider",
|
||||
Env: "CODER_SCALETEST_BRIDGE_PROVIDER",
|
||||
Default: "openai",
|
||||
Description: "API provider to use.",
|
||||
Value: serpent.EnumOf(&provider, "openai", "anthropic"),
|
||||
},
|
||||
{
|
||||
Flag: "request-count",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT",
|
||||
Default: "1",
|
||||
Description: "Number of sequential requests to make per runner.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--request-count must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
{
|
||||
Flag: "stream",
|
||||
Env: "CODER_SCALETEST_BRIDGE_STREAM",
|
||||
Description: "Enable streaming requests.",
|
||||
Value: serpent.BoolOf(&useStreamingAPI),
|
||||
},
|
||||
{
|
||||
Flag: "request-payload-size",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE",
|
||||
Default: "1024",
|
||||
Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.",
|
||||
Value: serpent.Int64Of(&requestPayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "num-messages",
|
||||
Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES",
|
||||
Default: "1",
|
||||
Description: "Number of messages to include in the conversation.",
|
||||
Value: serpent.Int64Of(&numMessages),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "http-timeout",
|
||||
Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT",
|
||||
Default: "30s",
|
||||
Description: "Timeout for individual HTTP requests to the upstream provider.",
|
||||
Value: serpent.DurationOf(&httpTimeout),
|
||||
},
|
||||
}
|
||||
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) scaletestLLMMock() *serpent.Command {
|
||||
var (
|
||||
address string
|
||||
artificialLatency time.Duration
|
||||
responsePayloadSize int64
|
||||
|
||||
pprofEnable bool
|
||||
pprofAddress string
|
||||
|
||||
traceEnable bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "llm-mock",
|
||||
Short: "Start a mock LLM API server for testing",
|
||||
Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...)
|
||||
defer stop()
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
|
||||
|
||||
if pprofEnable {
|
||||
closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof")
|
||||
defer closePprof()
|
||||
logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress))
|
||||
}
|
||||
|
||||
config := llmmock.Config{
|
||||
Address: address,
|
||||
Logger: logger,
|
||||
ArtificialLatency: artificialLatency,
|
||||
ResponsePayloadSize: int(responsePayloadSize),
|
||||
PprofEnable: pprofEnable,
|
||||
PprofAddress: pprofAddress,
|
||||
TraceEnable: traceEnable,
|
||||
}
|
||||
srv := new(llmmock.Server)
|
||||
|
||||
if err := srv.Start(ctx, config); err != nil {
|
||||
return xerrors.Errorf("start mock LLM server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress())
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = []serpent.Option{
|
||||
{
|
||||
Flag: "address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS",
|
||||
Default: "localhost",
|
||||
Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.",
|
||||
Value: serpent.StringOf(&address),
|
||||
},
|
||||
{
|
||||
Flag: "artificial-latency",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY",
|
||||
Default: "0s",
|
||||
Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.",
|
||||
Value: serpent.DurationOf(&artificialLatency),
|
||||
},
|
||||
{
|
||||
Flag: "response-payload-size",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE",
|
||||
Default: "0",
|
||||
Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.",
|
||||
Value: serpent.Int64Of(&responsePayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Serve pprof metrics on the address defined by pprof-address.",
|
||||
Value: serpent.BoolOf(&pprofEnable),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS",
|
||||
Default: "127.0.0.1:6060",
|
||||
Description: "The bind address to serve pprof.",
|
||||
Value: serpent.StringOf(&pprofAddress),
|
||||
},
|
||||
{
|
||||
Flag: "trace-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
|
||||
Value: serpent.BoolOf(&traceEnable),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -65,6 +65,22 @@ func (r *RootCmd) organizationSettings(orgContext *OrganizationContext) *serpent
|
||||
return cli.OrganizationIDPSyncSettings(ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "workspace-sharing",
|
||||
Aliases: []string{"workspacesharing"},
|
||||
Short: "Workspace sharing settings for the organization.",
|
||||
Patch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID, input json.RawMessage) (any, error) {
|
||||
var req codersdk.WorkspaceSharingSettings
|
||||
err := json.Unmarshal(input, &req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshalling workspace sharing settings: %w", err)
|
||||
}
|
||||
return cli.PatchWorkspaceSharingSettings(ctx, org.String(), req)
|
||||
},
|
||||
Fetch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID) (any, error) {
|
||||
return cli.WorkspaceSharingSettings(ctx, org.String())
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd := &serpent.Command{
|
||||
Use: "settings",
|
||||
|
||||
@@ -15,6 +15,7 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,6 +15,7 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,6 +15,7 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,6 +15,7 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
+4
@@ -147,6 +147,10 @@ AI BRIDGE OPTIONS:
|
||||
Maximum number of AI Bridge requests per second per replica. Set to 0
|
||||
to disable (unlimited).
|
||||
|
||||
--aibridge-structured-logging bool, $CODER_AIBRIDGE_STRUCTURED_LOGGING (default: false)
|
||||
Emit structured logs for AI Bridge interception records. Use this for
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file for AI Bridge Proxy.
|
||||
|
||||
+4
@@ -773,6 +773,10 @@ aibridge:
|
||||
# (unlimited).
|
||||
# (default: 0, type: int)
|
||||
rateLimit: 0
|
||||
# Emit structured logs for AI Bridge interception records. Use this for exporting
|
||||
# these records to external SIEM or observability systems.
|
||||
# (default: false, type: bool)
|
||||
structuredLogging: false
|
||||
aibridgeproxy:
|
||||
# Enable the AI Bridge MITM Proxy for intercepting and decrypting AI provider
|
||||
# requests.
|
||||
|
||||
Generated
+165
-20
@@ -2628,7 +2628,8 @@ const docTemplate = `{
|
||||
},
|
||||
{
|
||||
"enum": [
|
||||
"code"
|
||||
"code",
|
||||
"token"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
@@ -2683,7 +2684,8 @@ const docTemplate = `{
|
||||
},
|
||||
{
|
||||
"enum": [
|
||||
"code"
|
||||
"code",
|
||||
"token"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
@@ -2914,7 +2916,10 @@ const docTemplate = `{
|
||||
{
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token"
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Grant type",
|
||||
@@ -4566,6 +4571,86 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/settings/workspace-sharing": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Enterprise"
|
||||
],
|
||||
"summary": "Get workspace sharing settings for organization",
|
||||
"operationId": "get-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patch": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Enterprise"
|
||||
],
|
||||
"summary": "Update workspace sharing settings for organization",
|
||||
"operationId": "update-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Workspace sharing settings",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/templates": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -11970,6 +12055,9 @@ const docTemplate = `{
|
||||
},
|
||||
"retention": {
|
||||
"type": "integer"
|
||||
},
|
||||
"structured_logging": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15761,13 +15849,13 @@ const docTemplate = `{
|
||||
"code_challenge_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2PKCECodeChallengeMethod"
|
||||
}
|
||||
},
|
||||
"grant_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"issuer": {
|
||||
@@ -15779,7 +15867,7 @@ const docTemplate = `{
|
||||
"response_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"revocation_endpoint": {
|
||||
@@ -15797,7 +15885,7 @@ const docTemplate = `{
|
||||
"token_endpoint_auth_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15829,7 +15917,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -15851,10 +15939,7 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
"type": "string"
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -15862,7 +15947,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -15875,7 +15960,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -15900,7 +15985,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -15924,7 +16009,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -15940,7 +16025,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -15977,7 +16062,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -16007,7 +16092,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -16020,7 +16105,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -16073,6 +16158,17 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2PKCECodeChallengeMethod": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"S256",
|
||||
"plain"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2PKCECodeChallengeMethodS256",
|
||||
"OAuth2PKCECodeChallengeMethodPlain"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProtectedResourceMetadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -16152,6 +16248,47 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2ProviderGrantType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderGrantTypeAuthorizationCode",
|
||||
"OAuth2ProviderGrantTypeRefreshToken",
|
||||
"OAuth2ProviderGrantTypePassword",
|
||||
"OAuth2ProviderGrantTypeClientCredentials",
|
||||
"OAuth2ProviderGrantTypeImplicit"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProviderResponseType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"code",
|
||||
"token"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderResponseTypeCode",
|
||||
"OAuth2ProviderResponseTypeToken"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2TokenEndpointAuthMethod": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretBasic",
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretPost",
|
||||
"OAuth2TokenEndpointAuthMethodNone"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuthConversionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21451,6 +21588,14 @@ const docTemplate = `{
|
||||
"WorkspaceRoleDeleted"
|
||||
]
|
||||
},
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sharing_disabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.WorkspaceStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+146
-20
@@ -2304,7 +2304,7 @@
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"enum": ["code"],
|
||||
"enum": ["code", "token"],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
"name": "response_type",
|
||||
@@ -2355,7 +2355,7 @@
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"enum": ["code"],
|
||||
"enum": ["code", "token"],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
"name": "response_type",
|
||||
@@ -2555,7 +2555,13 @@
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"enum": ["authorization_code", "refresh_token"],
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Grant type",
|
||||
"name": "grant_type",
|
||||
@@ -4036,6 +4042,76 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/settings/workspace-sharing": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Get workspace sharing settings for organization",
|
||||
"operationId": "get-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patch": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Update workspace sharing settings for organization",
|
||||
"operationId": "update-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Workspace sharing settings",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/templates": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -10631,6 +10707,9 @@
|
||||
},
|
||||
"retention": {
|
||||
"type": "integer"
|
||||
},
|
||||
"structured_logging": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14280,13 +14359,13 @@
|
||||
"code_challenge_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2PKCECodeChallengeMethod"
|
||||
}
|
||||
},
|
||||
"grant_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"issuer": {
|
||||
@@ -14298,7 +14377,7 @@
|
||||
"response_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"revocation_endpoint": {
|
||||
@@ -14316,7 +14395,7 @@
|
||||
"token_endpoint_auth_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14348,7 +14427,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14370,10 +14449,7 @@
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
"type": "string"
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -14381,7 +14457,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14394,7 +14470,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14419,7 +14495,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14443,7 +14519,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14459,7 +14535,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14496,7 +14572,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14526,7 +14602,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14539,7 +14615,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"type": "string"
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14592,6 +14668,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2PKCECodeChallengeMethod": {
|
||||
"type": "string",
|
||||
"enum": ["S256", "plain"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2PKCECodeChallengeMethodS256",
|
||||
"OAuth2PKCECodeChallengeMethodPlain"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProtectedResourceMetadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -14671,6 +14755,40 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2ProviderGrantType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderGrantTypeAuthorizationCode",
|
||||
"OAuth2ProviderGrantTypeRefreshToken",
|
||||
"OAuth2ProviderGrantTypePassword",
|
||||
"OAuth2ProviderGrantTypeClientCredentials",
|
||||
"OAuth2ProviderGrantTypeImplicit"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProviderResponseType": {
|
||||
"type": "string",
|
||||
"enum": ["code", "token"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderResponseTypeCode",
|
||||
"OAuth2ProviderResponseTypeToken"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2TokenEndpointAuthMethod": {
|
||||
"type": "string",
|
||||
"enum": ["client_secret_basic", "client_secret_post", "none"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretBasic",
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretPost",
|
||||
"OAuth2TokenEndpointAuthMethodNone"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuthConversionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19730,6 +19848,14 @@
|
||||
"WorkspaceRoleDeleted"
|
||||
]
|
||||
},
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sharing_disabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.WorkspaceStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
+7
-4
@@ -205,7 +205,7 @@ type Options struct {
|
||||
// tokens issued by and passed to the coordinator DRPC API.
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
HealthcheckRefresh time.Duration
|
||||
WorkspaceProxiesFetchUpdater *atomic.Pointer[healthcheck.WorkspaceProxiesFetchUpdater]
|
||||
@@ -681,7 +681,7 @@ func New(options *Options) *API {
|
||||
}
|
||||
|
||||
if options.HealthcheckFunc == nil {
|
||||
options.HealthcheckFunc = func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
options.HealthcheckFunc = func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
// NOTE: dismissed healthchecks are marked in formatHealthcheck.
|
||||
// Not here, as this result gets cached.
|
||||
return healthcheck.Run(ctx, &healthcheck.ReportOptions{
|
||||
@@ -709,6 +709,7 @@ func New(options *Options) *API {
|
||||
StaleInterval: provisionerdserver.StaleInterval,
|
||||
// TimeNow set to default, see healthcheck/provisioner.go
|
||||
},
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -881,6 +882,7 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
func(next http.Handler) http.Handler {
|
||||
@@ -1859,8 +1861,9 @@ type API struct {
|
||||
// This is used to gate features that are not yet ready for production.
|
||||
Experiments codersdk.Experiments
|
||||
|
||||
healthCheckGroup *singleflight.Group[string, *healthsdk.HealthcheckReport]
|
||||
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
|
||||
healthCheckGroup *singleflight.Group[string, *healthsdk.HealthcheckReport]
|
||||
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
|
||||
healthCheckProgress healthcheck.Progress
|
||||
|
||||
statsReporter *workspacestats.Reporter
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jobreaper"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
@@ -131,7 +132,7 @@ type Options struct {
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
ConnectionLogger connectionlog.ConnectionLogger
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
HealthcheckRefresh time.Duration
|
||||
|
||||
|
||||
@@ -1965,6 +1965,14 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
w, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID)
|
||||
if err != nil {
|
||||
@@ -3592,7 +3600,7 @@ func (q *querier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (databa
|
||||
if err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionShare, workspace); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, workspace); err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
return q.db.GetWorkspaceACLByID(ctx, id)
|
||||
@@ -5099,6 +5107,13 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
|
||||
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
fetch := func(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
return q.db.GetOrganizationByID(ctx, arg.ID)
|
||||
}
|
||||
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateOrganizationWorkspaceSharingSettings)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
|
||||
|
||||
@@ -880,6 +880,16 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
dbm.EXPECT().InsertOrganization(gomock.Any(), arg).Return(database.Organization{ID: arg.ID, Name: arg.Name}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceOrganization, policy.ActionCreate)
|
||||
}))
|
||||
s.Run("UpdateOrganizationWorkspaceSharingSettings", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(org, policy.ActionUpdate).Returns(org)
|
||||
}))
|
||||
s.Run("InsertOrganizationMember", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
o := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
@@ -1784,7 +1794,7 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbM.EXPECT().GetWorkspaceACLByID(gomock.Any(), ws.ID).Return(database.GetWorkspaceACLByIDRow{}, nil).AnyTimes()
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionShare)
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionRead)
|
||||
}))
|
||||
s.Run("UpdateWorkspaceACLByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -1799,6 +1809,11 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().DeleteWorkspaceACLByID(gomock.Any(), w.ID).Return(nil).AnyTimes()
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: w.ID})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1055,6 +1055,20 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLByID", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6645,6 +6659,21 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationDeletedByID", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationDeletedByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateOrganizationWorkspaceSharingSettings mocks base method.
|
||||
func (m *MockStore) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateOrganizationWorkspaceSharingSettings", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Organization)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateOrganizationWorkspaceSharingSettings indicates an expected call of UpdateOrganizationWorkspaceSharingSettings.
|
||||
func (mr *MockStoreMockRecorder) UpdateOrganizationWorkspaceSharingSettings(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationWorkspaceSharingSettings", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationWorkspaceSharingSettings), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -1 +1 @@
|
||||
DROP INDEX IF EXISTS public.workspace_agents_auth_instance_id_deleted_idx;
|
||||
DROP INDEX IF EXISTS workspace_agents_auth_instance_id_deleted_idx;
|
||||
|
||||
@@ -1 +1 @@
|
||||
CREATE INDEX IF NOT EXISTS workspace_agents_auth_instance_id_deleted_idx ON public.workspace_agents (auth_instance_id, deleted);
|
||||
CREATE INDEX IF NOT EXISTS workspace_agents_auth_instance_id_deleted_idx ON workspace_agents (auth_instance_id, deleted);
|
||||
|
||||
+685
-685
File diff suppressed because one or more lines are too long
@@ -1,34 +1,34 @@
|
||||
-- This is a deleted user that shares the same username and linked_id as the existing user below.
|
||||
-- Any future migrations need to handle this case.
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'githubuser@coder.com', 'githubuser', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', true) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.organization_members VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO organization_members VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('a0061a8e-7db7-4585-838c-3116a003dd21', 'github', '100', '');
|
||||
|
||||
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'githubuser@coder.com', 'githubuser', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.organization_members VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO organization_members VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'github', '100', '');
|
||||
|
||||
-- Additionally, there is no unique constraint on user_id. So also add another user_link for the same user.
|
||||
-- This has happened on a production database.
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'oidc', 'foo', '');
|
||||
|
||||
|
||||
-- Lastly, make 2 other users who have the same user link.
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'dup_link_a@coder.com', 'dupe_a', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.organization_members VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO organization_members VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('580ed397-727d-4aaf-950a-51f89f556c24', 'github', '500', '');
|
||||
|
||||
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'dup_link_b@coder.com', 'dupe_b', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.organization_members VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO organization_members VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'github', '500', '');
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
INSERT INTO public.workspace_app_stats (
|
||||
INSERT INTO workspace_app_stats (
|
||||
id,
|
||||
user_id,
|
||||
workspace_id,
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
INSERT INTO
|
||||
public.workspace_modules (
|
||||
workspace_modules (
|
||||
id,
|
||||
job_id,
|
||||
transition,
|
||||
|
||||
+8
-8
@@ -1,15 +1,15 @@
|
||||
INSERT INTO public.organizations (id, name, description, created_at, updated_at, is_default, display_name, icon) VALUES ('20362772-802a-4a72-8e4f-3648b4bfd168', 'strange_hopper58', 'wizardly_stonebraker60', '2025-02-07 07:46:19.507551 +00:00', '2025-02-07 07:46:19.507552 +00:00', false, 'competent_rhodes59', '');
|
||||
INSERT INTO organizations (id, name, description, created_at, updated_at, is_default, display_name, icon) VALUES ('20362772-802a-4a72-8e4f-3648b4bfd168', 'strange_hopper58', 'wizardly_stonebraker60', '2025-02-07 07:46:19.507551 +00:00', '2025-02-07 07:46:19.507552 +00:00', false, 'competent_rhodes59', '');
|
||||
|
||||
INSERT INTO public.users (id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, quiet_hours_schedule, theme_preference, name, github_com_user_id, hashed_one_time_passcode, one_time_passcode_expires_at) VALUES ('6c353aac-20de-467b-bdfb-3c30a37adcd2', 'vigorous_murdock61', 'affectionate_hawking62', 'lqTu9C5363AwD7NVNH6noaGjp91XIuZJ', '2025-02-07 07:46:19.510861 +00:00', '2025-02-07 07:46:19.512949 +00:00', 'active', '{}', 'password', '', false, '0001-01-01 00:00:00.000000', '', '', 'vigilant_hugle63', null, null, null);
|
||||
INSERT INTO users (id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, quiet_hours_schedule, theme_preference, name, github_com_user_id, hashed_one_time_passcode, one_time_passcode_expires_at) VALUES ('6c353aac-20de-467b-bdfb-3c30a37adcd2', 'vigorous_murdock61', 'affectionate_hawking62', 'lqTu9C5363AwD7NVNH6noaGjp91XIuZJ', '2025-02-07 07:46:19.510861 +00:00', '2025-02-07 07:46:19.512949 +00:00', 'active', '{}', 'password', '', false, '0001-01-01 00:00:00.000000', '', '', 'vigilant_hugle63', null, null, null);
|
||||
|
||||
INSERT INTO public.templates (id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level) VALUES ('6b298946-7a4f-47ac-9158-b03b08740a41', '2025-02-07 07:46:19.513317 +00:00', '2025-02-07 07:46:19.513317 +00:00', '20362772-802a-4a72-8e4f-3648b4bfd168', false, 'modest_leakey64', 'echo', 'e6cfa2a4-e4cf-4182-9e19-08b975682a28', 'upbeat_wright65', 604800000000000, '6c353aac-20de-467b-bdfb-3c30a37adcd2', 'nervous_keller66', '{}', '{"20362772-802a-4a72-8e4f-3648b4bfd168": ["read", "use"]}', 'determined_aryabhata67', false, true, true, 0, 0, 0, 0, 0, 0, false, '', 3600000000000, 'owner');
|
||||
INSERT INTO public.template_versions (id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id) VALUES ('af58bd62-428c-4c33-849b-d43a3be07d93', '6b298946-7a4f-47ac-9158-b03b08740a41', '20362772-802a-4a72-8e4f-3648b4bfd168', '2025-02-07 07:46:19.514782 +00:00', '2025-02-07 07:46:19.514782 +00:00', 'distracted_shockley68', 'sleepy_turing69', 'f2e2ea1c-5aa3-4a1d-8778-2e5071efae59', '6c353aac-20de-467b-bdfb-3c30a37adcd2', '[]', '', false, null);
|
||||
INSERT INTO templates (id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level) VALUES ('6b298946-7a4f-47ac-9158-b03b08740a41', '2025-02-07 07:46:19.513317 +00:00', '2025-02-07 07:46:19.513317 +00:00', '20362772-802a-4a72-8e4f-3648b4bfd168', false, 'modest_leakey64', 'echo', 'e6cfa2a4-e4cf-4182-9e19-08b975682a28', 'upbeat_wright65', 604800000000000, '6c353aac-20de-467b-bdfb-3c30a37adcd2', 'nervous_keller66', '{}', '{"20362772-802a-4a72-8e4f-3648b4bfd168": ["read", "use"]}', 'determined_aryabhata67', false, true, true, 0, 0, 0, 0, 0, 0, false, '', 3600000000000, 'owner');
|
||||
INSERT INTO template_versions (id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id) VALUES ('af58bd62-428c-4c33-849b-d43a3be07d93', '6b298946-7a4f-47ac-9158-b03b08740a41', '20362772-802a-4a72-8e4f-3648b4bfd168', '2025-02-07 07:46:19.514782 +00:00', '2025-02-07 07:46:19.514782 +00:00', 'distracted_shockley68', 'sleepy_turing69', 'f2e2ea1c-5aa3-4a1d-8778-2e5071efae59', '6c353aac-20de-467b-bdfb-3c30a37adcd2', '[]', '', false, null);
|
||||
|
||||
INSERT INTO public.template_version_presets (id, template_version_id, name, created_at) VALUES ('28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'af58bd62-428c-4c33-849b-d43a3be07d93', 'test', '0001-01-01 00:00:00.000000 +00:00');
|
||||
INSERT INTO template_version_presets (id, template_version_id, name, created_at) VALUES ('28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'af58bd62-428c-4c33-849b-d43a3be07d93', 'test', '0001-01-01 00:00:00.000000 +00:00');
|
||||
|
||||
-- Add presets with the same template version ID and name
|
||||
-- to ensure they're correctly handled by the 00031*_preset_prebuilds migration.
|
||||
INSERT INTO public.template_version_presets (
|
||||
INSERT INTO template_version_presets (
|
||||
id, template_version_id, name, created_at
|
||||
)
|
||||
VALUES (
|
||||
@@ -19,7 +19,7 @@ VALUES (
|
||||
'0001-01-01 00:00:00.000000 +00:00'
|
||||
);
|
||||
|
||||
INSERT INTO public.template_version_presets (
|
||||
INSERT INTO template_version_presets (
|
||||
id, template_version_id, name, created_at
|
||||
)
|
||||
VALUES (
|
||||
@@ -29,4 +29,4 @@ VALUES (
|
||||
'0001-01-01 00:00:00.000000 +00:00'
|
||||
);
|
||||
|
||||
INSERT INTO public.template_version_preset_parameters (id, template_version_preset_id, name, value) VALUES ('ea90ccd2-5024-459e-87e4-879afd24de0f', '28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'test', 'test');
|
||||
INSERT INTO template_version_preset_parameters (id, template_version_preset_id, name, value) VALUES ('ea90ccd2-5024-459e-87e4-879afd24de0f', '28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'test', 'test');
|
||||
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
INSERT INTO public.tasks VALUES (
|
||||
INSERT INTO tasks VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- id
|
||||
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization_id
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307', -- owner_id
|
||||
@@ -11,7 +11,7 @@ INSERT INTO public.tasks VALUES (
|
||||
NULL -- deleted_at
|
||||
) ON CONFLICT DO NOTHING;
|
||||
|
||||
INSERT INTO public.task_workspace_apps VALUES (
|
||||
INSERT INTO task_workspace_apps VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- task_id
|
||||
'a8c0b8c5-c9a8-4f33-93a4-8142e6858244', -- workspace_build_id
|
||||
'8fa17bbd-c48c-44c7-91ae-d4acbc755fad', -- workspace_agent_id
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
INSERT INTO public.task_workspace_apps VALUES (
|
||||
INSERT INTO task_workspace_apps VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- task_id
|
||||
NULL, -- workspace_agent_id
|
||||
NULL, -- workspace_app_id
|
||||
|
||||
@@ -139,6 +139,7 @@ type sqlcQuerier interface {
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
|
||||
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
|
||||
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -677,6 +678,7 @@ type sqlcQuerier interface {
|
||||
UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg UpdateOAuth2ProviderAppSecretByIDParams) (OAuth2ProviderAppSecret, error)
|
||||
UpdateOrganization(ctx context.Context, arg UpdateOrganizationParams) (Organization, error)
|
||||
UpdateOrganizationDeletedByID(ctx context.Context, arg UpdateOrganizationDeletedByIDParams) error
|
||||
UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg UpdateOrganizationWorkspaceSharingSettingsParams) (Organization, error)
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
|
||||
@@ -2304,6 +2304,94 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
||||
require.True(t, roles[0].IsSystem)
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, updated.WorkspaceSharingDisabled)
|
||||
|
||||
got, err := db.GetOrganizationByID(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.WorkspaceSharingDisabled)
|
||||
}
|
||||
|
||||
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
}
|
||||
|
||||
func TestAuthorizedAuditLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -8197,6 +8197,41 @@ func (q *sqlQuerier) UpdateOrganizationDeletedByID(ctx context.Context, arg Upda
|
||||
return err
|
||||
}
|
||||
|
||||
const updateOrganizationWorkspaceSharingSettings = `-- name: UpdateOrganizationWorkspaceSharingSettings :one
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
workspace_sharing_disabled = $1,
|
||||
updated_at = $2
|
||||
WHERE
|
||||
id = $3
|
||||
RETURNING id, name, description, created_at, updated_at, is_default, display_name, icon, deleted, workspace_sharing_disabled
|
||||
`
|
||||
|
||||
type UpdateOrganizationWorkspaceSharingSettingsParams struct {
|
||||
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg UpdateOrganizationWorkspaceSharingSettingsParams) (Organization, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOrganizationWorkspaceSharingSettings, arg.WorkspaceSharingDisabled, arg.UpdatedAt, arg.ID)
|
||||
var i Organization
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsDefault,
|
||||
&i.DisplayName,
|
||||
&i.Icon,
|
||||
&i.Deleted,
|
||||
&i.WorkspaceSharingDisabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getParameterSchemasByJobID = `-- name: GetParameterSchemasByJobID :many
|
||||
SELECT
|
||||
id, created_at, job_id, name, description, default_source_scheme, default_source_value, allow_override_source, default_destination_scheme, allow_override_destination, default_refresh, redisplay_value, validation_error, validation_condition, validation_type_system, validation_value_type, index
|
||||
@@ -22151,6 +22186,21 @@ func (q *sqlQuerier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) e
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteWorkspaceACLsByOrganization = `-- name: DeleteWorkspaceACLsByOrganization :exec
|
||||
UPDATE
|
||||
workspaces
|
||||
SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteWorkspaceACLsByOrganization, organizationID)
|
||||
return err
|
||||
}
|
||||
|
||||
const favoriteWorkspace = `-- name: FavoriteWorkspace :exec
|
||||
UPDATE workspaces SET favorite = true WHERE id = $1
|
||||
`
|
||||
|
||||
@@ -143,3 +143,13 @@ WHERE
|
||||
id = @id AND
|
||||
is_default = false;
|
||||
|
||||
-- name: UpdateOrganizationWorkspaceSharingSettings :one
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
workspace_sharing_disabled = @workspace_sharing_disabled,
|
||||
updated_at = @updated_at
|
||||
WHERE
|
||||
id = @id
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -947,6 +947,15 @@ SET
|
||||
WHERE
|
||||
id = @id;
|
||||
|
||||
-- name: DeleteWorkspaceACLsByOrganization :exec
|
||||
UPDATE
|
||||
workspaces
|
||||
SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = @organization_id;
|
||||
|
||||
-- name: GetRegularWorkspaceCreateMetrics :many
|
||||
-- Count regular workspaces: only those whose first successful 'start' build
|
||||
-- was not initiated by the prebuild system user.
|
||||
|
||||
+6
-2
@@ -83,17 +83,21 @@ func (api *API) debugDeploymentHealth(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), api.Options.HealthcheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
report := api.HealthcheckFunc(ctx, apiKey)
|
||||
// Create and store progress tracker for timeout diagnostics.
|
||||
report := api.HealthcheckFunc(ctx, apiKey, &api.healthCheckProgress)
|
||||
if report != nil { // Only store non-nil reports.
|
||||
api.healthCheckCache.Store(report)
|
||||
}
|
||||
api.healthCheckProgress.Reset()
|
||||
return report, nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
summary := api.healthCheckProgress.Summary()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Healthcheck is in progress and did not complete in time. Try again in a few seconds.",
|
||||
Message: "Healthcheck timed out.",
|
||||
Detail: summary,
|
||||
})
|
||||
return
|
||||
case res := <-resChan:
|
||||
|
||||
+20
-17
@@ -14,6 +14,8 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -28,7 +30,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
calls.Add(1)
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
@@ -61,7 +63,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
calls.Add(1)
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
@@ -93,19 +95,14 @@ func TestDebugHealth(t *testing.T) {
|
||||
// Need to ignore errors due to ctx timeout
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
done = make(chan struct{})
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
Logger: &logger,
|
||||
HealthcheckTimeout: time.Microsecond,
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
t := time.NewTimer(time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
case <-t.C:
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
}
|
||||
HealthcheckTimeout: time.Second,
|
||||
HealthcheckFunc: func(_ context.Context, _ string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
progress.Start("test")
|
||||
<-done
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
@@ -115,8 +112,14 @@ func TestDebugHealth(t *testing.T) {
|
||||
res, err := client.Request(ctx, "GET", "/api/v2/debug/health", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
_, _ = io.ReadAll(res.Body)
|
||||
close(done)
|
||||
bs, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "reading body")
|
||||
require.Equal(t, http.StatusServiceUnavailable, res.StatusCode)
|
||||
var sdkResp codersdk.Response
|
||||
require.NoError(t, json.Unmarshal(bs, &sdkResp), "unmarshaling sdk response")
|
||||
require.Equal(t, "Healthcheck timed out.", sdkResp.Message)
|
||||
require.Contains(t, sdkResp.Detail, "Still running: test (elapsed:")
|
||||
})
|
||||
|
||||
t.Run("Refresh", func(t *testing.T) {
|
||||
@@ -128,7 +131,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckRefresh: time.Microsecond,
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(context.Context, string, *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
calls <- struct{}{}
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
},
|
||||
@@ -173,7 +176,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckRefresh: time.Hour,
|
||||
HealthcheckTimeout: time.Hour,
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(context.Context, string, *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
calls++
|
||||
return &healthsdk.HealthcheckReport{
|
||||
Time: time.Now(),
|
||||
@@ -207,7 +210,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
Time: time.Now(),
|
||||
|
||||
@@ -2,6 +2,9 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -10,8 +13,91 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Progress tracks the progress of healthcheck components for timeout
|
||||
// diagnostics. It records which checks have started and completed, along with
|
||||
// their durations, to provide useful information when a healthcheck times out.
|
||||
// The zero value is usable.
|
||||
type Progress struct {
|
||||
Clock quartz.Clock
|
||||
mu sync.Mutex
|
||||
checks map[string]*checkStatus
|
||||
}
|
||||
|
||||
type checkStatus struct {
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
}
|
||||
|
||||
// Start records that a check has started.
|
||||
func (p *Progress) Start(name string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.Clock == nil {
|
||||
p.Clock = quartz.NewReal()
|
||||
}
|
||||
if p.checks == nil {
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
p.checks[name] = &checkStatus{startedAt: p.Clock.Now()}
|
||||
}
|
||||
|
||||
// Complete records that a check has finished.
|
||||
func (p *Progress) Complete(name string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.Clock == nil {
|
||||
p.Clock = quartz.NewReal()
|
||||
}
|
||||
if p.checks == nil {
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
if p.checks[name] == nil {
|
||||
p.checks[name] = &checkStatus{startedAt: p.Clock.Now()}
|
||||
}
|
||||
p.checks[name].completedAt = p.Clock.Now()
|
||||
}
|
||||
|
||||
// Reset clears all recorded check statuses.
|
||||
func (p *Progress) Reset() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
|
||||
// Summary returns a human-readable summary of check progress.
|
||||
// Example: "Completed: AccessURL (95ms), Database (120ms). Still running: DERP, Websocket"
|
||||
func (p *Progress) Summary() string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
var completed, running []string
|
||||
for name, status := range p.checks {
|
||||
if status.completedAt.IsZero() {
|
||||
elapsed := p.Clock.Now().Sub(status.startedAt).Round(time.Millisecond)
|
||||
running = append(running, fmt.Sprintf("%s (elapsed: %dms)", name, elapsed.Milliseconds()))
|
||||
continue
|
||||
}
|
||||
duration := status.completedAt.Sub(status.startedAt).Round(time.Millisecond)
|
||||
completed = append(completed, fmt.Sprintf("%s (%dms)", name, duration.Milliseconds()))
|
||||
}
|
||||
|
||||
// Sort for consistent output.
|
||||
slices.Sort(completed)
|
||||
slices.Sort(running)
|
||||
|
||||
var parts []string
|
||||
if len(completed) > 0 {
|
||||
parts = append(parts, "Completed: "+strings.Join(completed, ", "))
|
||||
}
|
||||
if len(running) > 0 {
|
||||
parts = append(parts, "Still running: "+strings.Join(running, ", "))
|
||||
}
|
||||
return strings.Join(parts, ". ")
|
||||
}
|
||||
|
||||
type Checker interface {
|
||||
DERP(ctx context.Context, opts *derphealth.ReportOptions) healthsdk.DERPHealthReport
|
||||
AccessURL(ctx context.Context, opts *AccessURLReportOptions) healthsdk.AccessURLReport
|
||||
@@ -30,6 +116,10 @@ type ReportOptions struct {
|
||||
ProvisionerDaemons ProvisionerDaemonsReportDeps
|
||||
|
||||
Checker Checker
|
||||
|
||||
// Progress tracks healthcheck progress for timeout diagnostics.
|
||||
// If set, each check will record its start and completion time.
|
||||
Progress *Progress
|
||||
}
|
||||
|
||||
type defaultChecker struct{}
|
||||
@@ -89,6 +179,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("DERP")
|
||||
defer opts.Progress.Complete("DERP")
|
||||
}
|
||||
report.DERP = opts.Checker.DERP(ctx, &opts.DerpHealth)
|
||||
}()
|
||||
|
||||
@@ -101,6 +195,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("AccessURL")
|
||||
defer opts.Progress.Complete("AccessURL")
|
||||
}
|
||||
report.AccessURL = opts.Checker.AccessURL(ctx, &opts.AccessURL)
|
||||
}()
|
||||
|
||||
@@ -113,6 +211,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("Websocket")
|
||||
defer opts.Progress.Complete("Websocket")
|
||||
}
|
||||
report.Websocket = opts.Checker.Websocket(ctx, &opts.Websocket)
|
||||
}()
|
||||
|
||||
@@ -125,6 +227,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("Database")
|
||||
defer opts.Progress.Complete("Database")
|
||||
}
|
||||
report.Database = opts.Checker.Database(ctx, &opts.Database)
|
||||
}()
|
||||
|
||||
@@ -137,6 +243,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("WorkspaceProxy")
|
||||
defer opts.Progress.Complete("WorkspaceProxy")
|
||||
}
|
||||
report.WorkspaceProxy = opts.Checker.WorkspaceProxy(ctx, &opts.WorkspaceProxy)
|
||||
}()
|
||||
|
||||
@@ -149,6 +259,10 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("ProvisionerDaemons")
|
||||
defer opts.Progress.Complete("ProvisionerDaemons")
|
||||
}
|
||||
report.ProvisionerDaemons = opts.Checker.ProvisionerDaemons(ctx, &opts.ProvisionerDaemons)
|
||||
}()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package healthcheck_test
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type testChecker struct {
|
||||
@@ -533,3 +535,69 @@ func TestHealthcheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckProgress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Summary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
|
||||
// Start some checks
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
progress.Start("AccessURL")
|
||||
|
||||
// Advance time to simulate check duration
|
||||
mClock.Advance(100 * time.Millisecond)
|
||||
|
||||
// Complete some checks
|
||||
progress.Complete("Database")
|
||||
progress.Complete("AccessURL")
|
||||
|
||||
summary := progress.Summary()
|
||||
|
||||
// Verify completed and running checks are listed with duration / elapsed
|
||||
assert.Equal(t, summary, "Completed: AccessURL (100ms), Database (100ms). Still running: DERP (elapsed: 100ms)")
|
||||
})
|
||||
|
||||
t.Run("EmptyProgress", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
summary := progress.Summary()
|
||||
|
||||
// Should be empty string when nothing tracked
|
||||
assert.Empty(t, summary)
|
||||
})
|
||||
|
||||
t.Run("AllCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
mClock.Advance(50 * time.Millisecond)
|
||||
progress.Complete("Database")
|
||||
progress.Complete("DERP")
|
||||
|
||||
summary := progress.Summary()
|
||||
assert.Equal(t, summary, "Completed: DERP (50ms), Database (50ms)")
|
||||
})
|
||||
|
||||
t.Run("AllRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
|
||||
summary := progress.Summary()
|
||||
assert.Equal(t, summary, "Still running: DERP (elapsed: 0ms), Database (elapsed: 0ms)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -493,16 +493,10 @@ func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
|
||||
return sendEvent, closed, nil
|
||||
}
|
||||
|
||||
// OAuth2Error represents an OAuth2-compliant error response per RFC 6749.
|
||||
type OAuth2Error struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// WriteOAuth2Error writes an OAuth2-compliant error response per RFC 6749.
|
||||
// This should be used for all OAuth2 endpoints (/oauth2/*) to ensure compliance.
|
||||
func WriteOAuth2Error(ctx context.Context, rw http.ResponseWriter, status int, errorCode, description string) {
|
||||
Write(ctx, rw, status, OAuth2Error{
|
||||
func WriteOAuth2Error(ctx context.Context, rw http.ResponseWriter, status int, errorCode codersdk.OAuth2ErrorCode, description string) {
|
||||
Write(ctx, rw, status, codersdk.OAuth2Error{
|
||||
Error: errorCode,
|
||||
ErrorDescription: description,
|
||||
})
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type (
|
||||
httpRouteInfoKey struct{}
|
||||
)
|
||||
|
||||
type httpRouteInfo struct {
|
||||
Route string
|
||||
Method string
|
||||
}
|
||||
|
||||
// ExtractHTTPRoute retrieves just the HTTP route pattern from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPRoute(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Route
|
||||
}
|
||||
|
||||
// ExtractHTTPMethod retrieves just the HTTP method from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPMethod(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Method
|
||||
}
|
||||
|
||||
// HTTPRoute is middleware that stores the HTTP route pattern and method in
|
||||
// context for use by downstream handlers and services (e.g. prometheus).
|
||||
func HTTPRoute(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
route := getRoutePattern(r)
|
||||
ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{
|
||||
Route: route,
|
||||
Method: r.Method,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestHTTPRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
reqFn func() *http.Request
|
||||
registerRoutes map[string]string
|
||||
mws []func(http.Handler) http.Handler
|
||||
expectedRoute string
|
||||
expectedMethod string
|
||||
}{
|
||||
{
|
||||
name: "without middleware",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{},
|
||||
expectedRoute: "",
|
||||
expectedMethod: "",
|
||||
},
|
||||
{
|
||||
name: "root",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "parameterized route",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/users/{id}",
|
||||
expectedMethod: http.MethodPut,
|
||||
},
|
||||
{
|
||||
name: "unknown",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "UNKNOWN",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "static",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "STATIC",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
r := chi.NewRouter()
|
||||
done := make(chan string)
|
||||
for _, mw := range tc.mws {
|
||||
r.Use(mw)
|
||||
}
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
method := httpmw.ExtractHTTPMethod(r.Context())
|
||||
route := httpmw.ExtractHTTPRoute(r.Context())
|
||||
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
|
||||
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
for method, route := range tc.registerRoutes {
|
||||
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
req := tc.reqFn()
|
||||
r.ServeHTTP(httptest.NewRecorder(), req)
|
||||
_ = testutil.TryReceive(ctx, t, done)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -290,15 +290,15 @@ func (*codersdkErrorWriter) writeClientNotFound(ctx context.Context, rw http.Res
|
||||
type oauth2ErrorWriter struct{}
|
||||
|
||||
func (*oauth2ErrorWriter) writeMissingClientID(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Missing client_id parameter")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "Missing client_id parameter")
|
||||
}
|
||||
|
||||
func (*oauth2ErrorWriter) writeInvalidClientID(ctx context.Context, rw http.ResponseWriter, _ error) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
}
|
||||
|
||||
func (*oauth2ErrorWriter) writeClientNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
}
|
||||
|
||||
// extractOAuth2ProviderAppBase is the internal implementation that uses the strategy pattern
|
||||
|
||||
@@ -3,10 +3,8 @@ package httpmw
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
@@ -71,7 +69,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
var (
|
||||
dist *prometheus.HistogramVec
|
||||
distOpts []string
|
||||
path = getRoutePattern(r)
|
||||
path = ExtractHTTPRoute(r.Context())
|
||||
)
|
||||
|
||||
// We want to count WebSockets separately.
|
||||
@@ -98,29 +96,3 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) {
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
reg := prometheus.NewRegistry()
|
||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, req)
|
||||
}))).ServeHTTP(res, req)
|
||||
metrics, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(metrics), 0)
|
||||
@@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) {
|
||||
wrappedHandler := promMW(testHandler)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(tracing.StatusWriterMiddleware, promMW)
|
||||
r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW)
|
||||
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
|
||||
wrappedHandler.ServeHTTP(rw, r)
|
||||
})
|
||||
@@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
|
||||
|
||||
@@ -115,6 +115,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -145,6 +146,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -173,6 +175,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestOAuth2RegistrationErrorCodes(t *testing.T) {
|
||||
req: codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("test-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []string{"unsupported_grant_type"},
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{"unsupported_grant_type"},
|
||||
},
|
||||
expectedError: "invalid_client_metadata",
|
||||
expectedCode: http.StatusBadRequest,
|
||||
@@ -109,7 +109,7 @@ func TestOAuth2RegistrationErrorCodes(t *testing.T) {
|
||||
req: codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("test-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []string{"unsupported_response_type"},
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{"unsupported_response_type"},
|
||||
},
|
||||
expectedError: "invalid_client_metadata",
|
||||
expectedCode: http.StatusBadRequest,
|
||||
|
||||
@@ -44,10 +44,10 @@ func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
require.NotEmpty(t, metadata.Issuer)
|
||||
require.NotEmpty(t, metadata.AuthorizationEndpoint)
|
||||
require.NotEmpty(t, metadata.TokenEndpoint)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, "code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "authorization_code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "refresh_token")
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, "S256")
|
||||
require.Contains(t, metadata.ResponseTypesSupported, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
// Supported scopes are published from the curated catalog
|
||||
require.Equal(t, rbac.ExternalScopeNames(), metadata.ScopesSupported)
|
||||
}
|
||||
|
||||
@@ -277,47 +277,47 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
grantTypes: []string{},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidAuthorizationCode",
|
||||
grantTypes: []string{"authorization_code"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidRefreshTokenAlone",
|
||||
grantTypes: []string{"refresh_token"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"refresh_token"},
|
||||
expectError: true, // refresh_token requires authorization_code to be present
|
||||
},
|
||||
{
|
||||
name: "ValidMultiple",
|
||||
grantTypes: []string{"authorization_code", "refresh_token"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code", "refresh_token"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidUnsupported",
|
||||
grantTypes: []string{"client_credentials"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidPassword",
|
||||
grantTypes: []string{"password"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"password"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
grantTypes: []string{"implicit"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"implicit"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MixedValidInvalid",
|
||||
grantTypes: []string{"authorization_code", "client_credentials"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code", "client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -352,32 +352,32 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
responseTypes: []string{},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidCode",
|
||||
responseTypes: []string{"code"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidToken",
|
||||
responseTypes: []string{"token"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
responseTypes: []string{"id_token"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"id_token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidMultiple",
|
||||
responseTypes: []string{"code", "token"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"code", "token"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
@@ -659,14 +659,14 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should default to authorization_code
|
||||
require.Contains(t, config.GrantTypes, "authorization_code")
|
||||
require.Contains(t, config.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
|
||||
// Should default to code
|
||||
require.Contains(t, config.ResponseTypes, "code")
|
||||
require.Contains(t, config.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
|
||||
// Should default to client_secret_basic or client_secret_post
|
||||
require.True(t, config.TokenEndpointAuthMethod == "client_secret_basic" ||
|
||||
config.TokenEndpointAuthMethod == "client_secret_post" ||
|
||||
require.True(t, config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic ||
|
||||
config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost ||
|
||||
config.TokenEndpointAuthMethod == "")
|
||||
|
||||
// Client secret should be generated
|
||||
|
||||
@@ -1329,10 +1329,10 @@ func TestOAuth2DynamicClientRegistration(t *testing.T) {
|
||||
require.Equal(t, int64(0), resp.ClientSecretExpiresAt) // Non-expiring
|
||||
|
||||
// Verify default values
|
||||
require.Contains(t, resp.GrantTypes, "authorization_code")
|
||||
require.Contains(t, resp.GrantTypes, "refresh_token")
|
||||
require.Contains(t, resp.ResponseTypes, "code")
|
||||
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod)
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, resp.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Equal(t, codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic, resp.TokenEndpointAuthMethod)
|
||||
|
||||
// Verify request values are preserved
|
||||
require.Equal(t, req.RedirectURIs, resp.RedirectURIs)
|
||||
@@ -1363,9 +1363,9 @@ func TestOAuth2DynamicClientRegistration(t *testing.T) {
|
||||
require.NotEmpty(t, resp.RegistrationClientURI)
|
||||
|
||||
// Should have defaults applied
|
||||
require.Contains(t, resp.GrantTypes, "authorization_code")
|
||||
require.Contains(t, resp.ResponseTypes, "code")
|
||||
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod)
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, resp.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Equal(t, codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic, resp.TokenEndpointAuthMethod)
|
||||
})
|
||||
|
||||
t.Run("InvalidRedirectURI", func(t *testing.T) {
|
||||
|
||||
@@ -137,13 +137,13 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
|
||||
callbackURL, err := url.Parse(app.CallbackURL)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, "server_error", "Failed to validate query parameters")
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to validate query parameters")
|
||||
return
|
||||
}
|
||||
|
||||
params, _, err := extractAuthorizeParams(r, callbackURL)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", err.Error())
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,10 +151,10 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
if params.codeChallenge != "" {
|
||||
// If code_challenge is provided but method is not, default to S256
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = "S256"
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
}
|
||||
if params.codeChallengeMethod != "S256" {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid code_challenge_method: only S256 is supported")
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
// TODO: Ignoring scope for now, but should look into implementing.
|
||||
code, err := GenerateSecret()
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, "server_error", "Failed to generate OAuth2 app authorization code")
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to generate OAuth2 app authorization code")
|
||||
return
|
||||
}
|
||||
err = db.InTx(func(tx database.Store) error {
|
||||
@@ -202,7 +202,7 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, "server_error", "Failed to generate OAuth2 authorization code")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to generate OAuth2 authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ func GetAuthorizationServerMetadata(accessURL *url.URL) http.HandlerFunc {
|
||||
TokenEndpoint: accessURL.JoinPath("/oauth2/tokens").String(),
|
||||
RegistrationEndpoint: accessURL.JoinPath("/oauth2/register").String(), // RFC 7591
|
||||
RevocationEndpoint: accessURL.JoinPath("/oauth2/revoke").String(), // RFC 7009
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
CodeChallengeMethodsSupported: []string{"S256"},
|
||||
ResponseTypesSupported: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
GrantTypesSupported: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
CodeChallengeMethodsSupported: []codersdk.OAuth2PKCECodeChallengeMethod{codersdk.OAuth2PKCECodeChallengeMethodS256},
|
||||
ScopesSupported: rbac.ExternalScopeNames(),
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_post"},
|
||||
TokenEndpointAuthMethodsSupported: []codersdk.OAuth2TokenEndpointAuthMethod{codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost},
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, metadata)
|
||||
}
|
||||
|
||||
@@ -32,10 +32,10 @@ func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
require.NotEmpty(t, metadata.Issuer)
|
||||
require.NotEmpty(t, metadata.AuthorizationEndpoint)
|
||||
require.NotEmpty(t, metadata.TokenEndpoint)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, "code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "authorization_code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "refresh_token")
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, "S256")
|
||||
require.Contains(t, metadata.ResponseTypesSupported, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
// Supported scopes are published from the curated catalog
|
||||
require.Equal(t, rbac.ExternalScopeNames(), metadata.ScopesSupported)
|
||||
}
|
||||
|
||||
@@ -105,8 +105,9 @@ func GenerateState(t *testing.T) string {
|
||||
return base64.RawURLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// AuthorizeOAuth2App performs the OAuth2 authorization flow and returns the authorization code
|
||||
func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) string {
|
||||
// doAuthorizeRequest performs the OAuth2 authorization request and returns the response.
|
||||
// Caller is responsible for closing the response body.
|
||||
func doAuthorizeRequest(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -123,6 +124,8 @@ func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, p
|
||||
|
||||
if params.CodeChallenge != "" {
|
||||
query.Set("code_challenge", params.CodeChallenge)
|
||||
}
|
||||
if params.CodeChallengeMethod != "" {
|
||||
query.Set("code_challenge_method", params.CodeChallengeMethod)
|
||||
}
|
||||
if params.Resource != "" {
|
||||
@@ -151,6 +154,15 @@ func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, p
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
require.NoError(t, err, "failed to perform authorization request")
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// AuthorizeOAuth2App performs the OAuth2 authorization flow and returns the authorization code
|
||||
func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) string {
|
||||
t.Helper()
|
||||
|
||||
resp := doAuthorizeRequest(t, client, baseURL, params)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get a redirect response (either 302 Found or 307 Temporary Redirect)
|
||||
@@ -326,3 +338,13 @@ func CleanupOAuth2App(t *testing.T, client *codersdk.Client, appID uuid.UUID) {
|
||||
t.Logf("Warning: failed to cleanup OAuth2 app %s: %v", appID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeOAuth2AppExpectingError performs the OAuth2 authorization flow expecting an error
|
||||
func AuthorizeOAuth2AppExpectingError(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams, expectedStatusCode int) {
|
||||
t.Helper()
|
||||
|
||||
resp := doAuthorizeRequest(t, client, baseURL, params)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, expectedStatusCode, resp.StatusCode, "unexpected status code")
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/oauth2provider/oauth2providertest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
@@ -185,6 +186,38 @@ func TestOAuth2WithoutPKCE(t *testing.T) {
|
||||
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
|
||||
}
|
||||
|
||||
func TestOAuth2PKCEPlainMethodRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: false,
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create OAuth2 app
|
||||
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
t.Cleanup(func() {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
// Generate PKCE parameters but use "plain" method (should be rejected)
|
||||
_, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
// Attempt authorization with plain method - should fail
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: string(codersdk.OAuth2ProviderResponseTypeCode),
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: string(codersdk.OAuth2PKCECodeChallengeMethodPlain),
|
||||
}
|
||||
|
||||
// Should get a 400 Bad Request
|
||||
oauth2providertest.AuthorizeOAuth2AppExpectingError(t, client, client.URL.String(), authParams, 400)
|
||||
}
|
||||
|
||||
func TestOAuth2ResourceParameter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/oauth2provider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestVerifyPKCE(t *testing.T) {
|
||||
@@ -75,3 +76,52 @@ func TestPKCES256Generation(t *testing.T) {
|
||||
require.Equal(t, expectedChallenge, challenge)
|
||||
require.True(t, oauth2provider.VerifyPKCE(challenge, verifier))
|
||||
}
|
||||
|
||||
func TestValidatePKCECodeChallengeMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "EmptyIsValid",
|
||||
method: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "S256IsValid",
|
||||
method: string(codersdk.OAuth2PKCECodeChallengeMethodS256),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "PlainIsRejected",
|
||||
method: string(codersdk.OAuth2PKCECodeChallengeMethodPlain),
|
||||
expectError: true,
|
||||
errorContains: "plain",
|
||||
},
|
||||
{
|
||||
name: "UnknownIsRejected",
|
||||
method: "unknown_method",
|
||||
expectError: true,
|
||||
errorContains: "unsupported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := codersdk.ValidatePKCECodeChallengeMethod(tt.method)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
require.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("valid-grant-types-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
}
|
||||
|
||||
resp, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -266,7 +266,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("invalid-grant-types-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []string{"unsupported_grant"},
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{"unsupported_grant"},
|
||||
}
|
||||
|
||||
_, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -284,7 +284,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("valid-response-types-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []string{"code"},
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
}
|
||||
|
||||
resp, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -302,7 +302,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("invalid-response-types-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []string{"token"}, // Not supported
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{"token"}, // Not supported
|
||||
}
|
||||
|
||||
_, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -85,9 +86,9 @@ func CreateDynamicClientRegistration(db database.Store, accessURL *url.URL, audi
|
||||
DynamicallyRegistered: sql.NullBool{Bool: true, Valid: true},
|
||||
ClientIDIssuedAt: sql.NullTime{Time: now, Valid: true},
|
||||
ClientSecretExpiresAt: sql.NullTime{}, // No expiration for now
|
||||
GrantTypes: req.GrantTypes,
|
||||
ResponseTypes: req.ResponseTypes,
|
||||
TokenEndpointAuthMethod: sql.NullString{String: req.TokenEndpointAuthMethod, Valid: true},
|
||||
GrantTypes: slice.ToStrings(req.GrantTypes),
|
||||
ResponseTypes: slice.ToStrings(req.ResponseTypes),
|
||||
TokenEndpointAuthMethod: sql.NullString{String: string(req.TokenEndpointAuthMethod), Valid: true},
|
||||
Scope: sql.NullString{String: req.Scope, Valid: true},
|
||||
Contacts: req.Contacts,
|
||||
ClientUri: sql.NullString{String: req.ClientURI, Valid: req.ClientURI != ""},
|
||||
@@ -154,9 +155,9 @@ func CreateDynamicClientRegistration(db database.Store, accessURL *url.URL, audi
|
||||
JWKS: app.Jwks.RawMessage,
|
||||
SoftwareID: app.SoftwareID.String,
|
||||
SoftwareVersion: app.SoftwareVersion.String,
|
||||
GrantTypes: app.GrantTypes,
|
||||
ResponseTypes: app.ResponseTypes,
|
||||
TokenEndpointAuthMethod: app.TokenEndpointAuthMethod.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](app.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](app.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(app.TokenEndpointAuthMethod.String),
|
||||
Scope: app.Scope.String,
|
||||
Contacts: app.Contacts,
|
||||
RegistrationAccessToken: registrationToken,
|
||||
@@ -217,12 +218,12 @@ func GetClientConfiguration(db database.Store) http.HandlerFunc {
|
||||
JWKS: app.Jwks.RawMessage,
|
||||
SoftwareID: app.SoftwareID.String,
|
||||
SoftwareVersion: app.SoftwareVersion.String,
|
||||
GrantTypes: app.GrantTypes,
|
||||
ResponseTypes: app.ResponseTypes,
|
||||
TokenEndpointAuthMethod: app.TokenEndpointAuthMethod.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](app.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](app.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(app.TokenEndpointAuthMethod.String),
|
||||
Scope: app.Scope.String,
|
||||
Contacts: app.Contacts,
|
||||
RegistrationAccessToken: nil, // RFC 7592: Not returned in GET responses for security
|
||||
RegistrationAccessToken: "", // RFC 7592: Not returned in GET responses for security
|
||||
RegistrationClientURI: app.RegistrationClientUri.String,
|
||||
}
|
||||
|
||||
@@ -303,9 +304,9 @@ func UpdateClientConfiguration(db database.Store, auditor *audit.Auditor, logger
|
||||
RedirectUris: req.RedirectURIs,
|
||||
ClientType: sql.NullString{String: req.DetermineClientType(), Valid: true},
|
||||
ClientSecretExpiresAt: sql.NullTime{}, // No expiration for now
|
||||
GrantTypes: req.GrantTypes,
|
||||
ResponseTypes: req.ResponseTypes,
|
||||
TokenEndpointAuthMethod: sql.NullString{String: req.TokenEndpointAuthMethod, Valid: true},
|
||||
GrantTypes: slice.ToStrings(req.GrantTypes),
|
||||
ResponseTypes: slice.ToStrings(req.ResponseTypes),
|
||||
TokenEndpointAuthMethod: sql.NullString{String: string(req.TokenEndpointAuthMethod), Valid: true},
|
||||
Scope: sql.NullString{String: req.Scope, Valid: true},
|
||||
Contacts: req.Contacts,
|
||||
ClientUri: sql.NullString{String: req.ClientURI, Valid: req.ClientURI != ""},
|
||||
@@ -341,12 +342,12 @@ func UpdateClientConfiguration(db database.Store, auditor *audit.Auditor, logger
|
||||
JWKS: updatedApp.Jwks.RawMessage,
|
||||
SoftwareID: updatedApp.SoftwareID.String,
|
||||
SoftwareVersion: updatedApp.SoftwareVersion.String,
|
||||
GrantTypes: updatedApp.GrantTypes,
|
||||
ResponseTypes: updatedApp.ResponseTypes,
|
||||
TokenEndpointAuthMethod: updatedApp.TokenEndpointAuthMethod.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](updatedApp.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](updatedApp.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(updatedApp.TokenEndpointAuthMethod.String),
|
||||
Scope: updatedApp.Scope.String,
|
||||
Contacts: updatedApp.Contacts,
|
||||
RegistrationAccessToken: updatedApp.RegistrationAccessToken,
|
||||
RegistrationAccessToken: "", // RFC 7592: Not returned for security
|
||||
RegistrationClientURI: updatedApp.RegistrationClientUri.String,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,6 +28,26 @@ var (
|
||||
ErrInvalidTokenFormat = xerrors.New("invalid token format")
|
||||
)
|
||||
|
||||
func extractRevocationRequest(r *http.Request) (codersdk.OAuth2TokenRevocationRequest, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return codersdk.OAuth2TokenRevocationRequest{}, xerrors.Errorf("invalid form data: %w", err)
|
||||
}
|
||||
|
||||
req := codersdk.OAuth2TokenRevocationRequest{
|
||||
Token: r.Form.Get("token"),
|
||||
TokenTypeHint: codersdk.OAuth2RevocationTokenTypeHint(r.Form.Get("token_type_hint")),
|
||||
ClientID: r.Form.Get("client_id"),
|
||||
ClientSecret: r.Form.Get("client_secret"),
|
||||
}
|
||||
|
||||
// RFC 7009 requires 'token' parameter.
|
||||
if req.Token == "" {
|
||||
return codersdk.OAuth2TokenRevocationRequest{}, xerrors.New("missing token parameter")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// RevokeToken implements RFC 7009 OAuth2 Token Revocation
|
||||
// Authentication is unique for this endpoint in that it does not use the
|
||||
// standard token authentication middleware. Instead, it expects the token that
|
||||
@@ -41,35 +62,29 @@ func RevokeToken(db database.Store, logger slog.Logger) http.HandlerFunc {
|
||||
|
||||
// RFC 7009 requires POST method with application/x-www-form-urlencoded
|
||||
if r.Method != http.MethodPost {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusMethodNotAllowed, "invalid_request", "Method not allowed")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusMethodNotAllowed, codersdk.OAuth2ErrorCodeInvalidRequest, "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid form data")
|
||||
return
|
||||
}
|
||||
|
||||
// RFC 7009 requires 'token' parameter
|
||||
token := r.Form.Get("token")
|
||||
if token == "" {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Missing token parameter")
|
||||
req, err := extractRevocationRequest(r)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Determine if this is a refresh token (starts with "coder_") or API key
|
||||
// APIKeys do not have the SecretIdentifier prefix.
|
||||
const coderPrefix = SecretIdentifier + "_"
|
||||
isRefreshToken := strings.HasPrefix(token, coderPrefix)
|
||||
isRefreshToken := strings.HasPrefix(req.Token, coderPrefix)
|
||||
|
||||
// Revoke the token with ownership verification
|
||||
err := db.InTx(func(tx database.Store) error {
|
||||
err = db.InTx(func(tx database.Store) error {
|
||||
if isRefreshToken {
|
||||
// Handle refresh token revocation
|
||||
return revokeRefreshTokenInTx(ctx, tx, token, app.ID)
|
||||
return revokeRefreshTokenInTx(ctx, tx, req.Token, app.ID)
|
||||
}
|
||||
// Handle API key revocation
|
||||
return revokeAPIKeyInTx(ctx, tx, token, app.ID)
|
||||
return revokeAPIKeyInTx(ctx, tx, req.Token, app.ID)
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrTokenNotBelongsToClient) {
|
||||
@@ -85,14 +100,14 @@ func RevokeToken(db database.Store, logger slog.Logger) http.HandlerFunc {
|
||||
logger.Debug(ctx, "token revocation failed: invalid token format",
|
||||
slog.F("client_id", app.ID.String()),
|
||||
slog.F("app_name", app.Name))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid token format")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "Invalid token format")
|
||||
return
|
||||
}
|
||||
logger.Error(ctx, "token revocation failed with internal server error",
|
||||
slog.Error(err),
|
||||
slog.F("client_id", app.ID.String()),
|
||||
slog.F("app_name", app.Name))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, "server_error", "Internal server error")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
@@ -38,28 +36,18 @@ var (
|
||||
errInvalidResource = xerrors.New("invalid resource parameter")
|
||||
)
|
||||
|
||||
type tokenParams struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
code string
|
||||
grantType codersdk.OAuth2ProviderGrantType
|
||||
redirectURL *url.URL
|
||||
refreshToken string
|
||||
codeVerifier string // PKCE verifier
|
||||
resource string // RFC 8707 resource for token binding
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []codersdk.ValidationError, error) {
|
||||
func extractTokenRequest(r *http.Request, callbackURL *url.URL) (codersdk.OAuth2TokenRequest, []codersdk.ValidationError, error) {
|
||||
p := httpapi.NewQueryParamParser()
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return tokenParams{}, nil, xerrors.Errorf("parse form: %w", err)
|
||||
return codersdk.OAuth2TokenRequest{}, nil, xerrors.Errorf("parse form: %w", err)
|
||||
}
|
||||
|
||||
vals := r.Form
|
||||
p.RequiredNotEmpty("grant_type")
|
||||
grantType := httpapi.ParseCustom(p, vals, "", "grant_type", httpapi.ParseEnum[codersdk.OAuth2ProviderGrantType])
|
||||
|
||||
// Grant-type specific validation - must be called before parsing values.
|
||||
switch grantType {
|
||||
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
|
||||
p.RequiredNotEmpty("refresh_token")
|
||||
@@ -67,19 +55,23 @@ func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []c
|
||||
p.RequiredNotEmpty("client_secret", "client_id", "code")
|
||||
}
|
||||
|
||||
params := tokenParams{
|
||||
clientID: p.String(vals, "", "client_id"),
|
||||
clientSecret: p.String(vals, "", "client_secret"),
|
||||
code: p.String(vals, "", "code"),
|
||||
grantType: grantType,
|
||||
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
|
||||
refreshToken: p.String(vals, "", "refresh_token"),
|
||||
codeVerifier: p.String(vals, "", "code_verifier"),
|
||||
resource: p.String(vals, "", "resource"),
|
||||
scopes: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
|
||||
req := codersdk.OAuth2TokenRequest{
|
||||
GrantType: grantType,
|
||||
ClientID: p.String(vals, "", "client_id"),
|
||||
ClientSecret: p.String(vals, "", "client_secret"),
|
||||
Code: p.String(vals, "", "code"),
|
||||
RedirectURI: p.String(vals, "", "redirect_uri"),
|
||||
RefreshToken: p.String(vals, "", "refresh_token"),
|
||||
CodeVerifier: p.String(vals, "", "code_verifier"),
|
||||
Resource: p.String(vals, "", "resource"),
|
||||
Scope: p.String(vals, "", "scope"),
|
||||
}
|
||||
// Validate resource parameter syntax (RFC 8707): must be absolute URI without fragment
|
||||
if err := validateResourceParameter(params.resource); err != nil {
|
||||
|
||||
// Validate redirect URI - errors are added to p.Errors.
|
||||
_ = p.RedirectURL(vals, callbackURL, "redirect_uri")
|
||||
|
||||
// Validate resource parameter syntax (RFC 8707): must be absolute URI without fragment.
|
||||
if err := validateResourceParameter(req.Resource); err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: "resource",
|
||||
Detail: "must be an absolute URI without fragment",
|
||||
@@ -88,9 +80,9 @@ func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []c
|
||||
|
||||
p.ErrorExcessParams(vals)
|
||||
if len(p.Errors) > 0 {
|
||||
return tokenParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
|
||||
return codersdk.OAuth2TokenRequest{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
|
||||
}
|
||||
return params, nil, nil
|
||||
return req, nil, nil
|
||||
}
|
||||
|
||||
// Tokens
|
||||
@@ -110,13 +102,13 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
return
|
||||
}
|
||||
|
||||
params, validationErrs, err := extractTokenParams(r, callbackURL)
|
||||
req, validationErrs, err := extractTokenRequest(r, callbackURL)
|
||||
if err != nil {
|
||||
// Check for specific validation errors in priority order
|
||||
if slices.ContainsFunc(validationErrs, func(validationError codersdk.ValidationError) bool {
|
||||
return validationError.Field == "grant_type"
|
||||
}) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "unsupported_grant_type", "The grant type is missing or unsupported")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeUnsupportedGrantType, "The grant type is missing or unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -125,47 +117,47 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
if slices.ContainsFunc(validationErrs, func(validationError codersdk.ValidationError) bool {
|
||||
return validationError.Field == field
|
||||
}) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", fmt.Sprintf("Missing required parameter: %s", field))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, fmt.Sprintf("Missing required parameter: %s", field))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Generic invalid request for other validation errors
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "The request is missing required parameters or is otherwise malformed")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "The request is missing required parameters or is otherwise malformed")
|
||||
return
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
var token codersdk.OAuth2TokenResponse
|
||||
//nolint:gocritic,revive // More cases will be added later.
|
||||
switch params.grantType {
|
||||
switch req.GrantType {
|
||||
// TODO: Client creds, device code.
|
||||
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
|
||||
token, err = refreshTokenGrant(ctx, db, app, lifetimes, params)
|
||||
token, err = refreshTokenGrant(ctx, db, app, lifetimes, req)
|
||||
case codersdk.OAuth2ProviderGrantTypeAuthorizationCode:
|
||||
token, err = authorizationCodeGrant(ctx, db, app, lifetimes, params)
|
||||
token, err = authorizationCodeGrant(ctx, db, app, lifetimes, req)
|
||||
default:
|
||||
// This should handle truly invalid grant types
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "unsupported_grant_type", fmt.Sprintf("The grant type %q is not supported", params.grantType))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeUnsupportedGrantType, fmt.Sprintf("The grant type %q is not supported", req.GrantType))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, errBadSecret) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errBadCode) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The authorization code is invalid or expired")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The authorization code is invalid or expired")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errInvalidPKCE) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The PKCE code verifier is invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The PKCE code verifier is invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errInvalidResource) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_target", "The resource parameter is invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidTarget, "The resource parameter is invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errBadToken) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The refresh token is invalid or expired")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The refresh token is invalid or expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -182,77 +174,77 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
}
|
||||
}
|
||||
|
||||
func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, params tokenParams) (oauth2.Token, error) {
|
||||
func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, req codersdk.OAuth2TokenRequest) (codersdk.OAuth2TokenResponse, error) {
|
||||
// Validate the client secret.
|
||||
secret, err := ParseFormattedSecret(params.clientSecret)
|
||||
secret, err := ParseFormattedSecret(req.ClientSecret)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, errBadSecret
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
}
|
||||
//nolint:gocritic // Users cannot read secrets so we must use the system.
|
||||
dbSecret, err := db.GetOAuth2ProviderAppSecretByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(secret.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return oauth2.Token{}, errBadSecret
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
}
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
equalSecret := apikey.ValidateHash(dbSecret.HashedSecret, secret.Secret)
|
||||
if !equalSecret {
|
||||
return oauth2.Token{}, errBadSecret
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
}
|
||||
|
||||
// Validate the authorization code.
|
||||
code, err := ParseFormattedSecret(params.code)
|
||||
code, err := ParseFormattedSecret(req.Code)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, errBadCode
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
dbCode, err := db.GetOAuth2ProviderAppCodeByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(code.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return oauth2.Token{}, errBadCode
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
equalCode := apikey.ValidateHash(dbCode.HashedSecret, code.Secret)
|
||||
if !equalCode {
|
||||
return oauth2.Token{}, errBadCode
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
|
||||
// Ensure the code has not expired.
|
||||
if dbCode.ExpiresAt.Before(dbtime.Now()) {
|
||||
return oauth2.Token{}, errBadCode
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
}
|
||||
|
||||
// Verify PKCE challenge if present
|
||||
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
|
||||
if params.codeVerifier == "" {
|
||||
return oauth2.Token{}, errInvalidPKCE
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, params.codeVerifier) {
|
||||
return oauth2.Token{}, errInvalidPKCE
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency (RFC 8707)
|
||||
if dbCode.ResourceUri.Valid && dbCode.ResourceUri.String != "" {
|
||||
// Resource was specified during authorization - it must match in token request
|
||||
if params.resource == "" {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
if req.Resource == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
}
|
||||
if params.resource != dbCode.ResourceUri.String {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
if req.Resource != dbCode.ResourceUri.String {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
}
|
||||
} else if params.resource != "" {
|
||||
} else if req.Resource != "" {
|
||||
// Resource was not specified during authorization but is now provided
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
}
|
||||
|
||||
// Generate a refresh token.
|
||||
refreshToken, err := GenerateSecret()
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
// Generate the API key we will swap for the code.
|
||||
@@ -266,13 +258,13 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
TokenName: tokenName,
|
||||
})
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
// Grab the user roles so we can perform the exchange as the user.
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, dbCode.UserID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
return codersdk.OAuth2TokenResponse{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
}
|
||||
|
||||
// Do the actual token exchange in the database.
|
||||
@@ -324,47 +316,47 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
return oauth2.Token{
|
||||
return codersdk.OAuth2TokenResponse{
|
||||
AccessToken: sessionToken,
|
||||
TokenType: "Bearer",
|
||||
TokenType: codersdk.OAuth2TokenTypeBearer,
|
||||
RefreshToken: refreshToken.Formatted,
|
||||
Expiry: key.ExpiresAt,
|
||||
ExpiresIn: int64(time.Until(key.ExpiresAt).Seconds()),
|
||||
Expiry: &key.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, params tokenParams) (oauth2.Token, error) {
|
||||
func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, req codersdk.OAuth2TokenRequest) (codersdk.OAuth2TokenResponse, error) {
|
||||
// Validate the token.
|
||||
token, err := ParseFormattedSecret(params.refreshToken)
|
||||
token, err := ParseFormattedSecret(req.RefreshToken)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, errBadToken
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
}
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
dbToken, err := db.GetOAuth2ProviderAppTokenByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(token.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return oauth2.Token{}, errBadToken
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
}
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
equal := apikey.ValidateHash(dbToken.RefreshHash, token.Secret)
|
||||
if !equal {
|
||||
return oauth2.Token{}, errBadToken
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
}
|
||||
|
||||
// Ensure the token has not expired.
|
||||
if dbToken.ExpiresAt.Before(dbtime.Now()) {
|
||||
return oauth2.Token{}, errBadToken
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency for refresh tokens (RFC 8707)
|
||||
if params.resource != "" {
|
||||
if req.Resource != "" {
|
||||
// If resource is provided in refresh request, it must match the original token's audience
|
||||
if !dbToken.Audience.Valid || dbToken.Audience.String != params.resource {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
if !dbToken.Audience.Valid || dbToken.Audience.String != req.Resource {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,18 +364,18 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
prevKey, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), dbToken.APIKeyID)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, prevKey.UserID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
return codersdk.OAuth2TokenResponse{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
}
|
||||
|
||||
// Generate a new refresh token.
|
||||
refreshToken, err := GenerateSecret()
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
// Generate the new API key.
|
||||
@@ -397,7 +389,7 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
TokenName: tokenName,
|
||||
})
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
// Replace the token.
|
||||
@@ -437,15 +429,15 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
}
|
||||
|
||||
return oauth2.Token{
|
||||
return codersdk.OAuth2TokenResponse{
|
||||
AccessToken: sessionToken,
|
||||
TokenType: "Bearer",
|
||||
TokenType: codersdk.OAuth2TokenTypeBearer,
|
||||
RefreshToken: refreshToken.Formatted,
|
||||
Expiry: key.ExpiresAt,
|
||||
ExpiresIn: int64(time.Until(key.ExpiresAt).Seconds()),
|
||||
Expiry: &key.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package oauth2provider
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -10,6 +11,12 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// parseScopes parses a space-delimited scope string into a slice of scopes
|
||||
// per RFC 6749.
|
||||
func parseScopes(scope string) []string {
|
||||
return strings.Fields(strings.TrimSpace(scope))
|
||||
}
|
||||
|
||||
// TestExtractTokenParams_Scopes tests OAuth2 scope parameter parsing
|
||||
// to ensure RFC 6749 compliance where scopes are space-delimited
|
||||
func TestExtractTokenParams_Scopes(t *testing.T) {
|
||||
@@ -115,15 +122,15 @@ func TestExtractTokenParams_Scopes(t *testing.T) {
|
||||
Form: form, // Form is the combination of PostForm and URL query
|
||||
}
|
||||
|
||||
// Extract token params
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
// Extract token request
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
|
||||
// Verify no errors occurred
|
||||
require.NoError(t, err, "extractTokenParams should not return error for: %s", tc.description)
|
||||
require.NoError(t, err, "extractTokenRequest should not return error for: %s", tc.description)
|
||||
require.Empty(t, validationErrs, "should have no validation errors for: %s", tc.description)
|
||||
|
||||
// Verify scopes match expected
|
||||
require.Equal(t, tc.expectedScopes, params.scopes, "scope parsing failed for: %s", tc.description)
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope), "scope parsing failed for: %s", tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -178,15 +185,15 @@ func TestExtractTokenParams_ScopesURLEncoded(t *testing.T) {
|
||||
Form: values,
|
||||
}
|
||||
|
||||
// Extract token params
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
// Extract token request
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
|
||||
// Verify no errors
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
|
||||
// Verify scopes
|
||||
require.Equal(t, tc.expectedScopes, params.scopes)
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -259,11 +266,11 @@ func TestExtractTokenParams_ScopesEdgeCases(t *testing.T) {
|
||||
Form: form,
|
||||
}
|
||||
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
|
||||
require.NoError(t, err, "extractTokenParams should not error for: %s", tc.description)
|
||||
require.NoError(t, err, "extractTokenRequest should not error for: %s", tc.description)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, tc.expectedScopes, params.scopes, "scope mismatch for: %s", tc.description)
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope), "scope mismatch for: %s", tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -354,10 +361,10 @@ func TestRefreshTokenGrant_Scopes(t *testing.T) {
|
||||
Form: form,
|
||||
}
|
||||
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, codersdk.OAuth2ProviderGrantTypeRefreshToken, params.grantType)
|
||||
require.Equal(t, []string{"reduced:scope", "subset:scope"}, params.scopes)
|
||||
require.Equal(t, codersdk.OAuth2ProviderGrantTypeRefreshToken, tokenReq.GrantType)
|
||||
require.Equal(t, []string{"reduced:scope", "subset:scope"}, parseScopes(tokenReq.Scope))
|
||||
}
|
||||
|
||||
@@ -277,47 +277,47 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
grantTypes: []string{},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidAuthorizationCode",
|
||||
grantTypes: []string{"authorization_code"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidRefreshTokenAlone",
|
||||
grantTypes: []string{"refresh_token"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
expectError: true, // refresh_token requires authorization_code to be present
|
||||
},
|
||||
{
|
||||
name: "ValidMultiple",
|
||||
grantTypes: []string{"authorization_code", "refresh_token"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidUnsupported",
|
||||
grantTypes: []string{"client_credentials"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeClientCredentials},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidPassword",
|
||||
grantTypes: []string{"password"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypePassword},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
grantTypes: []string{"implicit"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeImplicit},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MixedValidInvalid",
|
||||
grantTypes: []string{"authorization_code", "client_credentials"},
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeClientCredentials},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -352,32 +352,32 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
responseTypes: []string{},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidCode",
|
||||
responseTypes: []string{"code"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidToken",
|
||||
responseTypes: []string{"token"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeToken},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
responseTypes: []string{"id_token"},
|
||||
name: "InvalidIDToken",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"id_token"}, // OIDC-specific, no constant
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidMultiple",
|
||||
responseTypes: []string{"code", "token"},
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode, codersdk.OAuth2ProviderResponseTypeToken},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
@@ -422,27 +422,27 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "ValidClientSecretBasic",
|
||||
authMethod: "client_secret_basic",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidClientSecretPost",
|
||||
authMethod: "client_secret_post",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidNone",
|
||||
authMethod: "none",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodNone,
|
||||
expectError: false, // "none" is valid for public clients per RFC 7591
|
||||
},
|
||||
{
|
||||
name: "InvalidPrivateKeyJWT",
|
||||
authMethod: "private_key_jwt",
|
||||
authMethod: "private_key_jwt", // OIDC-specific, no constant defined
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidClientSecretJWT",
|
||||
authMethod: "client_secret_jwt",
|
||||
authMethod: "client_secret_jwt", // OIDC-specific, no constant defined
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
@@ -659,14 +659,14 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should default to authorization_code
|
||||
require.Contains(t, config.GrantTypes, "authorization_code")
|
||||
require.Contains(t, config.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
|
||||
// Should default to code
|
||||
require.Contains(t, config.ResponseTypes, "code")
|
||||
require.Contains(t, config.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
|
||||
// Should default to client_secret_basic or client_secret_post
|
||||
require.True(t, config.TokenEndpointAuthMethod == "client_secret_basic" ||
|
||||
config.TokenEndpointAuthMethod == "client_secret_post" ||
|
||||
require.True(t, config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic ||
|
||||
config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost ||
|
||||
config.TokenEndpointAuthMethod == "")
|
||||
|
||||
// Client secret should be generated
|
||||
|
||||
@@ -2344,6 +2344,10 @@ func (api *API) patchWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
defer commitAudit()
|
||||
aReq.Old = workspace.WorkspaceTable()
|
||||
|
||||
if !api.allowWorkspaceSharing(ctx, rw, workspace.OrganizationID) {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateWorkspaceACL
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -2440,6 +2444,10 @@ func (api *API) deleteWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
defer commitAuditor()
|
||||
aReq.Old = workspace.WorkspaceTable()
|
||||
|
||||
if !api.allowWorkspaceSharing(ctx, rw, workspace.OrganizationID) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
err := tx.DeleteWorkspaceACLByID(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
@@ -2463,6 +2471,27 @@ func (api *API) deleteWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
// allowWorkspaceSharing enforces the workspace-sharing gate for an
|
||||
// organization. It writes an HTTP error response and returns false if
|
||||
// sharing is disabled or the org lookup fails; otherwise it returns
|
||||
// true.
|
||||
func (api *API) allowWorkspaceSharing(ctx context.Context, rw http.ResponseWriter, organizationID uuid.UUID) bool {
|
||||
//nolint:gocritic // Use system context so this check doesn’t
|
||||
// depend on the caller having organization:read.
|
||||
org, err := api.Database.GetOrganizationByID(dbauthz.AsSystemRestricted(ctx), organizationID)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return false
|
||||
}
|
||||
if org.WorkspaceSharingDisabled {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Workspace sharing is disabled for this organization.",
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// workspacesData only returns the data the caller can access. If the caller
|
||||
// does not have the correct perms to read a given template, the template will
|
||||
// not be returned.
|
||||
|
||||
@@ -5266,7 +5266,66 @@ func TestDeleteWorkspaceACL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// nolint:tparallel,paralleltest // Subtests modify package global.
|
||||
// `use`-role shares are granted `workspace:read` via the workspace RBAC ACL
|
||||
// list, so they should be able to read the ACL.
|
||||
//
|
||||
//nolint:tparallel,paralleltest // Test modifies a package global (rbac.workspaceACLDisabled).
|
||||
func TestWorkspaceReadCanListACL(t *testing.T) {
|
||||
// Be defensive by saving/restoring the modified package global.
|
||||
prevWorkspaceACLDisabled := rbac.WorkspaceACLDisabled()
|
||||
rbac.SetWorkspaceACLDisabled(false)
|
||||
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedUserClientA, sharedUserA = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
_, sharedUserB = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedGroup = dbgen.Group(t, db, database.Group{OrganizationID: admin.OrganizationID})
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: workspaceOwner.ID,
|
||||
OrganizationID: admin.OrganizationID,
|
||||
}).Do().Workspace
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := workspaceOwnerClient.UpdateWorkspaceACL(ctx, workspace.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedUserA.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
sharedUserB.ID.String(): codersdk.WorkspaceRoleAdmin,
|
||||
},
|
||||
GroupRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedGroup.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
acl, err := sharedUserClientA.WorkspaceACL(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, acl.Users, 2)
|
||||
require.Len(t, acl.Groups, 1)
|
||||
|
||||
gotRoles := make(map[uuid.UUID]codersdk.WorkspaceRole, len(acl.Users))
|
||||
for _, u := range acl.Users {
|
||||
gotRoles[u.ID] = u.Role
|
||||
}
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, gotRoles[sharedUserA.ID])
|
||||
require.Equal(t, codersdk.WorkspaceRoleAdmin, gotRoles[sharedUserB.ID])
|
||||
|
||||
gotGroupRoles := make(map[uuid.UUID]codersdk.WorkspaceRole, len(acl.Groups))
|
||||
for _, g := range acl.Groups {
|
||||
gotGroupRoles[g.ID] = g.Role
|
||||
}
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, gotGroupRoles[sharedGroup.ID])
|
||||
}
|
||||
|
||||
// nolint:tparallel,paralleltest // Subtests modify a package global (rbac.workspaceACLDisabled).
|
||||
func TestWorkspaceSharingDisabled(t *testing.T) {
|
||||
t.Run("CanAccessWhenEnabled", func(t *testing.T) {
|
||||
var (
|
||||
|
||||
@@ -3484,6 +3484,16 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "rateLimit",
|
||||
},
|
||||
{
|
||||
Name: "AI Bridge Structured Logging",
|
||||
Description: "Emit structured logs for AI Bridge interception records. Use this for exporting these records to external SIEM or observability systems.",
|
||||
Flag: "aibridge-structured-logging",
|
||||
Env: "CODER_AIBRIDGE_STRUCTURED_LOGGING",
|
||||
Value: &c.AI.BridgeConfig.StructuredLogging,
|
||||
Default: "false",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "structuredLogging",
|
||||
},
|
||||
|
||||
// AI Bridge Proxy Options
|
||||
{
|
||||
@@ -3610,6 +3620,7 @@ type AIBridgeConfig struct {
|
||||
Retention serpent.Duration `json:"retention" typescript:",notnull"`
|
||||
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
|
||||
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
|
||||
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeOpenAIConfig struct {
|
||||
|
||||
+248
-88
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -186,14 +187,22 @@ func (c *Client) DeleteOAuth2ProviderAppSecret(ctx context.Context, appID uuid.U
|
||||
|
||||
type OAuth2ProviderGrantType string
|
||||
|
||||
// OAuth2ProviderGrantType values (RFC 6749).
|
||||
const (
|
||||
OAuth2ProviderGrantTypeAuthorizationCode OAuth2ProviderGrantType = "authorization_code"
|
||||
OAuth2ProviderGrantTypeRefreshToken OAuth2ProviderGrantType = "refresh_token"
|
||||
OAuth2ProviderGrantTypePassword OAuth2ProviderGrantType = "password"
|
||||
OAuth2ProviderGrantTypeClientCredentials OAuth2ProviderGrantType = "client_credentials"
|
||||
OAuth2ProviderGrantTypeImplicit OAuth2ProviderGrantType = "implicit"
|
||||
)
|
||||
|
||||
func (e OAuth2ProviderGrantType) Valid() bool {
|
||||
switch e {
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode, OAuth2ProviderGrantTypeRefreshToken:
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode,
|
||||
OAuth2ProviderGrantTypeRefreshToken,
|
||||
OAuth2ProviderGrantTypePassword,
|
||||
OAuth2ProviderGrantTypeClientCredentials,
|
||||
OAuth2ProviderGrantTypeImplicit:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -201,19 +210,171 @@ func (e OAuth2ProviderGrantType) Valid() bool {
|
||||
|
||||
type OAuth2ProviderResponseType string
|
||||
|
||||
// OAuth2ProviderResponseType values (RFC 6749).
|
||||
const (
|
||||
OAuth2ProviderResponseTypeCode OAuth2ProviderResponseType = "code"
|
||||
OAuth2ProviderResponseTypeCode OAuth2ProviderResponseType = "code"
|
||||
OAuth2ProviderResponseTypeToken OAuth2ProviderResponseType = "token"
|
||||
)
|
||||
|
||||
func (e OAuth2ProviderResponseType) Valid() bool {
|
||||
//nolint:gocritic,revive // More cases might be added later.
|
||||
switch e {
|
||||
case OAuth2ProviderResponseTypeCode:
|
||||
case OAuth2ProviderResponseTypeCode, OAuth2ProviderResponseTypeToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2TokenEndpointAuthMethod string
|
||||
|
||||
const (
|
||||
OAuth2TokenEndpointAuthMethodClientSecretBasic OAuth2TokenEndpointAuthMethod = "client_secret_basic"
|
||||
OAuth2TokenEndpointAuthMethodClientSecretPost OAuth2TokenEndpointAuthMethod = "client_secret_post"
|
||||
OAuth2TokenEndpointAuthMethodNone OAuth2TokenEndpointAuthMethod = "none"
|
||||
)
|
||||
|
||||
func (m OAuth2TokenEndpointAuthMethod) Valid() bool {
|
||||
switch m {
|
||||
case OAuth2TokenEndpointAuthMethodClientSecretBasic,
|
||||
OAuth2TokenEndpointAuthMethodClientSecretPost,
|
||||
OAuth2TokenEndpointAuthMethodNone:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2PKCECodeChallengeMethod string
|
||||
|
||||
// OAuth2PKCECodeChallengeMethod values (RFC 7636).
|
||||
const (
|
||||
OAuth2PKCECodeChallengeMethodS256 OAuth2PKCECodeChallengeMethod = "S256"
|
||||
OAuth2PKCECodeChallengeMethodPlain OAuth2PKCECodeChallengeMethod = "plain"
|
||||
)
|
||||
|
||||
func (m OAuth2PKCECodeChallengeMethod) Valid() bool {
|
||||
switch m {
|
||||
case OAuth2PKCECodeChallengeMethodS256, OAuth2PKCECodeChallengeMethodPlain:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2TokenType string
|
||||
|
||||
// OAuth2TokenType values (RFC 6749, RFC 9449).
|
||||
const (
|
||||
OAuth2TokenTypeBearer OAuth2TokenType = "Bearer"
|
||||
OAuth2TokenTypeDPoP OAuth2TokenType = "DPoP"
|
||||
)
|
||||
|
||||
func (t OAuth2TokenType) Valid() bool {
|
||||
switch t {
|
||||
case OAuth2TokenTypeBearer, OAuth2TokenTypeDPoP:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2RevocationTokenTypeHint string
|
||||
|
||||
const (
|
||||
OAuth2RevocationTokenTypeHintAccessToken OAuth2RevocationTokenTypeHint = "access_token"
|
||||
OAuth2RevocationTokenTypeHintRefreshToken OAuth2RevocationTokenTypeHint = "refresh_token"
|
||||
)
|
||||
|
||||
func (h OAuth2RevocationTokenTypeHint) Valid() bool {
|
||||
switch h {
|
||||
case OAuth2RevocationTokenTypeHintAccessToken, OAuth2RevocationTokenTypeHintRefreshToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2ErrorCode string
|
||||
|
||||
// OAuth2 error codes per RFC 6749, RFC 7009, RFC 8707.
|
||||
// This is not comprehensive; it includes only codes relevant to this implementation.
|
||||
const (
|
||||
// RFC 6749 - Token endpoint errors.
|
||||
OAuth2ErrorCodeInvalidRequest OAuth2ErrorCode = "invalid_request"
|
||||
OAuth2ErrorCodeInvalidClient OAuth2ErrorCode = "invalid_client"
|
||||
OAuth2ErrorCodeInvalidGrant OAuth2ErrorCode = "invalid_grant"
|
||||
OAuth2ErrorCodeUnauthorizedClient OAuth2ErrorCode = "unauthorized_client"
|
||||
OAuth2ErrorCodeUnsupportedGrantType OAuth2ErrorCode = "unsupported_grant_type"
|
||||
OAuth2ErrorCodeInvalidScope OAuth2ErrorCode = "invalid_scope"
|
||||
|
||||
// RFC 6749 - Authorization endpoint errors.
|
||||
OAuth2ErrorCodeAccessDenied OAuth2ErrorCode = "access_denied"
|
||||
OAuth2ErrorCodeUnsupportedResponseType OAuth2ErrorCode = "unsupported_response_type"
|
||||
OAuth2ErrorCodeServerError OAuth2ErrorCode = "server_error"
|
||||
OAuth2ErrorCodeTemporarilyUnavailable OAuth2ErrorCode = "temporarily_unavailable"
|
||||
|
||||
// RFC 7009 - Token revocation errors.
|
||||
OAuth2ErrorCodeUnsupportedTokenType OAuth2ErrorCode = "unsupported_token_type"
|
||||
|
||||
// RFC 8707 - Resource indicator errors.
|
||||
OAuth2ErrorCodeInvalidTarget OAuth2ErrorCode = "invalid_target"
|
||||
)
|
||||
|
||||
func (c OAuth2ErrorCode) Valid() bool {
|
||||
switch c {
|
||||
case OAuth2ErrorCodeInvalidRequest,
|
||||
OAuth2ErrorCodeInvalidClient,
|
||||
OAuth2ErrorCodeInvalidGrant,
|
||||
OAuth2ErrorCodeUnauthorizedClient,
|
||||
OAuth2ErrorCodeUnsupportedGrantType,
|
||||
OAuth2ErrorCodeInvalidScope,
|
||||
OAuth2ErrorCodeAccessDenied,
|
||||
OAuth2ErrorCodeUnsupportedResponseType,
|
||||
OAuth2ErrorCodeServerError,
|
||||
OAuth2ErrorCodeTemporarilyUnavailable,
|
||||
OAuth2ErrorCodeUnsupportedTokenType,
|
||||
OAuth2ErrorCodeInvalidTarget:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OAuth2Error represents an OAuth2-compliant error response per RFC 6749.
|
||||
type OAuth2Error struct {
|
||||
Error OAuth2ErrorCode `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ErrorURI string `json:"error_uri,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2TokenRequest represents a token request per RFC 6749. The actual wire
|
||||
// format is application/x-www-form-urlencoded; this struct is for SDK docs.
|
||||
type OAuth2TokenRequest struct {
|
||||
GrantType OAuth2ProviderGrantType `json:"grant_type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
CodeVerifier string `json:"code_verifier,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2TokenResponse represents a successful token response per RFC 6749.
|
||||
type OAuth2TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType OAuth2TokenType `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Expiry is not part of RFC 6749 but is included for compatibility with
|
||||
// golang.org/x/oauth2.Token and clients that expect a timestamp.
|
||||
Expiry *time.Time `json:"expiry,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// OAuth2TokenRevocationRequest represents a token revocation request per RFC 7009.
|
||||
type OAuth2TokenRevocationRequest struct {
|
||||
Token string `json:"token"`
|
||||
TokenTypeHint OAuth2RevocationTokenTypeHint `json:"token_type_hint,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
}
|
||||
|
||||
// RevokeOAuth2Token revokes a specific OAuth2 token using RFC 7009 token revocation.
|
||||
func (c *Client) RevokeOAuth2Token(ctx context.Context, clientID uuid.UUID, token string) error {
|
||||
form := url.Values{}
|
||||
@@ -256,18 +417,18 @@ type OAuth2DeviceFlowCallbackResponse struct {
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
}
|
||||
|
||||
// OAuth2AuthorizationServerMetadata represents RFC 8414 OAuth 2.0 Authorization Server Metadata
|
||||
// OAuth2AuthorizationServerMetadata represents RFC 8414 OAuth 2.0 Authorization Server Metadata.
|
||||
type OAuth2AuthorizationServerMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
ResponseTypesSupported []OAuth2ProviderResponseType `json:"response_types_supported"`
|
||||
GrantTypesSupported []OAuth2ProviderGrantType `json:"grant_types_supported,omitempty"`
|
||||
CodeChallengeMethodsSupported []OAuth2PKCECodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2ProtectedResourceMetadata represents RFC 9728 OAuth 2.0 Protected Resource Metadata
|
||||
@@ -278,50 +439,50 @@ type OAuth2ProtectedResourceMetadata struct {
|
||||
BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2ClientRegistrationRequest represents RFC 7591 Dynamic Client Registration Request
|
||||
// OAuth2ClientRegistrationRequest represents RFC 7591 Dynamic Client Registration Request.
|
||||
type OAuth2ClientRegistrationRequest struct {
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
SoftwareStatement string `json:"software_statement,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
SoftwareStatement string `json:"software_statement,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types,omitempty"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types,omitempty"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
}
|
||||
|
||||
func (req OAuth2ClientRegistrationRequest) ApplyDefaults() OAuth2ClientRegistrationRequest {
|
||||
// Apply grant type defaults
|
||||
// Apply grant type defaults.
|
||||
if len(req.GrantTypes) == 0 {
|
||||
req.GrantTypes = []string{
|
||||
string(OAuth2ProviderGrantTypeAuthorizationCode),
|
||||
string(OAuth2ProviderGrantTypeRefreshToken),
|
||||
req.GrantTypes = []OAuth2ProviderGrantType{
|
||||
OAuth2ProviderGrantTypeAuthorizationCode,
|
||||
OAuth2ProviderGrantTypeRefreshToken,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply response type defaults
|
||||
// Apply response type defaults.
|
||||
if len(req.ResponseTypes) == 0 {
|
||||
req.ResponseTypes = []string{
|
||||
string(OAuth2ProviderResponseTypeCode),
|
||||
req.ResponseTypes = []OAuth2ProviderResponseType{
|
||||
OAuth2ProviderResponseTypeCode,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply token endpoint auth method default (RFC 7591 section 2)
|
||||
// Apply token endpoint auth method default (RFC 7591 section 2).
|
||||
if req.TokenEndpointAuthMethod == "" {
|
||||
// Default according to RFC 7591: "client_secret_basic" for confidential clients
|
||||
// For public clients, should be explicitly set to "none"
|
||||
req.TokenEndpointAuthMethod = "client_secret_basic"
|
||||
// Default according to RFC 7591: "client_secret_basic" for confidential clients.
|
||||
// For public clients, should be explicitly set to "none".
|
||||
req.TokenEndpointAuthMethod = OAuth2TokenEndpointAuthMethodClientSecretBasic
|
||||
}
|
||||
|
||||
// Apply client name default if not provided
|
||||
// Apply client name default if not provided.
|
||||
if req.ClientName == "" {
|
||||
req.ClientName = "Dynamically Registered Client"
|
||||
}
|
||||
@@ -377,29 +538,29 @@ func (req *OAuth2ClientRegistrationRequest) GenerateClientName() string {
|
||||
return "Dynamically Registered Client"
|
||||
}
|
||||
|
||||
// OAuth2ClientRegistrationResponse represents RFC 7591 Dynamic Client Registration Response
|
||||
// OAuth2ClientRegistrationResponse represents RFC 7591 Dynamic Client Registration Response.
|
||||
type OAuth2ClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
}
|
||||
|
||||
// PostOAuth2ClientRegistration dynamically registers a new OAuth2 client (RFC 7591)
|
||||
@@ -466,27 +627,26 @@ func (c *Client) DeleteOAuth2ClientConfiguration(ctx context.Context, clientID s
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuth2ClientConfiguration represents RFC 7592 Client Configuration (for GET/PUT operations)
|
||||
// Same as OAuth2ClientRegistrationResponse but without client_secret in GET responses
|
||||
// OAuth2ClientConfiguration represents RFC 7592 Client Read Response.
|
||||
type OAuth2ClientConfiguration struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken []byte `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (req *OAuth2ClientRegistrationRequest) Validate() error {
|
||||
}
|
||||
|
||||
// validateRedirectURIs validates redirect URIs according to RFC 7591, 8252
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod string) error {
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod) error {
|
||||
if len(uris) == 0 {
|
||||
return xerrors.New("at least one redirect URI is required")
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod string) error {
|
||||
}
|
||||
|
||||
// Determine if this is a public client based on token endpoint auth method
|
||||
isPublicClient := tokenEndpointAuthMethod == "none"
|
||||
isPublicClient := tokenEndpointAuthMethod == OAuth2TokenEndpointAuthMethodNone
|
||||
|
||||
// Handle different validation for public vs confidential clients
|
||||
if uri.Scheme == "http" || uri.Scheme == "https" {
|
||||
@@ -155,23 +155,15 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod string) error {
|
||||
}
|
||||
|
||||
// validateGrantTypes validates OAuth2 grant types
|
||||
func validateGrantTypes(grantTypes []string) error {
|
||||
validGrants := []string{
|
||||
string(OAuth2ProviderGrantTypeAuthorizationCode),
|
||||
string(OAuth2ProviderGrantTypeRefreshToken),
|
||||
// Add more grant types as they are implemented
|
||||
// "client_credentials",
|
||||
// "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
|
||||
func validateGrantTypes(grantTypes []OAuth2ProviderGrantType) error {
|
||||
for _, grant := range grantTypes {
|
||||
if !slices.Contains(validGrants, grant) {
|
||||
if !isSupportedGrantType(grant) {
|
||||
return xerrors.Errorf("unsupported grant type: %s", grant)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure authorization_code is present if redirect_uris are specified
|
||||
hasAuthCode := slices.Contains(grantTypes, string(OAuth2ProviderGrantTypeAuthorizationCode))
|
||||
hasAuthCode := slices.Contains(grantTypes, OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
if !hasAuthCode {
|
||||
return xerrors.New("authorization_code grant type is required when redirect_uris are specified")
|
||||
}
|
||||
@@ -179,15 +171,18 @@ func validateGrantTypes(grantTypes []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateResponseTypes validates OAuth2 response types
|
||||
func validateResponseTypes(responseTypes []string) error {
|
||||
validResponses := []string{
|
||||
string(OAuth2ProviderResponseTypeCode),
|
||||
// Add more response types as they are implemented
|
||||
func isSupportedGrantType(grant OAuth2ProviderGrantType) bool {
|
||||
switch grant {
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode, OAuth2ProviderGrantTypeRefreshToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validateResponseTypes validates OAuth2 response types
|
||||
func validateResponseTypes(responseTypes []OAuth2ProviderResponseType) error {
|
||||
for _, responseType := range responseTypes {
|
||||
if !slices.Contains(validResponses, responseType) {
|
||||
if !isSupportedResponseType(responseType) {
|
||||
return xerrors.Errorf("unsupported response type: %s", responseType)
|
||||
}
|
||||
}
|
||||
@@ -195,19 +190,34 @@ func validateResponseTypes(responseTypes []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSupportedResponseType(responseType OAuth2ProviderResponseType) bool {
|
||||
return responseType == OAuth2ProviderResponseTypeCode
|
||||
}
|
||||
|
||||
// validateTokenEndpointAuthMethod validates token endpoint authentication method
|
||||
func validateTokenEndpointAuthMethod(method string) error {
|
||||
validMethods := []string{
|
||||
"client_secret_post",
|
||||
"client_secret_basic",
|
||||
"none", // for public clients (RFC 7591)
|
||||
// Add more methods as they are implemented
|
||||
// "private_key_jwt",
|
||||
// "client_secret_jwt",
|
||||
func validateTokenEndpointAuthMethod(method OAuth2TokenEndpointAuthMethod) error {
|
||||
if !method.Valid() {
|
||||
return xerrors.Errorf("unsupported token endpoint auth method: %s", method)
|
||||
}
|
||||
|
||||
if !slices.Contains(validMethods, method) {
|
||||
return xerrors.Errorf("unsupported token endpoint auth method: %s", method)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePKCECodeChallengeMethod validates PKCE code_challenge_method parameter.
|
||||
// Per OAuth 2.1, only S256 is supported; plain is rejected for security reasons.
|
||||
func ValidatePKCECodeChallengeMethod(method string) error {
|
||||
if method == "" {
|
||||
return nil // Optional, defaults to S256 if code_challenge is provided
|
||||
}
|
||||
|
||||
m := OAuth2PKCECodeChallengeMethod(method)
|
||||
|
||||
if m == OAuth2PKCECodeChallengeMethodPlain {
|
||||
return xerrors.New("code_challenge_method 'plain' is not supported; use 'S256'")
|
||||
}
|
||||
|
||||
if m != OAuth2PKCECodeChallengeMethodS256 {
|
||||
return xerrors.Errorf("unsupported code_challenge_method: %s", method)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -146,6 +146,27 @@ func parseVariableValuesFromHCL(content []byte) ([]VariableValue, error) {
|
||||
return nil, err
|
||||
}
|
||||
stringData[attribute.Name] = string(m)
|
||||
case ctyType.IsObjectType() || ctyType.IsMapType():
|
||||
// In case of objects/maps, Coder only supports the map(string) type.
|
||||
result := map[string]string{}
|
||||
var err error
|
||||
_ = ctyValue.ForEachElement(func(key, val cty.Value) (stop bool) {
|
||||
if !val.Type().Equals(cty.String) {
|
||||
err = xerrors.Errorf("unsupported map value type for key %q: %s", key.AsString(), val.Type().GoString())
|
||||
return true
|
||||
}
|
||||
result[key.AsString()] = val.AsString()
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stringData[attribute.Name] = string(m)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported value type (name: %s): %s", attribute.Name, ctyType.GoString())
|
||||
}
|
||||
|
||||
@@ -175,3 +175,76 @@ cores: 2`
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `use the equals sign "=" to introduce the argument value`)
|
||||
}
|
||||
|
||||
func TestParseVariableValuesFromVarsFiles_MapOfStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given
|
||||
const (
|
||||
hclFilename = "file.tfvars"
|
||||
hclContent = `region = "us-east-1"
|
||||
default_tags = {
|
||||
owner_name = "John Doe"
|
||||
owner_email = "john@example.com"
|
||||
owner_slack = "@johndoe"
|
||||
}`
|
||||
)
|
||||
|
||||
// Prepare the .tfvars files
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "test-parse-variable-values-from-vars-files-map-of-strings-*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
err = os.WriteFile(filepath.Join(tempDir, hclFilename), []byte(hclContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When
|
||||
actual, err := codersdk.ParseUserVariableValues([]string{
|
||||
filepath.Join(tempDir, hclFilename),
|
||||
}, "", nil)
|
||||
|
||||
// Then
|
||||
require.NoError(t, err)
|
||||
require.Len(t, actual, 2)
|
||||
|
||||
// Results are sorted by name
|
||||
require.Equal(t, "default_tags", actual[0].Name)
|
||||
require.JSONEq(t, `{"owner_email":"john@example.com","owner_name":"John Doe","owner_slack":"@johndoe"}`, actual[0].Value)
|
||||
require.Equal(t, "region", actual[1].Name)
|
||||
require.Equal(t, "us-east-1", actual[1].Value)
|
||||
}
|
||||
|
||||
func TestParseVariableValuesFromVarsFiles_MapWithNonStringValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given - a map with non-string values should error
|
||||
const (
|
||||
hclFilename = "file.tfvars"
|
||||
hclContent = `config = {
|
||||
name = "test"
|
||||
count = 5
|
||||
}`
|
||||
)
|
||||
|
||||
// Prepare the .tfvars files
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "test-parse-variable-values-from-vars-files-map-non-string-*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
err = os.WriteFile(filepath.Join(tempDir, hclFilename), []byte(hclContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When
|
||||
actual, err := codersdk.ParseUserVariableValues([]string{
|
||||
filepath.Join(tempDir, hclFilename),
|
||||
}, "", nil)
|
||||
|
||||
// Then
|
||||
require.Nil(t, actual)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unsupported map value type")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// WorkspaceSharingSettings represents workspace sharing settings for an organization.
|
||||
type WorkspaceSharingSettings struct {
|
||||
SharingDisabled bool `json:"sharing_disabled"`
|
||||
}
|
||||
|
||||
// WorkspaceSharingSettings retrieves the workspace sharing settings for an organization.
|
||||
func (c *Client) WorkspaceSharingSettings(ctx context.Context, orgID string) (WorkspaceSharingSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/organizations/%s/settings/workspace-sharing", orgID), nil)
|
||||
if err != nil {
|
||||
return WorkspaceSharingSettings{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return WorkspaceSharingSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp WorkspaceSharingSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// PatchWorkspaceSharingSettings modifies the workspace sharing settings for an organization.
|
||||
func (c *Client) PatchWorkspaceSharingSettings(ctx context.Context, orgID string, req WorkspaceSharingSettings) (WorkspaceSharingSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/organizations/%s/settings/workspace-sharing", orgID), req)
|
||||
if err != nil {
|
||||
return WorkspaceSharingSettings{}, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return WorkspaceSharingSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp WorkspaceSharingSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
# Workspace Startup Coordination Examples
|
||||
|
||||
## Script Example
|
||||
|
||||
This example shows a complete, production-ready script that starts Claude Code
|
||||
only after a repository has been cloned. It includes error handling, graceful
|
||||
degradation, and cleanup on exit:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
UNIT_NAME="claude-code"
|
||||
DEPENDENCIES="git-clone"
|
||||
REPO_DIR="/workspace/repo"
|
||||
|
||||
# Track if sync started successfully
|
||||
SYNC_STARTED=0
|
||||
|
||||
# Declare dependencies
|
||||
if [ -n "$DEPENDENCIES" ]; then
|
||||
if command -v coder > /dev/null 2>&1; then
|
||||
IFS=',' read -ra DEPS <<< "$DEPENDENCIES"
|
||||
for dep in "${DEPS[@]}"; do
|
||||
dep=$(echo "$dep" | xargs)
|
||||
if [ -n "$dep" ]; then
|
||||
echo "Waiting for dependency: $dep"
|
||||
coder exp sync want "$UNIT_NAME" "$dep" > /dev/null 2>&1 || \
|
||||
echo "Warning: Failed to register dependency $dep, continuing..."
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "Coder CLI not found, running without sync coordination"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Start sync and track success
|
||||
if [ -n "$UNIT_NAME" ]; then
|
||||
if command -v coder > /dev/null 2>&1; then
|
||||
if coder exp sync start "$UNIT_NAME" > /dev/null 2>&1; then
|
||||
SYNC_STARTED=1
|
||||
echo "Started sync: $UNIT_NAME"
|
||||
else
|
||||
echo "Sync start failed or not available, continuing without sync..."
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Ensure completion on exit (even if script fails)
|
||||
cleanup_sync() {
|
||||
if [ "$SYNC_STARTED" -eq 1 ] && [ -n "$UNIT_NAME" ]; then
|
||||
echo "Completing sync: $UNIT_NAME"
|
||||
coder exp sync complete "$UNIT_NAME" > /dev/null 2>&1 || \
|
||||
echo "Warning: Sync complete failed, but continuing..."
|
||||
fi
|
||||
}
|
||||
trap cleanup_sync EXIT
|
||||
|
||||
# Now do the actual work
|
||||
echo "Repository cloned, starting Claude Code"
|
||||
cd "$REPO_DIR"
|
||||
claude
|
||||
```
|
||||
|
||||
This script demonstrates several [best practices](./usage.md#best-practices):
|
||||
|
||||
- Checking for Coder CLI availability before using sync commands
|
||||
- Tracking whether `coder exp sync` started successfully
|
||||
- Using `trap` to ensure completion even if the script exits early
|
||||
- Graceful degradation when `coder exp sync` isn't available
|
||||
- Redirecting `coder exp sync` output to reduce noise in logs
|
||||
|
||||
## Template Migration Example
|
||||
|
||||
Below is a simple example Docker template that clones [Miguel Grinberg's example Flask repo](https://github.com/miguelgrinberg/microblog/) using the [`git-clone` module](https://registry.coder.com/modules/coder/git-clone) and installs the required dependencies for the project:
|
||||
|
||||
- Python development headers (required for building some Python packages)
|
||||
- Python dependencies from the project's `requirements.txt`
|
||||
|
||||
We've omitted some details (such as persistent storage) for brevity, but these are easily added.
|
||||
|
||||
### Before
|
||||
|
||||
```terraform
|
||||
data "coder_provisioner" "me" {}
|
||||
data "coder_workspace" "me" {}
|
||||
data "coder_workspace_owner" "me" {}
|
||||
|
||||
resource "docker_container" "workspace" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
image = "codercom/enterprise-base:ubuntu"
|
||||
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
|
||||
entrypoint = ["sh", "-c", coder_agent.main.init_script]
|
||||
env = [
|
||||
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
|
||||
]
|
||||
}
|
||||
|
||||
resource "coder_agent" "main" {
|
||||
arch = data.coder_provisioner.me.arch
|
||||
os = "linux"
|
||||
}
|
||||
|
||||
module "git-clone" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "registry.coder.com/coder/git-clone/coder"
|
||||
version = "1.2.3"
|
||||
agent_id = coder_agent.main.id
|
||||
url = "https://github.com/miguelgrinberg/microblog"
|
||||
}
|
||||
|
||||
resource "coder_script" "setup" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
agent_id = coder_agent.main.id
|
||||
display_name = "Installing Dependencies"
|
||||
run_on_start = true
|
||||
script = <<EOT
|
||||
sudo apt-get update
|
||||
sudo apt-get install --yes python-dev-is-python3
|
||||
cd ${module.git-clone[count.index].repo_dir}
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
EOT
|
||||
}
|
||||
```
|
||||
|
||||
We can note the following issues in the above template:
|
||||
|
||||
1. There is a race between cloning the repository and the `pip install` commands, which can lead to failed workspace startups in some cases.
|
||||
2. The `apt` commands can run independently of the `git clone` command, meaning that there is a potential speedup here.
|
||||
|
||||
Based on the above, we can improve both the startup time and reliability of the template by splitting the monolithic startup script into multiple independent scripts:
|
||||
|
||||
- Install `apt` dependencies
|
||||
- Install `pip` dependencies (depends on the `git-clone` module and the above step)
|
||||
|
||||
### After
|
||||
|
||||
Here is the updated version of the template:
|
||||
|
||||
```terraform
|
||||
data "coder_provisioner" "me" {}
|
||||
data "coder_workspace" "me" {}
|
||||
data "coder_workspace_owner" "me" {}
|
||||
|
||||
resource "docker_container" "workspace" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
image = "codercom/enterprise-base:ubuntu"
|
||||
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
|
||||
entrypoint = ["sh", "-c", coder_agent.main.init_script]
|
||||
env = [
|
||||
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
|
||||
"CODER_AGENT_SOCKET_SERVER_ENABLED=true"
|
||||
]
|
||||
}
|
||||
|
||||
resource "coder_agent" "main" {
|
||||
arch = data.coder_provisioner.me.arch
|
||||
os = "linux"
|
||||
}
|
||||
|
||||
module "git-clone" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "registry.coder.com/coder/git-clone/coder"
|
||||
version = "1.2.3"
|
||||
agent_id = coder_agent.main.id
|
||||
url = "https://github.com/miguelgrinberg/microblog/"
|
||||
post_clone_script = <<-EOT
|
||||
coder exp sync start git-clone && coder exp sync complete git-clone
|
||||
EOT
|
||||
}
|
||||
|
||||
resource "coder_script" "apt-install" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
agent_id = coder_agent.main.id
|
||||
display_name = "Installing APT Dependencies"
|
||||
run_on_start = true
|
||||
script = <<EOT
|
||||
trap 'coder exp sync complete apt-install' EXIT
|
||||
coder exp sync start apt-install
|
||||
|
||||
sudo apt-get update
|
||||
sudo apt-get install --yes python-dev-is-python3
|
||||
EOT
|
||||
}
|
||||
|
||||
resource "coder_script" "pip-install" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
agent_id = coder_agent.main.id
|
||||
display_name = "Installing Python Dependencies"
|
||||
run_on_start = true
|
||||
script = <<EOT
|
||||
trap 'coder exp sync complete pip-install' EXIT
|
||||
coder exp sync want pip-install git-clone apt-install
|
||||
coder exp sync start pip-install
|
||||
|
||||
cd ${module.git-clone[count.index].repo_dir}
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
EOT
|
||||
}
|
||||
```
|
||||
|
||||
A short summary of the changes:
|
||||
|
||||
- We've added `CODER_AGENT_SOCKET_SERVER_ENABLED=true` to the environment variables of the Docker container in which the Coder agent runs.
|
||||
- We've broken the monolithic "setup" script into two separate scripts: one for the `apt` commands, and one for the `pip` commands.
|
||||
- In each script, we've added a `coder exp sync start $SCRIPT_NAME` command to mark the startup script as started.
|
||||
- We've also added an exit trap to ensure that we mark the startup scripts as completed. Without this, the `coder exp sync wait` command would eventually time out.
|
||||
- We have used the `post_clone_script` feature of the `git-clone` module to allow waiting on the Git repository clone.
|
||||
- In the `pip-install` script, we have declared a dependency on both `git-clone` and `apt-install`.
|
||||
|
||||
With these changes, the startup time has been reduced significantly and there is no longer any possibility of a race condition.
|
||||
@@ -0,0 +1,50 @@
|
||||
# Workspace Startup Coordination
|
||||
|
||||
> [!NOTE]
|
||||
> This feature is experimental and may change without notice in future releases.
|
||||
|
||||
When workspaces start, scripts often need to run in a specific order.
|
||||
For example, an IDE or coding agent might need the repository cloned
|
||||
before it can start. Without explicit coordination, these scripts can
|
||||
race against each other, leading to startup failures and inconsistent
|
||||
workspace states.
|
||||
|
||||
Coder's workspace startup coordination feature lets you declare
|
||||
dependencies between startup scripts and ensure they run in the correct order.
|
||||
This eliminates race conditions and makes workspace startup predictable and
|
||||
reliable.
|
||||
|
||||
## Why use this?
|
||||
|
||||
Simply placing all of your workspace initialization logic in a single script works, but leads to slow workspace startup times.
|
||||
Breaking this out into multiple independent `coder_script` resources improves startup times by allowing the scripts to run in parallel.
|
||||
However, this can lead to intermittent failures between dependent scripts due to timing issues.
|
||||
Up until now, template authors have had to rely on manual coordination methods (for example, touching a file upon completion).
|
||||
The goal of startup script coordination is to provide a single reliable source of truth for coordination between workspace startup scripts.
|
||||
|
||||
## Quick Start
|
||||
|
||||
To start using workspace startup coordination, follow these steps:
|
||||
|
||||
1. Set the environment variable `CODER_AGENT_SOCKET_SERVER_ENABLED=true` in your template to enable the agent socket server. The environment variable *must* be readable to the agent process. For example, in a template using the `kreuzwerker/docker` provider:
|
||||
|
||||
```terraform
|
||||
resource "docker_container" "workspace" {
|
||||
image = "codercom/enterprise-base:ubuntu"
|
||||
env = [
|
||||
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
|
||||
"CODER_AGENT_SOCKET_SERVER_ENABLED=true",
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
1. Add calls to `coder exp sync (start|complete)` in your startup scripts where required:
|
||||
|
||||
```bash
|
||||
trap 'coder exp sync complete my-script' EXIT
|
||||
coder exp sync want my-script my-other-script
|
||||
coder exp sync start my-script
|
||||
# Existing startup logic
|
||||
```
|
||||
|
||||
For more information, refer to the [usage documentation](./usage.md), [troubleshooting documentation](./troubleshooting.md), or view our [examples](./example.md).
|
||||
@@ -0,0 +1,98 @@
|
||||
# Workspace Startup Coordination Troubleshooting
|
||||
|
||||
> [!NOTE]
|
||||
> This feature is experimental and may change without notice in future releases.
|
||||
|
||||
## Test Sync Availability
|
||||
|
||||
From a workspace terminal, test if sync is working using `coder exp sync ping`:
|
||||
|
||||
```bash
|
||||
coder exp sync ping
|
||||
```
|
||||
|
||||
* If sync is working, expect the output to be `Success`.
|
||||
* Otherwise, you will see an error message similar to the below:
|
||||
|
||||
```bash
|
||||
error: connect to agent socket: connect to socket: dial unix /tmp/coder-agent.sock: connect: permission denied
|
||||
```
|
||||
|
||||
## Check Unit Status
|
||||
|
||||
You can check the status of a specific unit using `coder exp sync status`:
|
||||
|
||||
```bash
|
||||
coder exp sync status git-clone
|
||||
```
|
||||
|
||||
If the unit exists, you will see output similar to the below:
|
||||
|
||||
```bash
|
||||
# coder exp sync status git-clone
|
||||
Unit: git-clone
|
||||
Status: completed
|
||||
Ready: true
|
||||
```
|
||||
|
||||
If the unit is not known to the agent, you will see output similar to the below:
|
||||
|
||||
```bash
|
||||
# coder exp sync status doesnotexist
|
||||
Unit: doesnotexist
|
||||
Status: not registered
|
||||
Ready: true
|
||||
|
||||
Dependencies:
|
||||
No dependencies found
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Socket not enabled
|
||||
|
||||
If the Coder Agent Socket Server is not enabled, you will see an error message similar to the below when running `coder exp sync ping`:
|
||||
|
||||
```bash
|
||||
error: connect to agent socket: connect to socket: dial unix /tmp/coder-agent.sock: connect: no such file or directory
|
||||
```
|
||||
|
||||
Verify `CODER_AGENT_SOCKET_SERVER_ENABLED=true` is set in the Coder agent's environment:
|
||||
|
||||
```bash
|
||||
tr '\0' '\n' < /proc/$(pidof -s coder)/environ | grep CODER_AGENT_SOCKET_SERVER_ENABLED
|
||||
```
|
||||
|
||||
If the output of the above command is empty, review your template and ensure that the environment variable is set such that it is readable by the Coder agent process. Setting it on the `coder_agent` resource directly is **not** sufficient.
|
||||
|
||||
## Workspace startup script hangs
|
||||
|
||||
If the workspace startup scripts appear to 'hang', one or more of your startup scripts may be waiting for a dependency that never completes.
|
||||
|
||||
* Inside the workspace, review `/tmp/coder-script-*.log` for more details on your script's execution.
|
||||
> **Tip:** add `set -x` to the top of your script to enable debug mode and update/restart the workspace.
|
||||
* Review your template and verify that `coder exp sync complete <unit>` is called after the script completes e.g. with an exit trap.
|
||||
* View the unit status using `coder exp sync status <unit>`.
|
||||
|
||||
## Workspace startup scripts fail
|
||||
|
||||
If the workspace startup scripts fail:
|
||||
|
||||
* Review `/tmp/coder-script-*.log` inside the workspace for script errors.
|
||||
* Verify the Coder CLI is available in `$PATH` inside the workspace:
|
||||
|
||||
```bash
|
||||
command -v coder
|
||||
```
|
||||
|
||||
## Cycle detected
|
||||
|
||||
If you see an error similar to the below in your startup script logs, you have defined a cyclic dependency:
|
||||
|
||||
```bash
|
||||
error: declare dependency failed: cannot add dependency: adding edge for unit "bar": failed to add dependency
|
||||
adding edge (bar -> foo): cycle detected
|
||||
```
|
||||
|
||||
To fix this, review your dependency declarations and redesign them to remove the cycle. It may help to draw out the dependency graph to find
|
||||
the cycle.
|
||||
@@ -0,0 +1,283 @@
|
||||
# Workspace Startup Coordination Usage
|
||||
|
||||
> [!NOTE]
|
||||
> This feature is experimental and may change without notice in future releases.
|
||||
|
||||
Startup coordination is built around the concept of **units**. You declare units in your Coder workspace template using the `coder exp sync` command in `coder_script` resources. When the Coder agent starts, it keeps an in-memory directed acyclic graph (DAG) of all units of which it is aware. When you need to synchronize with another unit, you can use `coder exp sync start $UNIT_NAME` to block until all dependencies of that unit have been marked complete.
|
||||
|
||||
## What is a unit?
|
||||
|
||||
A **unit** is a named phase of work, typically corresponding to a script or initialization
|
||||
task.
|
||||
|
||||
- Units **may** declare dependencies on other units, creating an explicit ordering for workspace initialization.
|
||||
- Units **must** be registered before they can be marked as complete.
|
||||
- Units **may** be marked as dependencies before they are registered.
|
||||
- Units **must not** declare cyclic dependencies. Attempting to create a cyclic dependency will result in an error.
|
||||
|
||||
## Requirements
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The `coder exp sync` command is only available from Coder version >=v2.30 onwards.
|
||||
|
||||
To use startup dependencies in your templates, you must:
|
||||
|
||||
- Enable the Coder Agent Socket Server.
|
||||
- Modify your workspace startup scripts to run in parallel and declare dependencies as required using `coder exp sync`.
|
||||
|
||||
### Enable the Coder Agent Socket Server
|
||||
|
||||
The agent socket server provides the communication layer for startup
|
||||
coordination. To enable it, set `CODER_AGENT_SOCKET_SERVER_ENABLED=true` in the environment in which the agent is running.
|
||||
The exact method for doing this depends on your infrastructure platform:
|
||||
|
||||
<div class="tabs">
|
||||
|
||||
#### Docker / Podman
|
||||
|
||||
```hcl
|
||||
resource "docker_container" "workspace" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
image = "codercom/enterprise-base:ubuntu"
|
||||
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
|
||||
|
||||
env = [
|
||||
"CODER_AGENT_SOCKET_SERVER_ENABLED=true"
|
||||
]
|
||||
|
||||
command = ["sh", "-c", coder_agent.main.init_script]
|
||||
}
|
||||
```
|
||||
|
||||
#### Kubernetes
|
||||
|
||||
```hcl
|
||||
resource "kubernetes_pod" "main" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
|
||||
metadata {
|
||||
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
|
||||
namespace = var.workspaces_namespace
|
||||
}
|
||||
|
||||
spec {
|
||||
container {
|
||||
name = "dev"
|
||||
image = "codercom/enterprise-base:ubuntu"
|
||||
command = ["sh", "-c", coder_agent.main.init_script]
|
||||
|
||||
env {
|
||||
name = "CODER_AGENT_SOCKET_SERVER_ENABLED"
|
||||
value = "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### AWS EC2 / VMs
|
||||
|
||||
For virtual machines, pass the environment variable through cloud-init or your
|
||||
provisioning system:
|
||||
|
||||
```hcl
|
||||
locals {
|
||||
agent_env = {
|
||||
"CODER_AGENT_SOCKET_SERVER_ENABLED" = "true"
|
||||
}
|
||||
}
|
||||
|
||||
# In your cloud-init userdata template:
|
||||
# %{ for key, value in local.agent_env ~}
|
||||
# export ${key}="${value}"
|
||||
# %{ endfor ~}
|
||||
```
|
||||
|
||||
</div>
|
||||
|
||||
### Declare Dependencies in your Workspace Startup Scripts
|
||||
|
||||
<div class="tabs">
|
||||
|
||||
#### Single Dependency
|
||||
|
||||
Here's a simple example of a script that depends on another unit completing
|
||||
first:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
UNIT_NAME="my-setup"
|
||||
|
||||
# Declare dependency on git-clone
|
||||
coder exp sync want "$UNIT_NAME" "git-clone"
|
||||
|
||||
# Wait for dependencies and mark as started
|
||||
coder exp sync start "$UNIT_NAME"
|
||||
|
||||
# Do your work here
|
||||
echo "Running after git-clone completes"
|
||||
|
||||
# Signal completion
|
||||
coder exp sync complete "$UNIT_NAME"
|
||||
```
|
||||
|
||||
This script will wait until the `git-clone` unit completes before starting its
|
||||
own work.
|
||||
|
||||
#### Multiple Dependencies
|
||||
|
||||
If your unit depends on multiple other units, you can declare all dependencies
|
||||
before starting:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
UNIT_NAME="my-app"
|
||||
DEPENDENCIES="git-clone,env-setup,database-migration"
|
||||
|
||||
# Declare all dependencies
|
||||
if [ -n "$DEPENDENCIES" ]; then
|
||||
IFS=',' read -ra DEPS <<< "$DEPENDENCIES"
|
||||
for dep in "${DEPS[@]}"; do
|
||||
dep=$(echo "$dep" | xargs) # Trim whitespace
|
||||
if [ -n "$dep" ]; then
|
||||
coder exp sync want "$UNIT_NAME" "$dep"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
# Wait for all dependencies
|
||||
coder exp sync start "$UNIT_NAME"
|
||||
|
||||
# Your work here
|
||||
echo "All dependencies satisfied, starting application"
|
||||
|
||||
# Signal completion
|
||||
coder exp sync complete "$UNIT_NAME"
|
||||
```
|
||||
|
||||
</div>
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Test your changes before rolling out to all users
|
||||
|
||||
Before rolling out to all users:
|
||||
|
||||
1. Create a test workspace from the updated template
|
||||
2. Check workspace build logs for sync messages
|
||||
3. Verify all units reach "completed" status
|
||||
4. Test workspace functionality
|
||||
|
||||
Once you're satisfied, [promote the new template version](../../../reference/cli/templates_versions_promote.md).
|
||||
|
||||
### Handle missing CLI gracefully
|
||||
|
||||
Not all workspaces will have the Coder CLI available in `$PATH`. Check for availability of the Coder CLI before using
|
||||
sync commands:
|
||||
|
||||
```bash
|
||||
if command -v coder > /dev/null 2>&1; then
|
||||
coder exp sync start "$UNIT_NAME"
|
||||
else
|
||||
echo "Coder CLI not available, continuing without coordination"
|
||||
fi
|
||||
```
|
||||
|
||||
### Complete units that start successfully
|
||||
|
||||
Units **must** call `coder exp sync complete` to unblock dependent units. Use `trap` to ensure
|
||||
completion even if your script exits early or encounters errors:
|
||||
|
||||
```bash
|
||||
|
||||
SYNC_STARTED=0
|
||||
if coder exp sync start "$UNIT_NAME"; then
|
||||
SYNC_STARTED=1
|
||||
fi
|
||||
|
||||
cleanup_sync() {
|
||||
if [ "$SYNC_STARTED" -eq 1 ]; then
|
||||
coder exp sync complete "$UNIT_NAME"
|
||||
fi
|
||||
}
|
||||
trap cleanup_sync EXIT
|
||||
```
|
||||
|
||||
### Use descriptive unit names
|
||||
|
||||
Names should explain what the unit does, not its position in a sequence:
|
||||
|
||||
- Good: `git-clone`, `env-setup`, `database-migration`
|
||||
- Avoid: `step1`, `init`, `script-1`
|
||||
|
||||
### Prefix a unique name to your units
|
||||
|
||||
When using `coder exp sync` in modules, note that unit names like `git-clone` might be common. Prefix the name of your module to your units to
|
||||
ensure that your unit does not conflict with others.
|
||||
|
||||
- Good: `<module>.git-clone`, `<module>.claude`
|
||||
- Bad: `git-clone`, `claude`
|
||||
|
||||
### Document dependencies
|
||||
|
||||
Add comments explaining why dependencies exist:
|
||||
|
||||
```hcl
|
||||
resource "coder_script" "ide_setup" {
|
||||
# Depends on git-clone because we need .vscode/extensions.json
|
||||
# Depends on env-setup because we need $NODE_PATH configured
|
||||
script = <<-EOT
|
||||
coder exp sync want "ide-setup" "git-clone"
|
||||
coder exp sync want "ide-setup" "env-setup"
|
||||
# ...
|
||||
EOT
|
||||
}
|
||||
```
|
||||
|
||||
### Avoid circular dependencies
|
||||
|
||||
The Coder Agent detects and rejects circular dependencies, but they indicate a design problem:
|
||||
|
||||
```bash
|
||||
# This will fail
|
||||
coder exp sync want "unit-a" "unit-b"
|
||||
coder exp sync want "unit-b" "unit-a"
|
||||
```
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
### How do I identify scripts that can benefit from startup coordination?
|
||||
|
||||
Look for these patterns in existing templates:
|
||||
|
||||
- `sleep` commands used to order scripts
|
||||
- Using files to coordinate startup between scripts (e.g. `touch /tmp/startup-complete`)
|
||||
- Scripts that fail intermittently on startup
|
||||
- Comments like "must run after X" or "wait for Y"
|
||||
|
||||
### Will this slow down my workspace?
|
||||
|
||||
No. The socket server adds minimal overhead, and the default polling interval is 1
|
||||
second, so waiting for dependencies adds at most a few seconds to startup.
|
||||
You are more likely to notice an improvement in startup times as it becomes easier to manage complex dependencies in parallel.
|
||||
|
||||
### How do units interact with each other?
|
||||
|
||||
Units with no dependencies run immediately and in parallel.
|
||||
Only units with unsatisfied dependencies wait for their dependencies.
|
||||
|
||||
### How long can a dependency take to complete?
|
||||
|
||||
By default, `coder exp sync start` has a 5-minute timeout to prevent indefinite hangs.
|
||||
Upon timeout, the command will exit with an error code and print `timeout waiting for dependencies of unit <unit_name>` to stderr.
|
||||
|
||||
You can adjust this timeout as necessary for long-running operations:
|
||||
|
||||
```bash
|
||||
coder exp sync start "long-operation" --timeout 10m
|
||||
```
|
||||
|
||||
### Is state stored between restarts?
|
||||
|
||||
No. Sync state is kept in-memory only and resets on workspace restart.
|
||||
This is intentional to ensure clean initialization on every start.
|
||||
@@ -23,8 +23,9 @@ Rules follow the format: `key=value [key=value ...]` with three supported keys:
|
||||
|
||||
```yaml
|
||||
allowlist:
|
||||
- domain=github.com # All methods, all paths for github.com
|
||||
- method=GET,POST domain=api.example.com # GET/POST to api.example.com
|
||||
- domain=github.com # All methods, all paths for github.com (exact match)
|
||||
- domain=*.github.com # All subdomains of github.com
|
||||
- method=GET,POST domain=api.example.com # GET/POST to api.example.com (exact match)
|
||||
- domain=api.example.com path=/users,/posts # Multiple paths
|
||||
- method=GET domain=github.com path=/api/* # All three keys
|
||||
```
|
||||
@@ -35,19 +36,20 @@ allowlist:
|
||||
|
||||
The `*` wildcard matches domain labels (parts separated by dots).
|
||||
|
||||
| Pattern | Matches | Does NOT Match |
|
||||
|----------------|------------------------------------------------------------------|--------------------------------------------------------------------------|
|
||||
| `*` | All domains | - |
|
||||
| `github.com` | `github.com`, `api.github.com`, `v1.api.github.com` (subdomains) | `github.io` (diff domain) |
|
||||
| `*.github.com` | `api.github.com`, `v1.api.github.com` (1+ subdomain levels) | `github.com` (base domain) |
|
||||
| `api.*.com` | `api.github.com`, `api.google.com` | `api.v1.github.com` (`*` in the middle matches exactly one domain label) |
|
||||
| `*.*.com` | `api.example.com`, `api.v1.github.com` | - |
|
||||
| `api.*` | ❌ **ERROR** - Cannot end with `*` | - |
|
||||
| Pattern | Matches | Does NOT Match |
|
||||
|----------------|-------------------------------------------------------------|--------------------------------------------------------------------------|
|
||||
| `*` | All domains | - |
|
||||
| `github.com` | `github.com` (exact match only) | `api.github.com`, `v1.api.github.com` (subdomains), `github.io` |
|
||||
| `*.github.com` | `api.github.com`, `v1.api.github.com` (1+ subdomain levels) | `github.com` (base domain) |
|
||||
| `api.*.com` | `api.github.com`, `api.google.com` | `api.v1.github.com` (`*` in the middle matches exactly one domain label) |
|
||||
| `*.*.com` | `api.example.com`, `api.v1.github.com` | - |
|
||||
| `api.*` | ❌ **ERROR** - Cannot end with `*` | - |
|
||||
|
||||
**Important**:
|
||||
|
||||
- Patterns without `*` at the start automatically match subdomains
|
||||
- Patterns without `*` match **exactly** (no automatic subdomain matching)
|
||||
- `*.example.com` matches one or more subdomain levels
|
||||
- To match both base domain and subdomains, use separate rules: `domain=github.com` and `domain=*.github.com`
|
||||
- Domain patterns **cannot end with asterisk**
|
||||
|
||||
---
|
||||
|
||||
@@ -667,6 +667,29 @@
|
||||
"description": "Log workspace processes",
|
||||
"path": "./admin/templates/extending-templates/process-logging.md",
|
||||
"state": ["premium"]
|
||||
},
|
||||
{
|
||||
"title": "Startup Dependencies",
|
||||
"description": "Coordinate workspace startup with dependency management",
|
||||
"path": "./admin/templates/startup-coordination/index.md",
|
||||
"state": ["early access"],
|
||||
"children": [
|
||||
{
|
||||
"title": "Usage",
|
||||
"description": "How to use startup coordination",
|
||||
"path": "./admin/templates/startup-coordination/usage.md"
|
||||
},
|
||||
{
|
||||
"title": "Troubleshooting",
|
||||
"description": "Troubleshoot startup coordination",
|
||||
"path": "./admin/templates/startup-coordination/troubleshooting.md"
|
||||
},
|
||||
{
|
||||
"title": "Examples",
|
||||
"description": "Examples of startup coordination",
|
||||
"path": "./admin/templates/startup-coordination/example.md"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -1583,6 +1606,11 @@
|
||||
"description": "Role sync settings to sync organization roles from an IdP.",
|
||||
"path": "reference/cli/organizations_settings_set_role-sync.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings set workspace-sharing",
|
||||
"description": "Workspace sharing settings for the organization.",
|
||||
"path": "reference/cli/organizations_settings_set_workspace-sharing.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings show",
|
||||
"description": "Outputs specified organization setting.",
|
||||
@@ -1603,6 +1631,11 @@
|
||||
"description": "Role sync settings to sync organization roles from an IdP.",
|
||||
"path": "reference/cli/organizations_settings_show_role-sync.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings show workspace-sharing",
|
||||
"description": "Workspace sharing settings for the organization.",
|
||||
"path": "reference/cli/organizations_settings_show_workspace-sharing.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations show",
|
||||
"description": "Show the organization. Using \"selected\" will show the selected organization from the \"--org\" flag. Using \"me\" will show all organizations you are a member of.",
|
||||
|
||||
Generated
+114
-34
@@ -20,15 +20,15 @@ curl -X GET http://coder-server:8080/api/v2/.well-known/oauth-authorization-serv
|
||||
{
|
||||
"authorization_endpoint": "string",
|
||||
"code_challenge_methods_supported": [
|
||||
"string"
|
||||
"S256"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"issuer": "string",
|
||||
"registration_endpoint": "string",
|
||||
"response_types_supported": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"revocation_endpoint": "string",
|
||||
"scopes_supported": [
|
||||
@@ -36,7 +36,7 @@ curl -X GET http://coder-server:8080/api/v2/.well-known/oauth-authorization-serv
|
||||
],
|
||||
"token_endpoint": "string",
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"string"
|
||||
"client_secret_basic"
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -1265,9 +1265,9 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/authorize?client_id=string&st
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|----------|
|
||||
| `response_type` | `code` |
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|-----------------|
|
||||
| `response_type` | `code`, `token` |
|
||||
|
||||
### Responses
|
||||
|
||||
@@ -1301,9 +1301,9 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/authorize?client_id=string&s
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|----------|
|
||||
| `response_type` | `code` |
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|-----------------|
|
||||
| `response_type` | `code`, `token` |
|
||||
|
||||
### Responses
|
||||
|
||||
@@ -1346,7 +1346,7 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1355,17 +1355,15 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1399,7 +1397,7 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1409,13 +1407,13 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1442,7 +1440,7 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1451,17 +1449,15 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1519,7 +1515,7 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1529,13 +1525,13 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1562,7 +1558,7 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1574,12 +1570,12 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1662,9 +1658,9 @@ grant_type: authorization_code
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|----------------|---------------------------------------|
|
||||
| `» grant_type` | `authorization_code`, `refresh_token` |
|
||||
| Parameter | Value(s) |
|
||||
|----------------|-------------------------------------------------------------------------------------|
|
||||
| `» grant_type` | `authorization_code`, `client_credentials`, `implicit`, `password`, `refresh_token` |
|
||||
|
||||
### Example responses
|
||||
|
||||
@@ -2832,6 +2828,90 @@ curl -X PATCH http://coder-server:8080/api/v2/organizations/{organization}/setti
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Get workspace sharing settings for organization
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/settings/workspace-sharing \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /organizations/{organization}/settings/workspace-sharing`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------------|------|--------------|----------|-----------------|
|
||||
| `organization` | path | string(uuid) | true | Organization ID |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Update workspace sharing settings for organization
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X PATCH http://coder-server:8080/api/v2/organizations/{organization}/settings/workspace-sharing \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`PATCH /organizations/{organization}/settings/workspace-sharing`
|
||||
|
||||
> Body parameter
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------------|------|----------------------------------------------------------------------------------|----------|----------------------------|
|
||||
| `organization` | path | string(uuid) | true | Organization ID |
|
||||
| `body` | body | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) | true | Workspace sharing settings |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Fetch provisioner key details
|
||||
|
||||
### Code samples
|
||||
|
||||
Generated
+2
-1
@@ -191,7 +191,8 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
|
||||
Generated
+168
-95
@@ -396,7 +396,8 @@
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
}
|
||||
```
|
||||
|
||||
@@ -412,6 +413,7 @@
|
||||
| `openai` | [codersdk.AIBridgeOpenAIConfig](#codersdkaibridgeopenaiconfig) | false | | |
|
||||
| `rate_limit` | integer | false | | |
|
||||
| `retention` | integer | false | | |
|
||||
| `structured_logging` | boolean | false | | |
|
||||
|
||||
## codersdk.AIBridgeInterception
|
||||
|
||||
@@ -743,7 +745,8 @@
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -2658,7 +2661,8 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
@@ -3202,7 +3206,8 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
@@ -5183,15 +5188,15 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
{
|
||||
"authorization_endpoint": "string",
|
||||
"code_challenge_methods_supported": [
|
||||
"string"
|
||||
"S256"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"issuer": "string",
|
||||
"registration_endpoint": "string",
|
||||
"response_types_supported": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"revocation_endpoint": "string",
|
||||
"scopes_supported": [
|
||||
@@ -5199,25 +5204,25 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
],
|
||||
"token_endpoint": "string",
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"string"
|
||||
"client_secret_basic"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `authorization_endpoint` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of string | false | | |
|
||||
| `grant_types_supported` | array of string | false | | |
|
||||
| `issuer` | string | false | | |
|
||||
| `registration_endpoint` | string | false | | |
|
||||
| `response_types_supported` | array of string | false | | |
|
||||
| `revocation_endpoint` | string | false | | |
|
||||
| `scopes_supported` | array of string | false | | |
|
||||
| `token_endpoint` | string | false | | |
|
||||
| `token_endpoint_auth_methods_supported` | array of string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------------------------------|-------------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `authorization_endpoint` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of [codersdk.OAuth2PKCECodeChallengeMethod](#codersdkoauth2pkcecodechallengemethod) | false | | |
|
||||
| `grant_types_supported` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `issuer` | string | false | | |
|
||||
| `registration_endpoint` | string | false | | |
|
||||
| `response_types_supported` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `revocation_endpoint` | string | false | | |
|
||||
| `scopes_supported` | array of string | false | | |
|
||||
| `token_endpoint` | string | false | | |
|
||||
| `token_endpoint_auth_methods_supported` | array of [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientConfiguration
|
||||
|
||||
@@ -5232,7 +5237,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5241,45 +5246,43 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | array of integer | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientRegistrationRequest
|
||||
|
||||
@@ -5291,7 +5294,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5301,37 +5304,37 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `client_name` | string | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_statement` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_name` | string | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_statement` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientRegistrationResponse
|
||||
|
||||
@@ -5347,7 +5350,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"string"
|
||||
"authorization_code"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5359,41 +5362,41 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"string"
|
||||
"code"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2Config
|
||||
|
||||
@@ -5457,6 +5460,20 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
| `device_flow` | boolean | false | | |
|
||||
| `enterprise_base_url` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2PKCECodeChallengeMethod
|
||||
|
||||
```json
|
||||
"S256"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------|
|
||||
| `S256`, `plain` |
|
||||
|
||||
## codersdk.OAuth2ProtectedResourceMetadata
|
||||
|
||||
```json
|
||||
@@ -5544,6 +5561,48 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
| `client_secret_full` | string | false | | |
|
||||
| `id` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ProviderGrantType
|
||||
|
||||
```json
|
||||
"authorization_code"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-------------------------------------------------------------------------------------|
|
||||
| `authorization_code`, `client_credentials`, `implicit`, `password`, `refresh_token` |
|
||||
|
||||
## codersdk.OAuth2ProviderResponseType
|
||||
|
||||
```json
|
||||
"code"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------|
|
||||
| `code`, `token` |
|
||||
|
||||
## codersdk.OAuth2TokenEndpointAuthMethod
|
||||
|
||||
```json
|
||||
"client_secret_basic"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------------------------------------------|
|
||||
| `client_secret_basic`, `client_secret_post`, `none` |
|
||||
|
||||
## codersdk.OAuthConversionResponse
|
||||
|
||||
```json
|
||||
@@ -11666,6 +11725,20 @@ If the schedule is empty, the user will be updated to use the default schedule.|
|
||||
|--------------------|
|
||||
| ``, `admin`, `use` |
|
||||
|
||||
## codersdk.WorkspaceSharingSettings
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|--------------------|---------|----------|--------------|-------------|
|
||||
| `sharing_disabled` | boolean | false | | |
|
||||
|
||||
## codersdk.WorkspaceStatus
|
||||
|
||||
```json
|
||||
|
||||
@@ -24,3 +24,4 @@ coder organizations settings set
|
||||
| [<code>group-sync</code>](./organizations_settings_set_group-sync.md) | Group sync settings to sync groups from an IdP. |
|
||||
| [<code>role-sync</code>](./organizations_settings_set_role-sync.md) | Role sync settings to sync organization roles from an IdP. |
|
||||
| [<code>organization-sync</code>](./organizations_settings_set_organization-sync.md) | Organization sync settings to sync organization memberships from an IdP. |
|
||||
| [<code>workspace-sharing</code>](./organizations_settings_set_workspace-sharing.md) | Workspace sharing settings for the organization. |
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<!-- DO NOT EDIT | GENERATED CONTENT -->
|
||||
# organizations settings set workspace-sharing
|
||||
|
||||
Workspace sharing settings for the organization.
|
||||
|
||||
Aliases:
|
||||
|
||||
* workspacesharing
|
||||
|
||||
## Usage
|
||||
|
||||
```console
|
||||
coder organizations settings set workspace-sharing
|
||||
```
|
||||
@@ -24,3 +24,4 @@ coder organizations settings show
|
||||
| [<code>group-sync</code>](./organizations_settings_show_group-sync.md) | Group sync settings to sync groups from an IdP. |
|
||||
| [<code>role-sync</code>](./organizations_settings_show_role-sync.md) | Role sync settings to sync organization roles from an IdP. |
|
||||
| [<code>organization-sync</code>](./organizations_settings_show_organization-sync.md) | Organization sync settings to sync organization memberships from an IdP. |
|
||||
| [<code>workspace-sharing</code>](./organizations_settings_show_workspace-sharing.md) | Workspace sharing settings for the organization. |
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
<!-- DO NOT EDIT | GENERATED CONTENT -->
|
||||
# organizations settings show workspace-sharing
|
||||
|
||||
Workspace sharing settings for the organization.
|
||||
|
||||
Aliases:
|
||||
|
||||
* workspacesharing
|
||||
|
||||
## Usage
|
||||
|
||||
```console
|
||||
coder organizations settings show workspace-sharing
|
||||
```
|
||||
Generated
+11
@@ -1836,6 +1836,17 @@ Maximum number of concurrent AI Bridge requests per replica. Set to 0 to disable
|
||||
|
||||
Maximum number of AI Bridge requests per second per replica. Set to 0 to disable (unlimited).
|
||||
|
||||
### --aibridge-structured-logging
|
||||
|
||||
| | |
|
||||
|-------------|-------------------------------------------------|
|
||||
| Type | <code>bool</code> |
|
||||
| Environment | <code>$CODER_AIBRIDGE_STRUCTURED_LOGGING</code> |
|
||||
| YAML | <code>aibridge.structuredLogging</code> |
|
||||
| Default | <code>false</code> |
|
||||
|
||||
Emit structured logs for AI Bridge interception records. Use this for exporting these records to external SIEM or observability systems.
|
||||
|
||||
### --aibridge-proxy-enabled
|
||||
|
||||
| | |
|
||||
|
||||
@@ -124,77 +124,6 @@ property in your `devcontainer.json`:
|
||||
This maps container ports to the parent workspace, which can then be forwarded
|
||||
using the main workspace agent.
|
||||
|
||||
## Docker Compose dev containers
|
||||
|
||||
Dev containers support Docker Compose for multi-container environments. When you
|
||||
define a dev container with `dockerComposeFile` and `service` properties, the
|
||||
devcontainer CLI orchestrates all services defined in your Compose file.
|
||||
|
||||
### Configuration
|
||||
|
||||
To use Docker Compose, your `devcontainer.json` must specify:
|
||||
|
||||
- `dockerComposeFile`: Path to your Docker Compose file(s)
|
||||
- `service`: The container Coder connects to (becomes the dev container sub-agent)
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "My Project",
|
||||
"dockerComposeFile": "docker-compose.yml",
|
||||
"service": "app",
|
||||
"workspaceFolder": "/workspace"
|
||||
}
|
||||
```
|
||||
|
||||
With a corresponding `docker-compose.yml`:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
app:
|
||||
image: mcr.microsoft.com/devcontainers/base:ubuntu
|
||||
volumes:
|
||||
- .:/workspace
|
||||
command: sleep infinity
|
||||
database:
|
||||
image: postgres:16
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
```
|
||||
|
||||
The `app` service becomes your dev container with full Coder integration (SSH,
|
||||
web terminal, VS Code). The `database` service runs as a sidecar container.
|
||||
|
||||
### Container communication
|
||||
|
||||
Containers in a Compose setup communicate via Docker's internal DNS. From the
|
||||
primary container, reach sidecar services by their service name:
|
||||
|
||||
```console
|
||||
psql -h database -U postgres
|
||||
```
|
||||
|
||||
### Accessing sidecar services
|
||||
|
||||
Since only the primary service container runs as a Coder sub-agent, you cannot
|
||||
SSH or port-forward directly to sidecar containers. Instead:
|
||||
|
||||
1. **From the dev container**: Connect to sidecars using their service name
|
||||
(e.g., `psql -h database`).
|
||||
|
||||
1. **From your local machine**: Access sidecar services through the primary
|
||||
container. For example, run a proxy command in the dev container, then
|
||||
port-forward that port.
|
||||
|
||||
### Limitations
|
||||
|
||||
- The `forwardPorts` property with `host:port` syntax (e.g., `"database:5432"`)
|
||||
for forwarding ports from sidecar containers to your local machine is not yet
|
||||
supported.
|
||||
- Only the primary service container has Coder agent integration.
|
||||
|
||||
For more details on Docker Compose dev containers, see the
|
||||
[Dev Container specification](https://containers.dev/implementors/spec/#docker-compose-based).
|
||||
|
||||
## Dev container features
|
||||
|
||||
You can use standard [dev container features](https://containers.dev/features)
|
||||
|
||||
@@ -60,7 +60,9 @@ as [JetBrains](./workspace-access/jetbrains/index.md) or
|
||||
|
||||
Once started, the Coder agent is responsible for running your workspace startup
|
||||
scripts. These may configure tools, service connections, or personalization with
|
||||
[dotfiles](./workspace-dotfiles.md).
|
||||
[dotfiles](./workspace-dotfiles.md). For complex initialization with multiple
|
||||
dependent scripts, see
|
||||
[Workspace Startup Coordination](../admin/templates/startup-coordination/index.md).
|
||||
|
||||
Once these steps have completed, your workspace will now be in the `Running`
|
||||
state. You can access it via any of the [supported methods](./index.md), stop it
|
||||
|
||||
@@ -11,8 +11,8 @@ RUN cargo install jj-cli typos-cli watchexec-cli
|
||||
FROM ubuntu:jammy@sha256:104ae83764a5119017b8e8d6218fa0832b09df65aae7d5a6de29a85d813da2fb AS go
|
||||
|
||||
# Install Go manually, so that we can control the version
|
||||
ARG GO_VERSION=1.24.10
|
||||
ARG GO_CHECKSUM="dd52b974e3d9c5a7bbfb222c685806def6be5d6f7efd10f9caa9ca1fa2f47955"
|
||||
ARG GO_VERSION=1.24.11
|
||||
ARG GO_CHECKSUM="bceca00afaac856bc48b4cc33db7cd9eb383c81811379faed3bdbc80edb0af65"
|
||||
|
||||
# Boring Go is needed to build FIPS-compliant binaries.
|
||||
RUN apt-get update && \
|
||||
|
||||
@@ -7,7 +7,6 @@ allowlist:
|
||||
- domain=dev.coder.com
|
||||
|
||||
# test domains
|
||||
- method=GET domain=google.com
|
||||
- method=GET domain=typicode.com
|
||||
|
||||
# domain used in coder task workspaces
|
||||
|
||||
+1
-15
@@ -290,11 +290,6 @@ data "coder_parameter" "ide_choices" {
|
||||
value = "jetbrains"
|
||||
icon = "/icon/jetbrains.svg"
|
||||
}
|
||||
option {
|
||||
name = "JetBrains Fleet"
|
||||
value = "fleet"
|
||||
icon = "/icon/fleet.svg"
|
||||
}
|
||||
option {
|
||||
name = "Cursor"
|
||||
value = "cursor"
|
||||
@@ -458,15 +453,6 @@ module "zed" {
|
||||
folder = local.repo_dir
|
||||
}
|
||||
|
||||
module "jetbrains-fleet" {
|
||||
count = contains(jsondecode(data.coder_parameter.ide_choices.value), "fleet") ? data.coder_workspace.me.start_count : 0
|
||||
source = "registry.coder.com/coder/jetbrains-fleet/coder"
|
||||
version = "1.0.2"
|
||||
agent_id = coder_agent.dev.id
|
||||
agent_name = "dev"
|
||||
folder = local.repo_dir
|
||||
}
|
||||
|
||||
module "devcontainers-cli" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "dev.registry.coder.com/modules/devcontainers-cli/coder"
|
||||
@@ -904,7 +890,7 @@ module "claude-code" {
|
||||
source = "dev.registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.3.0"
|
||||
enable_boundary = true
|
||||
boundary_version = "v0.5.2"
|
||||
boundary_version = "v0.5.5"
|
||||
agent_id = coder_agent.dev.id
|
||||
workdir = local.repo_dir
|
||||
claude_code_version = "latest"
|
||||
|
||||
@@ -46,6 +46,10 @@ var (
|
||||
ErrNoExternalAuthLinkFound = xerrors.New("no external auth link found")
|
||||
)
|
||||
|
||||
const (
|
||||
InterceptionLogMarker = "interception log"
|
||||
)
|
||||
|
||||
var _ aibridged.DRPCServer = &Server{}
|
||||
|
||||
type store interface {
|
||||
@@ -73,7 +77,8 @@ type Server struct {
|
||||
logger slog.Logger
|
||||
externalAuthConfigs map[string]*externalauth.Config
|
||||
|
||||
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
||||
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
||||
structuredLogging bool
|
||||
}
|
||||
|
||||
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
|
||||
@@ -92,8 +97,9 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
|
||||
srv := &Server{
|
||||
lifecycleCtx: lifecycleCtx,
|
||||
store: store,
|
||||
logger: logger.Named("aibridgedserver"),
|
||||
logger: logger,
|
||||
externalAuthConfigs: eac,
|
||||
structuredLogging: bridgeCfg.StructuredLogging.Value(),
|
||||
}
|
||||
|
||||
if bridgeCfg.InjectCoderMCPTools {
|
||||
@@ -123,13 +129,33 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
|
||||
return nil, xerrors.Errorf("empty API key ID")
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "interception_start"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("initiator_id", initID.String()),
|
||||
slog.F("api_key_id", in.ApiKeyId),
|
||||
slog.F("provider", in.Provider),
|
||||
slog.F("model", in.Model),
|
||||
slog.F("started_at", in.StartedAt.AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: intcID,
|
||||
APIKeyID: sql.NullString{String: in.ApiKeyId, Valid: true},
|
||||
InitiatorID: initID,
|
||||
Provider: in.Provider,
|
||||
Model: in.Model,
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
Metadata: out,
|
||||
StartedAt: in.StartedAt.AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -148,6 +174,14 @@ func (s *Server) RecordInterceptionEnded(ctx context.Context, in *proto.RecordIn
|
||||
return nil, xerrors.Errorf("invalid interception ID %q: %w", in.GetId(), err)
|
||||
}
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "interception_end"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("ended_at", in.EndedAt.AsTime()),
|
||||
)
|
||||
}
|
||||
|
||||
_, err = s.store.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intcID,
|
||||
EndedAt: in.EndedAt.AsTime(),
|
||||
@@ -168,18 +202,38 @@ func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsag
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "token_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("input_tokens", in.GetInputTokens()),
|
||||
slog.F("output_tokens", in.GetOutputTokens()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
ProviderResponseID: in.GetMsgId(),
|
||||
InputTokens: in.GetInputTokens(),
|
||||
OutputTokens: in.GetOutputTokens(),
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
Metadata: out,
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert token usage: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordTokenUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -192,17 +246,36 @@ func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUs
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "prompt_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("prompt", in.GetPrompt()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
ProviderResponseID: in.GetMsgId(),
|
||||
Prompt: in.GetPrompt(),
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
Metadata: out,
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert user prompt: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordPromptUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -215,6 +288,28 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "tool_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("tool", in.GetTool()),
|
||||
slog.F("input", in.GetInput()),
|
||||
slog.F("server_url", in.GetServerUrl()),
|
||||
slog.F("injected", in.GetInjected()),
|
||||
slog.F("invocation_error", in.GetInvocationError()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
@@ -224,12 +319,13 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
|
||||
Input: in.GetInput(),
|
||||
Injected: in.GetInjected(),
|
||||
InvocationError: sql.NullString{String: in.GetInvocationError(), Valid: in.InvocationError != nil},
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
Metadata: out,
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert tool usage: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordToolUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -433,24 +529,16 @@ func getCoderMCPServerConfig(experiments codersdk.Experiments, accessURL string)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// marshalMetadata attempts to marshal the given metadata map into a
|
||||
// JSON-encoded byte slice. If the marshaling fails, the function logs a
|
||||
// warning and returns nil. The supplied context is only used for logging.
|
||||
func marshalMetadata(ctx context.Context, logger slog.Logger, in map[string]*anypb.Any) []byte {
|
||||
mdMap := make(map[string]any, len(in))
|
||||
func metadataToMap(in map[string]*anypb.Any) map[string]any {
|
||||
meta := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
var sv structpb.Value
|
||||
if err := v.UnmarshalTo(&sv); err == nil {
|
||||
mdMap[k] = sv.AsInterface()
|
||||
meta[k] = sv.AsInterface()
|
||||
}
|
||||
}
|
||||
out, err := json.Marshal(mdMap)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
return meta
|
||||
}
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package aibridgedserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
)
|
||||
|
||||
func TestMarshalMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NilData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
out := marshalMetadata(context.Background(), logger, nil)
|
||||
require.JSONEq(t, "{}", string(out))
|
||||
})
|
||||
|
||||
t.Run("WithData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
list := structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
|
||||
structpb.NewStringValue("a"),
|
||||
structpb.NewNumberValue(1),
|
||||
structpb.NewBoolValue(false),
|
||||
}})
|
||||
obj := structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"a": structpb.NewStringValue("b"),
|
||||
"n": structpb.NewNumberValue(3),
|
||||
}})
|
||||
|
||||
nonValue := mustMarshalAny(t, &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"ignored": structpb.NewStringValue("yes"),
|
||||
}})
|
||||
invalid := &anypb.Any{TypeUrl: "type.googleapis.com/google.protobuf.Value", Value: []byte{0xff, 0x00}}
|
||||
|
||||
in := map[string]*anypb.Any{
|
||||
"null": mustMarshalAny(t, structpb.NewNullValue()),
|
||||
// Scalars
|
||||
"string": mustMarshalAny(t, structpb.NewStringValue("hello")),
|
||||
"bool": mustMarshalAny(t, structpb.NewBoolValue(true)),
|
||||
"number": mustMarshalAny(t, structpb.NewNumberValue(42)),
|
||||
// Complex types
|
||||
"list": mustMarshalAny(t, list),
|
||||
"object": mustMarshalAny(t, obj),
|
||||
// Extra valid entries
|
||||
"ok": mustMarshalAny(t, structpb.NewStringValue("present")),
|
||||
"nan": mustMarshalAny(t, structpb.NewNumberValue(math.NaN())),
|
||||
// Entries that should be ignored
|
||||
"invalid": invalid,
|
||||
"non_value": nonValue,
|
||||
}
|
||||
|
||||
out := marshalMetadata(context.Background(), logger, in)
|
||||
require.NotNil(t, out)
|
||||
var got map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &got))
|
||||
|
||||
expected := map[string]any{
|
||||
"string": "hello",
|
||||
"bool": true,
|
||||
"number": float64(42),
|
||||
"null": nil,
|
||||
"list": []any{"a", float64(1), false},
|
||||
"object": map[string]any{"a": "b", "n": float64(3)},
|
||||
"ok": "present",
|
||||
"nan": "NaN",
|
||||
}
|
||||
require.Equal(t, expected, got)
|
||||
})
|
||||
}
|
||||
|
||||
func mustMarshalAny(t testing.TB, m proto.Message) *anypb.Any {
|
||||
t.Helper()
|
||||
a, err := anypb.New(m)
|
||||
require.NoError(t, err)
|
||||
return a
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package aibridgedserver_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -20,6 +22,8 @@ import (
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
@@ -832,3 +836,279 @@ func mustMarshalAny(t *testing.T, msg protobufproto.Message) *anypb.Any {
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// logLine represents a parsed JSON log entry.
|
||||
type logLine struct {
|
||||
Msg string `json:"msg"`
|
||||
Level string `json:"level"`
|
||||
Fields map[string]any `json:"fields"`
|
||||
}
|
||||
|
||||
// parseLogLines parses JSON log lines from a buffer.
|
||||
func parseLogLines(buf *bytes.Buffer) []logLine {
|
||||
var lines []logLine
|
||||
scanner := bufio.NewScanner(buf)
|
||||
for scanner.Scan() {
|
||||
var line logLine
|
||||
if err := json.Unmarshal(scanner.Bytes(), &line); err == nil {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// getLogLinesWithMessage returns all log lines with the given message.
|
||||
func getLogLinesWithMessage(lines []logLine, msg string) []logLine {
|
||||
var result []logLine
|
||||
for _, line := range lines {
|
||||
if line.Msg == msg {
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestStructuredLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadataProto := map[string]*anypb.Any{
|
||||
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
structuredLogging bool
|
||||
expectedErr error
|
||||
setupMocks func(db *dbmock.MockStore, interceptionID uuid.UUID)
|
||||
recordFn func(srv *aibridgedserver.Server, ctx context.Context, interceptionID uuid.UUID) error
|
||||
expectedFields map[string]any
|
||||
}
|
||||
|
||||
interceptionID := uuid.UUID{1}
|
||||
initiatorID := uuid.UUID{2}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "RecordInterception_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
InitiatorID: initiatorID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
Metadata: metadataProto,
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_start",
|
||||
"interception_id": interceptionID.String(),
|
||||
"initiator_id": initiatorID.String(),
|
||||
"provider": "anthropic",
|
||||
"model": "claude-4-opus",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordInterception_does_not_log_when_disabled",
|
||||
structuredLogging: false,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
InitiatorID: initiatorID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: nil, // No log expected.
|
||||
},
|
||||
{
|
||||
name: "RecordInterception_log_on_db_error",
|
||||
structuredLogging: true,
|
||||
expectedErr: sql.ErrConnDone,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{}, sql.ErrConnDone)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
// Even though the database call errored, we must still write the logs.
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_start",
|
||||
"interception_id": interceptionID.String(),
|
||||
"initiator_id": initiatorID.String(),
|
||||
"provider": "anthropic",
|
||||
"model": "claude-4-opus",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordInterceptionEnded_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterceptionEnded(ctx, &proto.RecordInterceptionEndedRequest{
|
||||
Id: intcID.String(),
|
||||
EndedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_end",
|
||||
"interception_id": interceptionID.String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordTokenUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeTokenUsage{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordTokenUsage(ctx, &proto.RecordTokenUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "token_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"input_tokens": float64(100), // JSON numbers are float64.
|
||||
"output_tokens": float64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordPromptUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(database.AIBridgeUserPrompt{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordPromptUsage(ctx, &proto.RecordPromptUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
Prompt: "Hello, Claude!",
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "prompt_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"prompt": "Hello, Claude!",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordToolUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeToolUsage{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
ServerUrl: strPtr("https://api.example.com"),
|
||||
Tool: "read_file",
|
||||
Input: `{"path": "/etc/hosts"}`,
|
||||
Injected: true,
|
||||
InvocationError: strPtr("permission denied"),
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "tool_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"tool": "read_file",
|
||||
"input": `{"path": "/etc/hosts"}`,
|
||||
"injected": true,
|
||||
"invocation_error": "permission denied",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
buf := &bytes.Buffer{}
|
||||
logger := slog.Make(slogjson.Sink(buf)).Leveled(slog.LevelDebug)
|
||||
|
||||
tc.setupMocks(db, interceptionID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
|
||||
StructuredLogging: serpent.Bool(tc.structuredLogging),
|
||||
}, nil, requiredExperiments)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tc.recordFn(srv, ctx, interceptionID)
|
||||
if tc.expectedErr != nil {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
lines := parseLogLines(buf)
|
||||
if tc.expectedFields == nil {
|
||||
// No log expected (disabled or error case).
|
||||
require.Empty(t, lines)
|
||||
} else {
|
||||
matchedLines := getLogLinesWithMessage(lines, aibridgedserver.InterceptionLogMarker)
|
||||
require.Len(t, matchedLines, 1, "expected exactly one log line with message %q", aibridgedserver.InterceptionLogMarker)
|
||||
|
||||
fields := matchedLines[0].Fields
|
||||
for key, expected := range tc.expectedFields {
|
||||
require.Equal(t, expected, fields[key], "field %q mismatch", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +148,10 @@ AI BRIDGE OPTIONS:
|
||||
Maximum number of AI Bridge requests per second per replica. Set to 0
|
||||
to disable (unlimited).
|
||||
|
||||
--aibridge-structured-logging bool, $CODER_AIBRIDGE_STRUCTURED_LOGGING (default: false)
|
||||
Emit structured logs for AI Bridge interception records. Use this for
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file for AI Bridge Proxy.
|
||||
|
||||
@@ -361,6 +361,14 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
||||
|
||||
r.Get("/idpsync/available-fields", api.organizationIDPSyncClaimFields)
|
||||
r.Get("/idpsync/field-values", api.organizationIDPSyncClaimFieldValues)
|
||||
|
||||
r.Route("/workspace-sharing", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.RequireExperiment(api.AGPL.Experiments, codersdk.ExperimentWorkspaceSharing),
|
||||
)
|
||||
r.Get("/", api.workspaceSharingSettings)
|
||||
r.Patch("/", api.patchWorkspaceSharingSettings)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/rbac/rolestore"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Get workspace sharing settings for organization
|
||||
// @ID get-workspace-sharing-settings-for-organization
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Enterprise
|
||||
// @Param organization path string true "Organization ID" format(uuid)
|
||||
// @Success 200 {object} codersdk.WorkspaceSharingSettings
|
||||
// @Router /organizations/{organization}/settings/workspace-sharing [get]
|
||||
func (api *API) workspaceSharingSettings(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
org := httpmw.OrganizationParam(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionRead, org) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: org.WorkspaceSharingDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
// @Summary Update workspace sharing settings for organization
|
||||
// @ID update-workspace-sharing-settings-for-organization
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Accept json
|
||||
// @Tags Enterprise
|
||||
// @Param organization path string true "Organization ID" format(uuid)
|
||||
// @Param request body codersdk.WorkspaceSharingSettings true "Workspace sharing settings"
|
||||
// @Success 200 {object} codersdk.WorkspaceSharingSettings
|
||||
// @Router /organizations/{organization}/settings/workspace-sharing [patch]
|
||||
func (api *API) patchWorkspaceSharingSettings(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
org := httpmw.OrganizationParam(r)
|
||||
auditor := *api.AGPL.Auditor.Load()
|
||||
aReq, commitAudit := audit.InitRequest[database.Organization](rw, &audit.RequestParams{
|
||||
Audit: auditor,
|
||||
Log: api.Logger,
|
||||
Request: r,
|
||||
Action: database.AuditActionWrite,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
aReq.Old = org
|
||||
defer commitAudit()
|
||||
|
||||
if !api.Authorize(r, policy.ActionUpdate, org) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.WorkspaceSharingSettings
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
//nolint:gocritic // System context required to look up and reconcile the
|
||||
// organization-member system role; callers only need `organization:update`
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
// Serialize organization workspace-sharing updates with system role
|
||||
// reconciliation across coderd instances (e.g. during rolling restarts).
|
||||
// This prevents conflicting writes to the organization-member system role.
|
||||
// TODO(geokat): Consider finer-grained locks as we add more system roles.
|
||||
err := tx.AcquireLock(ctx, database.LockIDReconcileSystemRoles)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("acquire system roles reconciliation lock: %w", err)
|
||||
}
|
||||
|
||||
org, err = tx.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: req.SharingDisabled,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update organization workspace sharing settings: %w", err)
|
||||
}
|
||||
|
||||
role, err := database.ExpectOne(tx.CustomRoles(sysCtx, database.CustomRolesParams{
|
||||
LookupRoles: []database.NameOrganizationPair{
|
||||
{
|
||||
Name: rbac.RoleOrgMember(),
|
||||
OrganizationID: org.ID,
|
||||
},
|
||||
},
|
||||
// Satisfy linter that requires all fields to be set.
|
||||
OrganizationID: org.ID,
|
||||
ExcludeOrgRoles: false,
|
||||
IncludeSystemRoles: true,
|
||||
}))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get organization-member role: %w", err)
|
||||
}
|
||||
|
||||
_, _, err = rolestore.ReconcileOrgMemberRole(sysCtx, tx, role, req.SharingDisabled)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("reconcile organization-member role: %w", err)
|
||||
}
|
||||
|
||||
if req.SharingDisabled {
|
||||
err = tx.DeleteWorkspaceACLsByOrganization(sysCtx, org.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete workspace ACLs by organization: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating workspace sharing settings.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
aReq.New = org
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: org.WorkspaceSharingDisabled,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestWorkspaceSharingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DisabledDefaultsFalse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, first := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
|
||||
settings, err := memberClient.WorkspaceSharingSettings(ctx, first.OrganizationID.String())
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.SharingDisabled)
|
||||
})
|
||||
|
||||
t.Run("DisabledTogglePersists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, first := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
orgAdminClient, _ := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, rbac.ScopedRoleOrgAdmin(first.OrganizationID))
|
||||
settings, err := orgAdminClient.PatchWorkspaceSharingSettings(ctx, first.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, settings.SharingDisabled)
|
||||
|
||||
settings, err = orgAdminClient.WorkspaceSharingSettings(ctx, first.OrganizationID.String())
|
||||
require.NoError(t, err)
|
||||
require.True(t, settings.SharingDisabled)
|
||||
|
||||
settings, err = orgAdminClient.PatchWorkspaceSharingSettings(ctx, first.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.SharingDisabled)
|
||||
})
|
||||
|
||||
t.Run("UpdateAuthz", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, first := coderdenttest.New(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
|
||||
_, err := memberClient.PatchWorkspaceSharingSettings(ctx, first.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("AuditLog", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := audit.NewMock()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, first := coderdenttest.New(t, &coderdenttest.Options{
|
||||
AuditLogging: true,
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
Auditor: auditor,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureAuditLog: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
orgAdminClient, _ := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, rbac.ScopedRoleOrgAdmin(first.OrganizationID))
|
||||
auditor.ResetLogs()
|
||||
_, err := orgAdminClient.PatchWorkspaceSharingSettings(ctx, first.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, auditor.AuditLogs(), 1)
|
||||
alog := auditor.AuditLogs()[0]
|
||||
require.Equal(t, database.AuditActionWrite, alog.Action)
|
||||
require.Equal(t, database.ResourceTypeOrganization, alog.ResourceType)
|
||||
require.Equal(t, first.OrganizationID, alog.ResourceID)
|
||||
})
|
||||
|
||||
t.Run("ExperimentDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Note: NOT setting the experiment flag.
|
||||
client, first := coderdenttest.New(t, &coderdenttest.Options{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
memberClient, _ := coderdtest.CreateAnotherUser(t, client, first.OrganizationID)
|
||||
_, err := memberClient.WorkspaceSharingSettings(ctx, first.OrganizationID.String())
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "requires enabling")
|
||||
require.Contains(t, apiErr.Message, "workspace-sharing")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWorkspaceSharingDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ACLEndpointsForbidden", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, db, owner := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
})
|
||||
|
||||
workspaceOwnerClient, workspaceOwner := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
ws := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: workspaceOwner.ID,
|
||||
OrganizationID: owner.OrganizationID,
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
orgAdminClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.ScopedRoleOrgAdmin(owner.OrganizationID))
|
||||
_, err := orgAdminClient.PatchWorkspaceSharingSettings(ctx, owner.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reading the ACL list remains allowed even when workspace sharing is
|
||||
// disabled, but mutating it is forbidden.
|
||||
_, err = workspaceOwnerClient.WorkspaceACL(ctx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We don't allow mutating the ACL.
|
||||
assertSharingDisabled := func(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
require.Equal(t, "Workspace sharing is disabled for this organization.", apiErr.Message)
|
||||
}
|
||||
|
||||
err = workspaceOwnerClient.UpdateWorkspaceACL(ctx, ws.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
uuid.NewString(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
assertSharingDisabled(t, err)
|
||||
|
||||
err = workspaceOwnerClient.DeleteWorkspaceACL(ctx, ws.ID)
|
||||
assertSharingDisabled(t, err)
|
||||
})
|
||||
|
||||
t.Run("ACLsPurged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
|
||||
client, db, owner := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureTemplateRBAC: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
workspaceOwnerClient, workspaceOwner := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
_, sharedUser := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
|
||||
// Create a group to test group ACL purging.
|
||||
group := coderdtest.CreateGroup(t, client, owner.OrganizationID, "test-group")
|
||||
|
||||
ws := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: workspaceOwner.ID,
|
||||
OrganizationID: owner.OrganizationID,
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
// Set both user and group ACLs.
|
||||
err := workspaceOwnerClient.UpdateWorkspaceACL(ctx, ws.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedUser.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
GroupRoles: map[string]codersdk.WorkspaceRole{
|
||||
group.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
acl, err := workspaceOwnerClient.WorkspaceACL(ctx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, acl.Users, 1)
|
||||
require.Equal(t, sharedUser.ID, acl.Users[0].ID)
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, acl.Users[0].Role)
|
||||
require.Len(t, acl.Groups, 1)
|
||||
require.Equal(t, group.ID, acl.Groups[0].ID)
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, acl.Groups[0].Role)
|
||||
|
||||
orgAdminClient, _ := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID, rbac.ScopedRoleOrgAdmin(owner.OrganizationID))
|
||||
_, err = orgAdminClient.PatchWorkspaceSharingSettings(ctx, owner.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = orgAdminClient.PatchWorkspaceSharingSettings(ctx, owner.OrganizationID.String(), codersdk.WorkspaceSharingSettings{
|
||||
SharingDisabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both user and group ACLs are purged.
|
||||
acl, err = workspaceOwnerClient.WorkspaceACL(ctx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, acl.Users)
|
||||
require.Empty(t, acl.Groups)
|
||||
|
||||
// Verify ACLs can be set again after re-enabling sharing.
|
||||
err = workspaceOwnerClient.UpdateWorkspaceACL(ctx, ws.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedUser.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
acl, err = workspaceOwnerClient.WorkspaceACL(ctx, ws.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, acl.Users, 1)
|
||||
require.Equal(t, sharedUser.ID, acl.Users[0].ID)
|
||||
})
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/coder/coder/v2
|
||||
|
||||
go 1.24.10
|
||||
go 1.24.11
|
||||
|
||||
// Required until a v3 of chroma is created to lazily initialize all XML files.
|
||||
// None of our dependencies seem to use the registries anyways, so this
|
||||
@@ -437,7 +437,7 @@ require (
|
||||
go.opentelemetry.io/collector/pdata/pprofile v0.121.0 // indirect
|
||||
go.opentelemetry.io/collector/semconv v0.123.0 // indirect
|
||||
go.opentelemetry.io/contrib v1.19.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
)
|
||||
|
||||
type RequestMode string
|
||||
|
||||
const (
|
||||
RequestModeBridge RequestMode = "bridge"
|
||||
RequestModeDirect RequestMode = "direct"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Mode determines how requests are made.
|
||||
// "bridge": Create users in Coder and use their session tokens to make requests through AI Bridge.
|
||||
// "direct": Make requests directly to UpstreamURL without user creation.
|
||||
Mode RequestMode `json:"mode"`
|
||||
|
||||
// User is the configuration for the user to create.
|
||||
// Required in bridge mode.
|
||||
User createusers.Config `json:"user"`
|
||||
|
||||
// UpstreamURL is the URL to make requests to directly.
|
||||
// Only used in direct mode.
|
||||
UpstreamURL string `json:"upstream_url"`
|
||||
|
||||
// Provider is the API provider to use: "openai" or "anthropic".
|
||||
Provider string `json:"provider"`
|
||||
|
||||
// RequestCount is the number of requests to make per runner.
|
||||
RequestCount int `json:"request_count"`
|
||||
|
||||
// Stream indicates whether to use streaming requests.
|
||||
Stream bool `json:"stream"`
|
||||
|
||||
// RequestPayloadSize is the size in bytes of the request payload (user message content).
|
||||
// If 0, uses default message content.
|
||||
RequestPayloadSize int `json:"request_payload_size"`
|
||||
|
||||
// NumMessages is the number of messages to include in the conversation.
|
||||
// Messages alternate between user and assistant roles, always ending with user.
|
||||
// Must be greater than 0.
|
||||
NumMessages int `json:"num_messages"`
|
||||
|
||||
// HTTPTimeout is the timeout for individual HTTP requests to the upstream
|
||||
// provider. This is separate from the job timeout which controls the overall
|
||||
// test execution.
|
||||
HTTPTimeout time.Duration `json:"http_timeout"`
|
||||
|
||||
Metrics *Metrics `json:"-"`
|
||||
|
||||
// RequestBody is the pre-serialized JSON request body. This is generated
|
||||
// once by PrepareRequestBody and shared across all runners and requests.
|
||||
RequestBody []byte `json:"-"`
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
if c.Metrics == nil {
|
||||
return xerrors.New("metrics must be set")
|
||||
}
|
||||
|
||||
// Validate mode
|
||||
if c.Mode != RequestModeBridge && c.Mode != RequestModeDirect {
|
||||
return xerrors.New("mode must be either 'bridge' or 'direct'")
|
||||
}
|
||||
|
||||
if c.RequestCount <= 0 {
|
||||
return xerrors.New("request_count must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if c.Provider != "openai" && c.Provider != "anthropic" {
|
||||
return xerrors.New("provider must be either 'openai' or 'anthropic'")
|
||||
}
|
||||
|
||||
if c.Mode == RequestModeDirect {
|
||||
// In direct mode, UpstreamURL must be set.
|
||||
if c.UpstreamURL == "" {
|
||||
return xerrors.New("upstream_url must be set in direct mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// In bridge mode, User config is required.
|
||||
if c.User.OrganizationID == uuid.Nil {
|
||||
return xerrors.New("user organization_id must be set in bridge mode")
|
||||
}
|
||||
|
||||
if err := c.User.Validate(); err != nil {
|
||||
return xerrors.Errorf("user config: %w", err)
|
||||
}
|
||||
|
||||
if c.NumMessages <= 0 {
|
||||
return xerrors.New("num_messages must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) NewStrategy(client *codersdk.Client) requestModeStrategy {
|
||||
if c.Mode == RequestModeDirect {
|
||||
return newDirectStrategy(directStrategyConfig{
|
||||
UpstreamURL: c.UpstreamURL,
|
||||
})
|
||||
}
|
||||
|
||||
return newBridgeStrategy(bridgeStrategyConfig{
|
||||
Client: client,
|
||||
Provider: c.Provider,
|
||||
Metrics: c.Metrics,
|
||||
User: c.User,
|
||||
})
|
||||
}
|
||||
|
||||
// PrepareRequestBody generates the conversation and serializes the full request
|
||||
// body once. This should be called before creating Runners so that all runners
|
||||
// share the same pre-generated payload.
|
||||
func (c *Config) PrepareRequestBody() error {
|
||||
provider := NewProviderStrategy(c.Provider)
|
||||
model := provider.DefaultModel()
|
||||
|
||||
var formattedMessages []any
|
||||
if c.RequestPayloadSize > 0 {
|
||||
formattedMessages = generateConversation(provider, c.RequestPayloadSize, c.NumMessages)
|
||||
} else {
|
||||
messages := []message{{
|
||||
Role: "user",
|
||||
Content: "Hello from the bridge load generator.",
|
||||
}}
|
||||
formattedMessages = provider.formatMessages(messages)
|
||||
}
|
||||
|
||||
reqBody := provider.buildRequestBody(model, formattedMessages, c.Stream)
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
c.RequestBody = bodyBytes
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
bridgeErrors *prometheus.CounterVec
|
||||
bridgeRequests *prometheus.CounterVec
|
||||
bridgeDuration prometheus.Histogram
|
||||
bridgeTokensTotal *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func NewMetrics(reg prometheus.Registerer) *Metrics {
|
||||
if reg == nil {
|
||||
reg = prometheus.DefaultRegisterer
|
||||
}
|
||||
|
||||
errors := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_errors_total",
|
||||
Help: "Total number of bridge errors",
|
||||
}, []string{"action"})
|
||||
|
||||
requests := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_requests_total",
|
||||
Help: "Total number of bridge requests",
|
||||
}, []string{"status"})
|
||||
|
||||
duration := prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_request_duration_seconds",
|
||||
Help: "Duration of bridge requests in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
})
|
||||
|
||||
tokens := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coderd",
|
||||
Subsystem: "scaletest",
|
||||
Name: "bridge_response_tokens_total",
|
||||
Help: "Total number of tokens in bridge responses",
|
||||
}, []string{"type"})
|
||||
|
||||
reg.MustRegister(errors, requests, duration, tokens)
|
||||
|
||||
return &Metrics{
|
||||
bridgeErrors: errors,
|
||||
bridgeRequests: requests,
|
||||
bridgeDuration: duration,
|
||||
bridgeTokensTotal: tokens,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) AddError(action string) {
|
||||
m.bridgeErrors.WithLabelValues(action).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) AddRequest(status string) {
|
||||
m.bridgeRequests.WithLabelValues(status).Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) ObserveDuration(duration float64) {
|
||||
m.bridgeDuration.Observe(duration)
|
||||
}
|
||||
|
||||
func (m *Metrics) AddTokens(tokenType string, count int64) {
|
||||
m.bridgeTokensTotal.WithLabelValues(tokenType).Add(float64(count))
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProviderStrategy handles provider-specific message formatting for LLM APIs.
|
||||
type ProviderStrategy interface {
|
||||
DefaultModel() string
|
||||
formatMessages(messages []message) []any
|
||||
buildRequestBody(model string, messages []any, stream bool) map[string]any
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
func NewProviderStrategy(provider string) ProviderStrategy {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return &anthropicProvider{}
|
||||
default:
|
||||
return &openAIProvider{}
|
||||
}
|
||||
}
|
||||
|
||||
type openAIProvider struct{}
|
||||
|
||||
func (*openAIProvider) DefaultModel() string {
|
||||
return "gpt-4"
|
||||
}
|
||||
|
||||
func (*openAIProvider) formatMessages(messages []message) []any {
|
||||
formatted := make([]any, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
formatted = append(formatted, map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func (*openAIProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
|
||||
return map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
}
|
||||
|
||||
type anthropicProvider struct{}
|
||||
|
||||
func (*anthropicProvider) DefaultModel() string {
|
||||
return "claude-3-opus-20240229"
|
||||
}
|
||||
|
||||
func (*anthropicProvider) formatMessages(messages []message) []any {
|
||||
formatted := make([]any, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
formatted = append(formatted, map[string]any{
|
||||
"role": msg.Role,
|
||||
"content": []map[string]string{
|
||||
{
|
||||
"type": "text",
|
||||
"text": msg.Content,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func (*anthropicProvider) buildRequestBody(model string, messages []any, stream bool) map[string]any {
|
||||
return map[string]any{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1024,
|
||||
"stream": stream,
|
||||
}
|
||||
}
|
||||
|
||||
// generateConversation creates a conversation with alternating user/assistant
|
||||
// messages. The content is filled with repeated 'x' characters to reach
|
||||
// approximately the target size. The last message is always from "user" as
|
||||
// required by LLM APIs.
|
||||
func generateConversation(provider ProviderStrategy, targetSize int, numMessages int) []any {
|
||||
if targetSize <= 0 {
|
||||
return nil
|
||||
}
|
||||
if numMessages < 1 {
|
||||
numMessages = 1
|
||||
}
|
||||
|
||||
roles := []string{"user", "assistant"}
|
||||
messages := make([]message, numMessages)
|
||||
for i := range messages {
|
||||
messages[i].Role = roles[i%2]
|
||||
}
|
||||
// Ensure last message is from user (required for LLM APIs).
|
||||
if messages[len(messages)-1].Role != "user" {
|
||||
messages[len(messages)-1].Role = "user"
|
||||
}
|
||||
|
||||
overhead := measureJSONSize(provider.formatMessages(messages))
|
||||
|
||||
bytesPerMessage := targetSize - overhead
|
||||
if bytesPerMessage < 0 {
|
||||
bytesPerMessage = 0
|
||||
}
|
||||
|
||||
perMessage := bytesPerMessage / len(messages)
|
||||
remainder := bytesPerMessage % len(messages)
|
||||
|
||||
for i := range messages {
|
||||
size := perMessage
|
||||
if i == len(messages)-1 {
|
||||
size += remainder
|
||||
}
|
||||
messages[i].Content = strings.Repeat("x", size)
|
||||
}
|
||||
|
||||
return provider.formatMessages(messages)
|
||||
}
|
||||
|
||||
func measureJSONSize(v any) int {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
@@ -0,0 +1,391 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.14.0"
|
||||
"go.opentelemetry.io/otel/semconv/v1.14.0/httpconv"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/coder/v2/scaletest/loadtestutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type (
|
||||
tracingContextKey struct{}
|
||||
tracingContext struct {
|
||||
provider string
|
||||
model string
|
||||
stream bool
|
||||
requestNum int
|
||||
mode RequestMode
|
||||
}
|
||||
)
|
||||
|
||||
type tracingTransport struct {
|
||||
cfg Config
|
||||
underlying http.RoundTripper
|
||||
}
|
||||
|
||||
func newTracingTransport(cfg Config, underlying http.RoundTripper) *tracingTransport {
|
||||
if underlying == nil {
|
||||
underlying = http.DefaultTransport
|
||||
}
|
||||
return &tracingTransport{
|
||||
cfg: cfg,
|
||||
underlying: otelhttp.NewTransport(underlying),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
aibridgeCtx, hasAIBridgeCtx := req.Context().Value(tracingContextKey{}).(tracingContext)
|
||||
|
||||
resp, err := t.underlying.RoundTrip(req)
|
||||
|
||||
if hasAIBridgeCtx {
|
||||
ctx := req.Context()
|
||||
if resp != nil && resp.Request != nil {
|
||||
ctx = resp.Request.Context()
|
||||
}
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.IsRecording() {
|
||||
span.SetAttributes(
|
||||
attribute.String("aibridge.provider", aibridgeCtx.provider),
|
||||
attribute.String("aibridge.model", aibridgeCtx.model),
|
||||
attribute.Bool("aibridge.stream", aibridgeCtx.stream),
|
||||
attribute.Int("aibridge.request_num", aibridgeCtx.requestNum),
|
||||
attribute.String("aibridge.mode", string(aibridgeCtx.mode)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
client *codersdk.Client
|
||||
cfg Config
|
||||
strategy requestModeStrategy
|
||||
providerStrategy ProviderStrategy
|
||||
|
||||
clock quartz.Clock
|
||||
httpClient *http.Client
|
||||
|
||||
requestCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
totalDuration time.Duration
|
||||
totalTokens int64
|
||||
}
|
||||
|
||||
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
|
||||
httpTimeout := cfg.HTTPTimeout
|
||||
if httpTimeout <= 0 {
|
||||
httpTimeout = 30 * time.Second
|
||||
}
|
||||
return &Runner{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
strategy: cfg.NewStrategy(client),
|
||||
providerStrategy: NewProviderStrategy(cfg.Provider),
|
||||
clock: quartz.NewReal(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
Transport: newTracingTransport(cfg, http.DefaultTransport),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) WithClock(clock quartz.Clock) *Runner {
|
||||
r.clock = clock
|
||||
return r
|
||||
}
|
||||
|
||||
var (
|
||||
_ harness.Runnable = &Runner{}
|
||||
_ harness.Cleanable = &Runner{}
|
||||
_ harness.Collectable = &Runner{}
|
||||
)
|
||||
|
||||
func (r *Runner) Run(ctx context.Context, id string, logs io.Writer) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
logs = loadtestutil.NewSyncWriter(logs)
|
||||
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
|
||||
|
||||
requestURL, token, err := r.strategy.Setup(ctx, id, logs)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("strategy setup: %w", err)
|
||||
}
|
||||
|
||||
requestCount := r.cfg.RequestCount
|
||||
if requestCount <= 0 {
|
||||
requestCount = 1
|
||||
}
|
||||
|
||||
model := r.providerStrategy.DefaultModel()
|
||||
|
||||
logger.Info(ctx, "bridge runner is ready",
|
||||
slog.F("request_count", requestCount),
|
||||
slog.F("model", model),
|
||||
slog.F("stream", r.cfg.Stream),
|
||||
)
|
||||
|
||||
for i := 0; i < requestCount; i++ {
|
||||
if err := r.makeRequest(ctx, logger, requestURL, token, model, i); err != nil {
|
||||
logger.Warn(ctx, "bridge request failed",
|
||||
slog.F("request_num", i+1),
|
||||
slog.F("error_type", "request_failed"),
|
||||
slog.Error(err),
|
||||
)
|
||||
r.cfg.Metrics.AddError("request")
|
||||
r.cfg.Metrics.AddRequest("failure")
|
||||
r.failureCount++
|
||||
|
||||
// Continue making requests even if one fails
|
||||
continue
|
||||
}
|
||||
r.successCount++
|
||||
r.cfg.Metrics.AddRequest("success")
|
||||
r.requestCount++
|
||||
}
|
||||
|
||||
logger.Info(ctx, "bridge runner completed",
|
||||
slog.F("total_requests", r.requestCount),
|
||||
slog.F("success", r.successCount),
|
||||
slog.F("failure", r.failureCount),
|
||||
)
|
||||
|
||||
// Fail the run if any request failed
|
||||
if r.failureCount > 0 {
|
||||
return xerrors.Errorf("bridge runner failed: %d out of %d requests failed", r.failureCount, requestCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) makeRequest(ctx context.Context, logger slog.Logger, url, token, model string, requestNum int) error {
|
||||
start := r.clock.Now()
|
||||
|
||||
ctx = context.WithValue(ctx, tracingContextKey{}, tracingContext{
|
||||
provider: r.cfg.Provider,
|
||||
model: model,
|
||||
stream: r.cfg.Stream,
|
||||
requestNum: requestNum + 1,
|
||||
mode: r.cfg.Mode,
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(r.cfg.RequestBody))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "making bridge request",
|
||||
slog.F("url", url),
|
||||
slog.F("request_num", requestNum+1),
|
||||
slog.F("model", model),
|
||||
)
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
span := trace.SpanFromContext(req.Context())
|
||||
if span.IsRecording() {
|
||||
span.RecordError(err)
|
||||
}
|
||||
logger.Warn(ctx, "request failed during execution",
|
||||
slog.F("request_num", requestNum+1),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
span := trace.SpanFromContext(req.Context())
|
||||
if span.IsRecording() {
|
||||
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
|
||||
span.SetStatus(httpconv.ClientStatus(resp.StatusCode))
|
||||
}
|
||||
|
||||
duration := r.clock.Since(start)
|
||||
r.totalDuration += duration
|
||||
r.cfg.Metrics.ObserveDuration(duration.Seconds())
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err := xerrors.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if r.cfg.Stream {
|
||||
err := r.handleStreamingResponse(ctx, logger, resp)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.handleNonStreamingResponse(ctx, logger, resp)
|
||||
}
|
||||
|
||||
func (r *Runner) handleNonStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
if r.cfg.Provider == "anthropic" {
|
||||
return r.handleAnthropicResponse(ctx, logger, resp)
|
||||
}
|
||||
return r.handleOpenAIResponse(ctx, logger, resp)
|
||||
}
|
||||
|
||||
func (r *Runner) handleOpenAIResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
var response struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return xerrors.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(response.Choices) > 0 {
|
||||
assistantContent := response.Choices[0].Message.Content
|
||||
logger.Debug(ctx, "received response",
|
||||
slog.F("response_id", response.ID),
|
||||
slog.F("content_length", len(assistantContent)),
|
||||
)
|
||||
}
|
||||
|
||||
if response.Usage.TotalTokens > 0 {
|
||||
r.totalTokens += int64(response.Usage.TotalTokens)
|
||||
r.cfg.Metrics.AddTokens("input", int64(response.Usage.PromptTokens))
|
||||
r.cfg.Metrics.AddTokens("output", int64(response.Usage.CompletionTokens))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) handleAnthropicResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
var response struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return xerrors.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
var assistantContent string
|
||||
if len(response.Content) > 0 {
|
||||
assistantContent = response.Content[0].Text
|
||||
logger.Debug(ctx, "received response",
|
||||
slog.F("response_id", response.ID),
|
||||
slog.F("content_length", len(assistantContent)),
|
||||
)
|
||||
}
|
||||
|
||||
totalTokens := response.Usage.InputTokens + response.Usage.OutputTokens
|
||||
if totalTokens > 0 {
|
||||
r.totalTokens += int64(totalTokens)
|
||||
r.cfg.Metrics.AddTokens("input", int64(response.Usage.InputTokens))
|
||||
r.cfg.Metrics.AddTokens("output", int64(response.Usage.OutputTokens))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*Runner) handleStreamingResponse(ctx context.Context, logger slog.Logger, resp *http.Response) error {
|
||||
buf := make([]byte, 4096)
|
||||
totalRead := 0
|
||||
for {
|
||||
// Check for context cancellation before each read
|
||||
if ctx.Err() != nil {
|
||||
logger.Warn(ctx, "streaming response canceled",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(ctx.Err()),
|
||||
)
|
||||
return xerrors.Errorf("stream canceled: %w", ctx.Err())
|
||||
}
|
||||
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
totalRead += n
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Warn(ctx, "streaming response read canceled",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("stream read canceled: %w", err)
|
||||
}
|
||||
logger.Warn(ctx, "streaming response read error",
|
||||
slog.F("bytes_read", totalRead),
|
||||
slog.Error(err),
|
||||
)
|
||||
return xerrors.Errorf("read stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "received streaming response", slog.F("bytes_read", totalRead))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) Cleanup(ctx context.Context, id string, logs io.Writer) error {
|
||||
return r.strategy.Cleanup(ctx, id, logs)
|
||||
}
|
||||
|
||||
func (r *Runner) GetMetrics() map[string]any {
|
||||
avgDuration := time.Duration(0)
|
||||
if r.requestCount > 0 {
|
||||
avgDuration = r.totalDuration / time.Duration(r.requestCount)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"request_count": r.requestCount,
|
||||
"success_count": r.successCount,
|
||||
"failure_count": r.failureCount,
|
||||
"total_duration": r.totalDuration.String(),
|
||||
"avg_duration": avgDuration.String(),
|
||||
"total_tokens": r.totalTokens,
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user