Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ce2aed9002 |
@@ -4,7 +4,7 @@ description: |
|
||||
inputs:
|
||||
version:
|
||||
description: "The Go version to use."
|
||||
default: "1.24.11"
|
||||
default: "1.24.10"
|
||||
use-preinstalled-go:
|
||||
description: "Whether to use preinstalled Go."
|
||||
default: "false"
|
||||
|
||||
@@ -211,6 +211,14 @@ issues:
|
||||
- path: scripts/rules.go
|
||||
linters:
|
||||
- ALL
|
||||
# Boundary code is imported from github.com/coder/boundary and has different
|
||||
# lint standards. Suppress lint issues in this imported code.
|
||||
- path: enterprise/cli/boundary/
|
||||
linters:
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- errorlint
|
||||
|
||||
fix: true
|
||||
max-issues-per-linter: 0
|
||||
|
||||
@@ -69,9 +69,6 @@ MOST_GO_SRC_FILES := $(shell \
|
||||
# All the shell files in the repo, excluding ignored files.
|
||||
SHELL_SRC_FILES := $(shell find . $(FIND_EXCLUSIONS) -type f -name '*.sh')
|
||||
|
||||
MIGRATION_FILES := $(shell find ./coderd/database/migrations/ -maxdepth 1 $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
FIXTURE_FILES := $(shell find ./coderd/database/migrations/testdata/fixtures/ $(FIND_EXCLUSIONS) -type f -name '*.sql')
|
||||
|
||||
# Ensure we don't use the user's git configs which might cause side-effects
|
||||
GIT_FLAGS = GIT_CONFIG_GLOBAL=/dev/null GIT_CONFIG_SYSTEM=/dev/null
|
||||
|
||||
@@ -564,7 +561,7 @@ endif
|
||||
|
||||
# Note: we don't run zizmor in the lint target because it takes a while. CI
|
||||
# runs it explicitly.
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes lint/migrations
|
||||
lint: lint/shellcheck lint/go lint/ts lint/examples lint/helm lint/site-icons lint/markdown lint/actions/actionlint lint/check-scopes
|
||||
.PHONY: lint
|
||||
|
||||
lint/site-icons:
|
||||
@@ -622,12 +619,6 @@ lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
lint/migrations:
|
||||
./scripts/check_pg_schema.sh "Migrations" $(MIGRATION_FILES)
|
||||
./scripts/check_pg_schema.sh "Fixtures" $(FIXTURE_FILES)
|
||||
.PHONY: lint/migrations
|
||||
|
||||
# All files generated by the database should be added here, and this can be used
|
||||
# as a target for jobs that need to run after the database is generated.
|
||||
DB_GEN_FILES := \
|
||||
|
||||
+10
-4
@@ -1,12 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) boundary() *serpent.Command {
|
||||
cmd := boundarycli.BaseCommand() // Package coder/boundary/cli exports a "base command" designed to be integrated as a subcommand.
|
||||
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
|
||||
return cmd
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests (enterprise)",
|
||||
Long: `boundary creates an isolated network environment for target processes. This is an enterprise feature.`,
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
return xerrors.New("boundary is an enterprise feature; upgrade to use this command")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// Actually testing the functionality of coder/boundary takes place in the
|
||||
// coder/boundary repo, since it's a dependency of coder.
|
||||
// Here we want to test basically that integrating it as a subcommand doesn't break anything.
|
||||
// Here we want to test that integrating boundary as a subcommand doesn't break anything.
|
||||
// The full boundary functionality is tested in enterprise/cli.
|
||||
func TestBoundarySubcommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
@@ -27,7 +25,5 @@ func TestBoundarySubcommand(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Expect the --help output to include the short description.
|
||||
// We're simply confirming that `coder boundary --help` ran without a runtime error as
|
||||
// a good chunk of serpents self validation logic happens at runtime.
|
||||
pty.ExpectMatch(boundarycli.BaseCommand().Short)
|
||||
pty.ExpectMatch("Network isolation tool")
|
||||
}
|
||||
|
||||
@@ -68,8 +68,6 @@ func (r *RootCmd) scaletestCmd() *serpent.Command {
|
||||
r.scaletestTaskStatus(),
|
||||
r.scaletestSMTP(),
|
||||
r.scaletestPrebuilds(),
|
||||
r.scaletestBridge(),
|
||||
r.scaletestLLMMock(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/scaletest/bridge"
|
||||
"github.com/coder/coder/v2/scaletest/createusers"
|
||||
"github.com/coder/coder/v2/scaletest/harness"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) scaletestBridge() *serpent.Command {
|
||||
var (
|
||||
concurrentUsers int64
|
||||
noCleanup bool
|
||||
mode string
|
||||
upstreamURL string
|
||||
provider string
|
||||
requestsPerUser int64
|
||||
useStreamingAPI bool
|
||||
requestPayloadSize int64
|
||||
numMessages int64
|
||||
httpTimeout time.Duration
|
||||
|
||||
timeoutStrategy = &timeoutFlags{}
|
||||
cleanupStrategy = newScaletestCleanupStrategy()
|
||||
output = &scaletestOutputFlags{}
|
||||
prometheusFlags = &scaletestPrometheusFlags{}
|
||||
)
|
||||
|
||||
cmd := &serpent.Command{
|
||||
Use: "bridge",
|
||||
Short: "Generate load on the AI Bridge service.",
|
||||
Long: `Generate load for AI Bridge testing. Supports two modes: 'bridge' mode routes requests through the Coder AI Bridge, 'direct' mode makes requests directly to an upstream URL (useful for baseline comparisons).
|
||||
|
||||
Examples:
|
||||
# Test OpenAI API through bridge
|
||||
coder scaletest bridge --mode bridge --provider openai --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test Anthropic API through bridge
|
||||
coder scaletest bridge --mode bridge --provider anthropic --concurrent-users 10 --request-count 5 --num-messages 10
|
||||
|
||||
# Test directly against mock server
|
||||
coder scaletest bridge --mode direct --provider openai --upstream-url http://localhost:8080/v1/chat/completions
|
||||
`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
client, err := r.InitClient(inv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &codersdk.HeaderTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
Header: map[string][]string{
|
||||
codersdk.BypassRatelimitHeader: {"true"},
|
||||
},
|
||||
},
|
||||
}
|
||||
reg := prometheus.NewRegistry()
|
||||
metrics := bridge.NewMetrics(reg)
|
||||
|
||||
logger := inv.Logger
|
||||
prometheusSrvClose := ServeHandler(ctx, logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), prometheusFlags.Address, "prometheus")
|
||||
defer prometheusSrvClose()
|
||||
|
||||
defer func() {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Waiting %s for prometheus metrics to be scraped\n", prometheusFlags.Wait)
|
||||
<-time.After(prometheusFlags.Wait)
|
||||
}()
|
||||
|
||||
notifyCtx, stop := signal.NotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
ctx = notifyCtx
|
||||
|
||||
var userConfig createusers.Config
|
||||
if bridge.RequestMode(mode) == bridge.RequestModeBridge {
|
||||
me, err := requireAdmin(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(me.OrganizationIDs) == 0 {
|
||||
return xerrors.Errorf("admin user must have at least one organization")
|
||||
}
|
||||
userConfig = createusers.Config{
|
||||
OrganizationID: me.OrganizationIDs[0],
|
||||
}
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge mode: creating users and making requests through AI Bridge...")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stderr, "Direct mode: making requests directly to %s\n", upstreamURL)
|
||||
}
|
||||
|
||||
outputs, err := output.parse()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse output flags: %w", err)
|
||||
}
|
||||
|
||||
config := bridge.Config{
|
||||
Mode: bridge.RequestMode(mode),
|
||||
Metrics: metrics,
|
||||
Provider: provider,
|
||||
RequestCount: int(requestsPerUser),
|
||||
Stream: useStreamingAPI,
|
||||
RequestPayloadSize: int(requestPayloadSize),
|
||||
NumMessages: int(numMessages),
|
||||
HTTPTimeout: httpTimeout,
|
||||
UpstreamURL: upstreamURL,
|
||||
User: userConfig,
|
||||
}
|
||||
if err := config.Validate(); err != nil {
|
||||
return xerrors.Errorf("validate config: %w", err)
|
||||
}
|
||||
if err := config.PrepareRequestBody(); err != nil {
|
||||
return xerrors.Errorf("prepare request body: %w", err)
|
||||
}
|
||||
|
||||
th := harness.NewTestHarness(timeoutStrategy.wrapStrategy(harness.ConcurrentExecutionStrategy{}), cleanupStrategy.toStrategy())
|
||||
|
||||
for i := range concurrentUsers {
|
||||
id := strconv.Itoa(int(i))
|
||||
name := fmt.Sprintf("bridge-%s", id)
|
||||
var runner harness.Runnable = bridge.NewRunner(client, config)
|
||||
th.AddRun(name, id, runner)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Bridge scaletest configuration:")
|
||||
tw := tabwriter.NewWriter(inv.Stderr, 0, 0, 2, ' ', 0)
|
||||
for _, opt := range inv.Command.Options {
|
||||
if opt.Hidden || opt.ValueSource == serpent.ValueSourceNone {
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(tw, " %s:\t%s", opt.Name, opt.Value.String())
|
||||
if opt.ValueSource != serpent.ValueSourceDefault {
|
||||
_, _ = fmt.Fprintf(tw, "\t(from %s)", opt.ValueSource)
|
||||
}
|
||||
_, _ = fmt.Fprintln(tw)
|
||||
}
|
||||
_ = tw.Flush()
|
||||
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nRunning bridge scaletest...")
|
||||
testCtx, testCancel := timeoutStrategy.toContext(ctx)
|
||||
defer testCancel()
|
||||
err = th.Run(testCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("run test harness (harness failure, not a test failure): %w", err)
|
||||
}
|
||||
|
||||
// If the command was interrupted, skip stats.
|
||||
if notifyCtx.Err() != nil {
|
||||
return notifyCtx.Err()
|
||||
}
|
||||
|
||||
res := th.Results()
|
||||
|
||||
for _, o := range outputs {
|
||||
err = o.write(res, inv.Stdout)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCleanup {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...")
|
||||
cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx)
|
||||
defer cleanupCancel()
|
||||
err = th.Cleanup(cleanupCtx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cleanup tests: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if res.TotalFail > 0 {
|
||||
return xerrors.New("load test failed, see above for more details")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = serpent.OptionSet{
|
||||
{
|
||||
Flag: "concurrent-users",
|
||||
FlagShorthand: "c",
|
||||
Env: "CODER_SCALETEST_BRIDGE_CONCURRENT_USERS",
|
||||
Description: "Required: Number of concurrent users.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&concurrentUsers), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--concurrent-users must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Flag: "mode",
|
||||
Env: "CODER_SCALETEST_BRIDGE_MODE",
|
||||
Default: "direct",
|
||||
Description: "Request mode: 'bridge' (create users and use AI Bridge) or 'direct' (make requests directly to upstream-url).",
|
||||
Value: serpent.EnumOf(&mode, string(bridge.RequestModeBridge), string(bridge.RequestModeDirect)),
|
||||
},
|
||||
{
|
||||
Flag: "upstream-url",
|
||||
Env: "CODER_SCALETEST_BRIDGE_UPSTREAM_URL",
|
||||
Description: "URL to make requests to directly (required in direct mode, e.g., http://localhost:8080/v1/chat/completions).",
|
||||
Value: serpent.StringOf(&upstreamURL),
|
||||
},
|
||||
{
|
||||
Flag: "provider",
|
||||
Env: "CODER_SCALETEST_BRIDGE_PROVIDER",
|
||||
Default: "openai",
|
||||
Description: "API provider to use.",
|
||||
Value: serpent.EnumOf(&provider, "openai", "anthropic"),
|
||||
},
|
||||
{
|
||||
Flag: "request-count",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_COUNT",
|
||||
Default: "1",
|
||||
Description: "Number of sequential requests to make per runner.",
|
||||
Value: serpent.Validate(serpent.Int64Of(&requestsPerUser), func(value *serpent.Int64) error {
|
||||
if value == nil || value.Value() <= 0 {
|
||||
return xerrors.Errorf("--request-count must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
{
|
||||
Flag: "stream",
|
||||
Env: "CODER_SCALETEST_BRIDGE_STREAM",
|
||||
Description: "Enable streaming requests.",
|
||||
Value: serpent.BoolOf(&useStreamingAPI),
|
||||
},
|
||||
{
|
||||
Flag: "request-payload-size",
|
||||
Env: "CODER_SCALETEST_BRIDGE_REQUEST_PAYLOAD_SIZE",
|
||||
Default: "1024",
|
||||
Description: "Size in bytes of the request payload (user message content). If 0, uses default message content.",
|
||||
Value: serpent.Int64Of(&requestPayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "num-messages",
|
||||
Env: "CODER_SCALETEST_BRIDGE_NUM_MESSAGES",
|
||||
Default: "1",
|
||||
Description: "Number of messages to include in the conversation.",
|
||||
Value: serpent.Int64Of(&numMessages),
|
||||
},
|
||||
{
|
||||
Flag: "no-cleanup",
|
||||
Env: "CODER_SCALETEST_NO_CLEANUP",
|
||||
Description: "Do not clean up resources after the test completes.",
|
||||
Value: serpent.BoolOf(&noCleanup),
|
||||
},
|
||||
{
|
||||
Flag: "http-timeout",
|
||||
Env: "CODER_SCALETEST_BRIDGE_HTTP_TIMEOUT",
|
||||
Default: "30s",
|
||||
Description: "Timeout for individual HTTP requests to the upstream provider.",
|
||||
Value: serpent.DurationOf(&httpTimeout),
|
||||
},
|
||||
}
|
||||
|
||||
timeoutStrategy.attach(&cmd.Options)
|
||||
cleanupStrategy.attach(&cmd.Options)
|
||||
output.attach(&cmd.Options)
|
||||
prometheusFlags.attach(&cmd.Options)
|
||||
return cmd
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build !slim
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/scaletest/llmmock"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) scaletestLLMMock() *serpent.Command {
|
||||
var (
|
||||
address string
|
||||
artificialLatency time.Duration
|
||||
responsePayloadSize int64
|
||||
|
||||
pprofEnable bool
|
||||
pprofAddress string
|
||||
|
||||
traceEnable bool
|
||||
)
|
||||
cmd := &serpent.Command{
|
||||
Use: "llm-mock",
|
||||
Short: "Start a mock LLM API server for testing",
|
||||
Long: `Start a mock LLM API server that simulates OpenAI and Anthropic APIs`,
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx, stop := signal.NotifyContext(inv.Context(), StopSignals...)
|
||||
defer stop()
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelInfo)
|
||||
|
||||
if pprofEnable {
|
||||
closePprof := ServeHandler(ctx, logger, nil, pprofAddress, "pprof")
|
||||
defer closePprof()
|
||||
logger.Info(ctx, "pprof server started", slog.F("address", pprofAddress))
|
||||
}
|
||||
|
||||
config := llmmock.Config{
|
||||
Address: address,
|
||||
Logger: logger,
|
||||
ArtificialLatency: artificialLatency,
|
||||
ResponsePayloadSize: int(responsePayloadSize),
|
||||
PprofEnable: pprofEnable,
|
||||
PprofAddress: pprofAddress,
|
||||
TraceEnable: traceEnable,
|
||||
}
|
||||
srv := new(llmmock.Server)
|
||||
|
||||
if err := srv.Start(ctx, config); err != nil {
|
||||
return xerrors.Errorf("start mock LLM server: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = srv.Stop()
|
||||
}()
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Mock LLM API server started on %s\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " OpenAI endpoint: %s/v1/chat/completions\n", srv.APIAddress())
|
||||
_, _ = fmt.Fprintf(inv.Stdout, " Anthropic endpoint: %s/v1/messages\n", srv.APIAddress())
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Options = []serpent.Option{
|
||||
{
|
||||
Flag: "address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ADDRESS",
|
||||
Default: "localhost",
|
||||
Description: "Address to bind the mock LLM API server. Can include a port (e.g., 'localhost:8080' or ':8080'). Uses a random port if no port is specified.",
|
||||
Value: serpent.StringOf(&address),
|
||||
},
|
||||
{
|
||||
Flag: "artificial-latency",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_ARTIFICIAL_LATENCY",
|
||||
Default: "0s",
|
||||
Description: "Artificial latency to add to each response (e.g., 100ms, 1s). Simulates slow upstream processing.",
|
||||
Value: serpent.DurationOf(&artificialLatency),
|
||||
},
|
||||
{
|
||||
Flag: "response-payload-size",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_RESPONSE_PAYLOAD_SIZE",
|
||||
Default: "0",
|
||||
Description: "Size in bytes of the response payload. If 0, uses default context-aware responses.",
|
||||
Value: serpent.Int64Of(&responsePayloadSize),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Serve pprof metrics on the address defined by pprof-address.",
|
||||
Value: serpent.BoolOf(&pprofEnable),
|
||||
},
|
||||
{
|
||||
Flag: "pprof-address",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_PPROF_ADDRESS",
|
||||
Default: "127.0.0.1:6060",
|
||||
Description: "The bind address to serve pprof.",
|
||||
Value: serpent.StringOf(&pprofAddress),
|
||||
},
|
||||
{
|
||||
Flag: "trace-enable",
|
||||
Env: "CODER_SCALETEST_LLM_MOCK_TRACE_ENABLE",
|
||||
Default: "false",
|
||||
Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.",
|
||||
Value: serpent.BoolOf(&traceEnable),
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -65,22 +65,6 @@ func (r *RootCmd) organizationSettings(orgContext *OrganizationContext) *serpent
|
||||
return cli.OrganizationIDPSyncSettings(ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "workspace-sharing",
|
||||
Aliases: []string{"workspacesharing"},
|
||||
Short: "Workspace sharing settings for the organization.",
|
||||
Patch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID, input json.RawMessage) (any, error) {
|
||||
var req codersdk.WorkspaceSharingSettings
|
||||
err := json.Unmarshal(input, &req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshalling workspace sharing settings: %w", err)
|
||||
}
|
||||
return cli.PatchWorkspaceSharingSettings(ctx, org.String(), req)
|
||||
},
|
||||
Fetch: func(ctx context.Context, cli *codersdk.Client, org uuid.UUID) (any, error) {
|
||||
return cli.WorkspaceSharingSettings(ctx, org.String())
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd := &serpent.Command{
|
||||
Use: "settings",
|
||||
|
||||
@@ -15,7 +15,6 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,7 +15,6 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,7 +15,6 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
@@ -15,7 +15,6 @@ SUBCOMMANDS:
|
||||
memberships from an IdP.
|
||||
role-sync Role sync settings to sync organization roles from an
|
||||
IdP.
|
||||
workspace-sharing Workspace sharing settings for the organization.
|
||||
|
||||
———
|
||||
Run `coder --help` for a list of global options.
|
||||
|
||||
-4
@@ -147,10 +147,6 @@ AI BRIDGE OPTIONS:
|
||||
Maximum number of AI Bridge requests per second per replica. Set to 0
|
||||
to disable (unlimited).
|
||||
|
||||
--aibridge-structured-logging bool, $CODER_AIBRIDGE_STRUCTURED_LOGGING (default: false)
|
||||
Emit structured logs for AI Bridge interception records. Use this for
|
||||
exporting these records to external SIEM or observability systems.
|
||||
|
||||
AI BRIDGE PROXY OPTIONS:
|
||||
--aibridge-proxy-cert-file string, $CODER_AIBRIDGE_PROXY_CERT_FILE
|
||||
Path to the CA certificate file for AI Bridge Proxy.
|
||||
|
||||
-4
@@ -773,10 +773,6 @@ aibridge:
|
||||
# (unlimited).
|
||||
# (default: 0, type: int)
|
||||
rateLimit: 0
|
||||
# Emit structured logs for AI Bridge interception records. Use this for exporting
|
||||
# these records to external SIEM or observability systems.
|
||||
# (default: false, type: bool)
|
||||
structuredLogging: false
|
||||
aibridgeproxy:
|
||||
# Enable the AI Bridge MITM Proxy for intercepting and decrypting AI provider
|
||||
# requests.
|
||||
|
||||
Generated
+20
-165
@@ -2628,8 +2628,7 @@ const docTemplate = `{
|
||||
},
|
||||
{
|
||||
"enum": [
|
||||
"code",
|
||||
"token"
|
||||
"code"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
@@ -2684,8 +2683,7 @@ const docTemplate = `{
|
||||
},
|
||||
{
|
||||
"enum": [
|
||||
"code",
|
||||
"token"
|
||||
"code"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
@@ -2916,10 +2914,7 @@ const docTemplate = `{
|
||||
{
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
"refresh_token"
|
||||
],
|
||||
"type": "string",
|
||||
"description": "Grant type",
|
||||
@@ -4571,86 +4566,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/settings/workspace-sharing": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Enterprise"
|
||||
],
|
||||
"summary": "Get workspace sharing settings for organization",
|
||||
"operationId": "get-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patch": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Enterprise"
|
||||
],
|
||||
"summary": "Update workspace sharing settings for organization",
|
||||
"operationId": "update-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Workspace sharing settings",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/templates": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -12055,9 +11970,6 @@ const docTemplate = `{
|
||||
},
|
||||
"retention": {
|
||||
"type": "integer"
|
||||
},
|
||||
"structured_logging": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -15849,13 +15761,13 @@ const docTemplate = `{
|
||||
"code_challenge_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2PKCECodeChallengeMethod"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"grant_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"issuer": {
|
||||
@@ -15867,7 +15779,7 @@ const docTemplate = `{
|
||||
"response_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"revocation_endpoint": {
|
||||
@@ -15885,7 +15797,7 @@ const docTemplate = `{
|
||||
"token_endpoint_auth_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15917,7 +15829,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -15939,7 +15851,10 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "string"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -15947,7 +15862,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -15960,7 +15875,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -15985,7 +15900,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -16009,7 +15924,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -16025,7 +15940,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -16062,7 +15977,7 @@ const docTemplate = `{
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -16092,7 +16007,7 @@ const docTemplate = `{
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -16105,7 +16020,7 @@ const docTemplate = `{
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -16158,17 +16073,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2PKCECodeChallengeMethod": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"S256",
|
||||
"plain"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2PKCECodeChallengeMethodS256",
|
||||
"OAuth2PKCECodeChallengeMethodPlain"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProtectedResourceMetadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -16248,47 +16152,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2ProviderGrantType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderGrantTypeAuthorizationCode",
|
||||
"OAuth2ProviderGrantTypeRefreshToken",
|
||||
"OAuth2ProviderGrantTypePassword",
|
||||
"OAuth2ProviderGrantTypeClientCredentials",
|
||||
"OAuth2ProviderGrantTypeImplicit"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProviderResponseType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"code",
|
||||
"token"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderResponseTypeCode",
|
||||
"OAuth2ProviderResponseTypeToken"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2TokenEndpointAuthMethod": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretBasic",
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretPost",
|
||||
"OAuth2TokenEndpointAuthMethodNone"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuthConversionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21588,14 +21451,6 @@ const docTemplate = `{
|
||||
"WorkspaceRoleDeleted"
|
||||
]
|
||||
},
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sharing_disabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.WorkspaceStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+20
-146
@@ -2304,7 +2304,7 @@
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"enum": ["code", "token"],
|
||||
"enum": ["code"],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
"name": "response_type",
|
||||
@@ -2355,7 +2355,7 @@
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"enum": ["code", "token"],
|
||||
"enum": ["code"],
|
||||
"type": "string",
|
||||
"description": "Response type",
|
||||
"name": "response_type",
|
||||
@@ -2555,13 +2555,7 @@
|
||||
"in": "formData"
|
||||
},
|
||||
{
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"enum": ["authorization_code", "refresh_token"],
|
||||
"type": "string",
|
||||
"description": "Grant type",
|
||||
"name": "grant_type",
|
||||
@@ -4042,76 +4036,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/settings/workspace-sharing": {
|
||||
"get": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Get workspace sharing settings for organization",
|
||||
"operationId": "get-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patch": {
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Enterprise"],
|
||||
"summary": "Update workspace sharing settings for organization",
|
||||
"operationId": "update-workspace-sharing-settings-for-organization",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"format": "uuid",
|
||||
"description": "Organization ID",
|
||||
"name": "organization",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Workspace sharing settings",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceSharingSettings"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/organizations/{organization}/templates": {
|
||||
"get": {
|
||||
"security": [
|
||||
@@ -10707,9 +10631,6 @@
|
||||
},
|
||||
"retention": {
|
||||
"type": "integer"
|
||||
},
|
||||
"structured_logging": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14359,13 +14280,13 @@
|
||||
"code_challenge_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2PKCECodeChallengeMethod"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"grant_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"issuer": {
|
||||
@@ -14377,7 +14298,7 @@
|
||||
"response_types_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"revocation_endpoint": {
|
||||
@@ -14395,7 +14316,7 @@
|
||||
"token_endpoint_auth_methods_supported": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14427,7 +14348,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14449,7 +14370,10 @@
|
||||
}
|
||||
},
|
||||
"registration_access_token": {
|
||||
"type": "string"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"registration_client_uri": {
|
||||
"type": "string"
|
||||
@@ -14457,7 +14381,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14470,7 +14394,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14495,7 +14419,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14519,7 +14443,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14535,7 +14459,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14572,7 +14496,7 @@
|
||||
"grant_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderGrantType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"jwks": {
|
||||
@@ -14602,7 +14526,7 @@
|
||||
"response_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2ProviderResponseType"
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"scope": {
|
||||
@@ -14615,7 +14539,7 @@
|
||||
"type": "string"
|
||||
},
|
||||
"token_endpoint_auth_method": {
|
||||
"$ref": "#/definitions/codersdk.OAuth2TokenEndpointAuthMethod"
|
||||
"type": "string"
|
||||
},
|
||||
"tos_uri": {
|
||||
"type": "string"
|
||||
@@ -14668,14 +14592,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2PKCECodeChallengeMethod": {
|
||||
"type": "string",
|
||||
"enum": ["S256", "plain"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2PKCECodeChallengeMethodS256",
|
||||
"OAuth2PKCECodeChallengeMethodPlain"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProtectedResourceMetadata": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -14755,40 +14671,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.OAuth2ProviderGrantType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
"password",
|
||||
"client_credentials",
|
||||
"implicit"
|
||||
],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderGrantTypeAuthorizationCode",
|
||||
"OAuth2ProviderGrantTypeRefreshToken",
|
||||
"OAuth2ProviderGrantTypePassword",
|
||||
"OAuth2ProviderGrantTypeClientCredentials",
|
||||
"OAuth2ProviderGrantTypeImplicit"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2ProviderResponseType": {
|
||||
"type": "string",
|
||||
"enum": ["code", "token"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2ProviderResponseTypeCode",
|
||||
"OAuth2ProviderResponseTypeToken"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuth2TokenEndpointAuthMethod": {
|
||||
"type": "string",
|
||||
"enum": ["client_secret_basic", "client_secret_post", "none"],
|
||||
"x-enum-varnames": [
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretBasic",
|
||||
"OAuth2TokenEndpointAuthMethodClientSecretPost",
|
||||
"OAuth2TokenEndpointAuthMethodNone"
|
||||
]
|
||||
},
|
||||
"codersdk.OAuthConversionResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19848,14 +19730,6 @@
|
||||
"WorkspaceRoleDeleted"
|
||||
]
|
||||
},
|
||||
"codersdk.WorkspaceSharingSettings": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sharing_disabled": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.WorkspaceStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
+4
-7
@@ -205,7 +205,7 @@ type Options struct {
|
||||
// tokens issued by and passed to the coordinator DRPC API.
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
HealthcheckRefresh time.Duration
|
||||
WorkspaceProxiesFetchUpdater *atomic.Pointer[healthcheck.WorkspaceProxiesFetchUpdater]
|
||||
@@ -681,7 +681,7 @@ func New(options *Options) *API {
|
||||
}
|
||||
|
||||
if options.HealthcheckFunc == nil {
|
||||
options.HealthcheckFunc = func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
options.HealthcheckFunc = func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
// NOTE: dismissed healthchecks are marked in formatHealthcheck.
|
||||
// Not here, as this result gets cached.
|
||||
return healthcheck.Run(ctx, &healthcheck.ReportOptions{
|
||||
@@ -709,7 +709,6 @@ func New(options *Options) *API {
|
||||
StaleInterval: provisionerdserver.StaleInterval,
|
||||
// TimeNow set to default, see healthcheck/provisioner.go
|
||||
},
|
||||
Progress: progress,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -882,7 +881,6 @@ func New(options *Options) *API {
|
||||
loggermw.Logger(api.Logger),
|
||||
singleSlashMW,
|
||||
rolestore.CustomRoleMW,
|
||||
httpmw.HTTPRoute, // NB: prometheusMW depends on this middleware.
|
||||
prometheusMW,
|
||||
// Build-Version is helpful for debugging.
|
||||
func(next http.Handler) http.Handler {
|
||||
@@ -1861,9 +1859,8 @@ type API struct {
|
||||
// This is used to gate features that are not yet ready for production.
|
||||
Experiments codersdk.Experiments
|
||||
|
||||
healthCheckGroup *singleflight.Group[string, *healthsdk.HealthcheckReport]
|
||||
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
|
||||
healthCheckProgress healthcheck.Progress
|
||||
healthCheckGroup *singleflight.Group[string, *healthsdk.HealthcheckReport]
|
||||
healthCheckCache atomic.Pointer[healthsdk.HealthcheckReport]
|
||||
|
||||
statsReporter *workspacestats.Reporter
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/jobreaper"
|
||||
"github.com/coder/coder/v2/coderd/notifications"
|
||||
@@ -132,7 +131,7 @@ type Options struct {
|
||||
CoordinatorResumeTokenProvider tailnet.ResumeTokenProvider
|
||||
ConnectionLogger connectionlog.ConnectionLogger
|
||||
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport
|
||||
HealthcheckFunc func(ctx context.Context, apiKey string) *healthsdk.HealthcheckReport
|
||||
HealthcheckTimeout time.Duration
|
||||
HealthcheckRefresh time.Duration
|
||||
|
||||
|
||||
@@ -1965,14 +1965,6 @@ func (q *querier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) erro
|
||||
return fetchAndExec(q.log, q.auth, policy.ActionShare, fetch, q.db.DeleteWorkspaceACLByID)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
// This is a system-only function.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteWorkspaceACLsByOrganization(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
w, err := q.db.GetWorkspaceByID(ctx, arg.WorkspaceID)
|
||||
if err != nil {
|
||||
@@ -3600,7 +3592,7 @@ func (q *querier) GetWorkspaceACLByID(ctx context.Context, id uuid.UUID) (databa
|
||||
if err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, workspace); err != nil {
|
||||
if err := q.authorizeContext(ctx, policy.ActionShare, workspace); err != nil {
|
||||
return database.GetWorkspaceACLByIDRow{}, err
|
||||
}
|
||||
return q.db.GetWorkspaceACLByID(ctx, id)
|
||||
@@ -5107,13 +5099,6 @@ func (q *querier) UpdateOrganizationDeletedByID(ctx context.Context, arg databas
|
||||
return deleteQ(q.log, q.auth, q.db.GetOrganizationByID, deleteF)(ctx, arg.ID)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
fetch := func(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
return q.db.GetOrganizationByID(ctx, arg.ID)
|
||||
}
|
||||
return updateWithReturn(q.log, q.auth, fetch, q.db.UpdateOrganizationWorkspaceSharingSettings)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
// Prebuild operation for canceling pending prebuild jobs from non-active template versions
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourcePrebuiltWorkspace); err != nil {
|
||||
|
||||
@@ -880,16 +880,6 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
dbm.EXPECT().InsertOrganization(gomock.Any(), arg).Return(database.Organization{ID: arg.ID, Name: arg.Name}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceOrganization, policy.ActionCreate)
|
||||
}))
|
||||
s.Run("UpdateOrganizationWorkspaceSharingSettings", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
org := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
arg := database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
}
|
||||
dbm.EXPECT().GetOrganizationByID(gomock.Any(), org.ID).Return(org, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateOrganizationWorkspaceSharingSettings(gomock.Any(), arg).Return(org, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(org, policy.ActionUpdate).Returns(org)
|
||||
}))
|
||||
s.Run("InsertOrganizationMember", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
o := testutil.Fake(s.T(), faker, database.Organization{})
|
||||
u := testutil.Fake(s.T(), faker, database.User{})
|
||||
@@ -1794,7 +1784,7 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
ws := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
dbM.EXPECT().GetWorkspaceByID(gomock.Any(), ws.ID).Return(ws, nil).AnyTimes()
|
||||
dbM.EXPECT().GetWorkspaceACLByID(gomock.Any(), ws.ID).Return(database.GetWorkspaceACLByIDRow{}, nil).AnyTimes()
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionRead)
|
||||
check.Args(ws.ID).Asserts(ws, policy.ActionShare)
|
||||
}))
|
||||
s.Run("UpdateWorkspaceACLByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
@@ -1809,11 +1799,6 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().DeleteWorkspaceACLByID(gomock.Any(), w.ID).Return(nil).AnyTimes()
|
||||
check.Args(w.ID).Asserts(w, policy.ActionShare)
|
||||
}))
|
||||
s.Run("DeleteWorkspaceACLsByOrganization", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
orgID := uuid.New()
|
||||
dbm.EXPECT().DeleteWorkspaceACLsByOrganization(gomock.Any(), orgID).Return(nil).AnyTimes()
|
||||
check.Args(orgID).Asserts(rbac.ResourceSystem, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
b := testutil.Fake(s.T(), faker, database.WorkspaceBuild{WorkspaceID: w.ID})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1055,20 +1055,6 @@ func (mr *MockStoreMockRecorder) DeleteWorkspaceACLByID(ctx, id any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLByID", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLByID), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteWorkspaceACLsByOrganization", ctx, organizationID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteWorkspaceACLsByOrganization indicates an expected call of DeleteWorkspaceACLsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) DeleteWorkspaceACLsByOrganization(ctx, organizationID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkspaceACLsByOrganization", reflect.TypeOf((*MockStore)(nil).DeleteWorkspaceACLsByOrganization), ctx, organizationID)
|
||||
}
|
||||
|
||||
// DeleteWorkspaceAgentPortShare mocks base method.
|
||||
func (m *MockStore) DeleteWorkspaceAgentPortShare(ctx context.Context, arg database.DeleteWorkspaceAgentPortShareParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6659,21 +6645,6 @@ func (mr *MockStoreMockRecorder) UpdateOrganizationDeletedByID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationDeletedByID", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationDeletedByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateOrganizationWorkspaceSharingSettings mocks base method.
|
||||
func (m *MockStore) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg database.UpdateOrganizationWorkspaceSharingSettingsParams) (database.Organization, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateOrganizationWorkspaceSharingSettings", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Organization)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateOrganizationWorkspaceSharingSettings indicates an expected call of UpdateOrganizationWorkspaceSharingSettings.
|
||||
func (mr *MockStoreMockRecorder) UpdateOrganizationWorkspaceSharingSettings(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOrganizationWorkspaceSharingSettings", reflect.TypeOf((*MockStore)(nil).UpdateOrganizationWorkspaceSharingSettings), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdatePrebuildProvisionerJobWithCancel mocks base method.
|
||||
func (m *MockStore) UpdatePrebuildProvisionerJobWithCancel(ctx context.Context, arg database.UpdatePrebuildProvisionerJobWithCancelParams) ([]database.UpdatePrebuildProvisionerJobWithCancelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
+685
-685
File diff suppressed because one or more lines are too long
@@ -1,34 +1,34 @@
|
||||
-- This is a deleted user that shares the same username and linked_id as the existing user below.
|
||||
-- Any future migrations need to handle this case.
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'githubuser@coder.com', 'githubuser', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', true) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO organization_members VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO public.organization_members VALUES ('a0061a8e-7db7-4585-838c-3116a003dd21', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('a0061a8e-7db7-4585-838c-3116a003dd21', 'github', '100', '');
|
||||
|
||||
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'githubuser@coder.com', 'githubuser', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO organization_members VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO public.organization_members VALUES ('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'github', '100', '');
|
||||
|
||||
-- Additionally, there is no unique constraint on user_id. So also add another user_link for the same user.
|
||||
-- This has happened on a production database.
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('fc1511ef-4fcf-4a3b-98a1-8df64160e35a', 'oidc', 'foo', '');
|
||||
|
||||
|
||||
-- Lastly, make 2 other users who have the same user link.
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'dup_link_a@coder.com', 'dupe_a', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO organization_members VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO public.organization_members VALUES ('580ed397-727d-4aaf-950a-51f89f556c24', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('580ed397-727d-4aaf-950a-51f89f556c24', 'github', '500', '');
|
||||
|
||||
|
||||
INSERT INTO users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
INSERT INTO public.users(id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, deleted)
|
||||
VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'dup_link_b@coder.com', 'dupe_b', '\x', '2022-11-02 13:05:21.445455+02', '2022-11-02 13:05:21.445455+02', 'active', '{}', false) ON CONFLICT DO NOTHING;
|
||||
INSERT INTO organization_members VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
INSERT INTO public.organization_members VALUES ('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', '2022-11-02 13:05:21.447595+02', '2022-11-02 13:05:21.447595+02', '{}') ON CONFLICT DO NOTHING;
|
||||
INSERT INTO public.user_links(user_id, login_type, linked_id, oauth_access_token)
|
||||
VALUES('c813366b-2fde-45ae-920c-101c3ad6a1e1', 'github', '500', '');
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
INSERT INTO workspace_app_stats (
|
||||
INSERT INTO public.workspace_app_stats (
|
||||
id,
|
||||
user_id,
|
||||
workspace_id,
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
INSERT INTO
|
||||
workspace_modules (
|
||||
public.workspace_modules (
|
||||
id,
|
||||
job_id,
|
||||
transition,
|
||||
|
||||
+8
-8
@@ -1,15 +1,15 @@
|
||||
INSERT INTO organizations (id, name, description, created_at, updated_at, is_default, display_name, icon) VALUES ('20362772-802a-4a72-8e4f-3648b4bfd168', 'strange_hopper58', 'wizardly_stonebraker60', '2025-02-07 07:46:19.507551 +00:00', '2025-02-07 07:46:19.507552 +00:00', false, 'competent_rhodes59', '');
|
||||
INSERT INTO public.organizations (id, name, description, created_at, updated_at, is_default, display_name, icon) VALUES ('20362772-802a-4a72-8e4f-3648b4bfd168', 'strange_hopper58', 'wizardly_stonebraker60', '2025-02-07 07:46:19.507551 +00:00', '2025-02-07 07:46:19.507552 +00:00', false, 'competent_rhodes59', '');
|
||||
|
||||
INSERT INTO users (id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, quiet_hours_schedule, theme_preference, name, github_com_user_id, hashed_one_time_passcode, one_time_passcode_expires_at) VALUES ('6c353aac-20de-467b-bdfb-3c30a37adcd2', 'vigorous_murdock61', 'affectionate_hawking62', 'lqTu9C5363AwD7NVNH6noaGjp91XIuZJ', '2025-02-07 07:46:19.510861 +00:00', '2025-02-07 07:46:19.512949 +00:00', 'active', '{}', 'password', '', false, '0001-01-01 00:00:00.000000', '', '', 'vigilant_hugle63', null, null, null);
|
||||
INSERT INTO public.users (id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, quiet_hours_schedule, theme_preference, name, github_com_user_id, hashed_one_time_passcode, one_time_passcode_expires_at) VALUES ('6c353aac-20de-467b-bdfb-3c30a37adcd2', 'vigorous_murdock61', 'affectionate_hawking62', 'lqTu9C5363AwD7NVNH6noaGjp91XIuZJ', '2025-02-07 07:46:19.510861 +00:00', '2025-02-07 07:46:19.512949 +00:00', 'active', '{}', 'password', '', false, '0001-01-01 00:00:00.000000', '', '', 'vigilant_hugle63', null, null, null);
|
||||
|
||||
INSERT INTO templates (id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level) VALUES ('6b298946-7a4f-47ac-9158-b03b08740a41', '2025-02-07 07:46:19.513317 +00:00', '2025-02-07 07:46:19.513317 +00:00', '20362772-802a-4a72-8e4f-3648b4bfd168', false, 'modest_leakey64', 'echo', 'e6cfa2a4-e4cf-4182-9e19-08b975682a28', 'upbeat_wright65', 604800000000000, '6c353aac-20de-467b-bdfb-3c30a37adcd2', 'nervous_keller66', '{}', '{"20362772-802a-4a72-8e4f-3648b4bfd168": ["read", "use"]}', 'determined_aryabhata67', false, true, true, 0, 0, 0, 0, 0, 0, false, '', 3600000000000, 'owner');
|
||||
INSERT INTO template_versions (id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id) VALUES ('af58bd62-428c-4c33-849b-d43a3be07d93', '6b298946-7a4f-47ac-9158-b03b08740a41', '20362772-802a-4a72-8e4f-3648b4bfd168', '2025-02-07 07:46:19.514782 +00:00', '2025-02-07 07:46:19.514782 +00:00', 'distracted_shockley68', 'sleepy_turing69', 'f2e2ea1c-5aa3-4a1d-8778-2e5071efae59', '6c353aac-20de-467b-bdfb-3c30a37adcd2', '[]', '', false, null);
|
||||
INSERT INTO public.templates (id, created_at, updated_at, organization_id, deleted, name, provisioner, active_version_id, description, default_ttl, created_by, icon, user_acl, group_acl, display_name, allow_user_cancel_workspace_jobs, allow_user_autostart, allow_user_autostop, failure_ttl, time_til_dormant, time_til_dormant_autodelete, autostop_requirement_days_of_week, autostop_requirement_weeks, autostart_block_days_of_week, require_active_version, deprecated, activity_bump, max_port_sharing_level) VALUES ('6b298946-7a4f-47ac-9158-b03b08740a41', '2025-02-07 07:46:19.513317 +00:00', '2025-02-07 07:46:19.513317 +00:00', '20362772-802a-4a72-8e4f-3648b4bfd168', false, 'modest_leakey64', 'echo', 'e6cfa2a4-e4cf-4182-9e19-08b975682a28', 'upbeat_wright65', 604800000000000, '6c353aac-20de-467b-bdfb-3c30a37adcd2', 'nervous_keller66', '{}', '{"20362772-802a-4a72-8e4f-3648b4bfd168": ["read", "use"]}', 'determined_aryabhata67', false, true, true, 0, 0, 0, 0, 0, 0, false, '', 3600000000000, 'owner');
|
||||
INSERT INTO public.template_versions (id, template_id, organization_id, created_at, updated_at, name, readme, job_id, created_by, external_auth_providers, message, archived, source_example_id) VALUES ('af58bd62-428c-4c33-849b-d43a3be07d93', '6b298946-7a4f-47ac-9158-b03b08740a41', '20362772-802a-4a72-8e4f-3648b4bfd168', '2025-02-07 07:46:19.514782 +00:00', '2025-02-07 07:46:19.514782 +00:00', 'distracted_shockley68', 'sleepy_turing69', 'f2e2ea1c-5aa3-4a1d-8778-2e5071efae59', '6c353aac-20de-467b-bdfb-3c30a37adcd2', '[]', '', false, null);
|
||||
|
||||
INSERT INTO template_version_presets (id, template_version_id, name, created_at) VALUES ('28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'af58bd62-428c-4c33-849b-d43a3be07d93', 'test', '0001-01-01 00:00:00.000000 +00:00');
|
||||
INSERT INTO public.template_version_presets (id, template_version_id, name, created_at) VALUES ('28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'af58bd62-428c-4c33-849b-d43a3be07d93', 'test', '0001-01-01 00:00:00.000000 +00:00');
|
||||
|
||||
-- Add presets with the same template version ID and name
|
||||
-- to ensure they're correctly handled by the 00031*_preset_prebuilds migration.
|
||||
INSERT INTO template_version_presets (
|
||||
INSERT INTO public.template_version_presets (
|
||||
id, template_version_id, name, created_at
|
||||
)
|
||||
VALUES (
|
||||
@@ -19,7 +19,7 @@ VALUES (
|
||||
'0001-01-01 00:00:00.000000 +00:00'
|
||||
);
|
||||
|
||||
INSERT INTO template_version_presets (
|
||||
INSERT INTO public.template_version_presets (
|
||||
id, template_version_id, name, created_at
|
||||
)
|
||||
VALUES (
|
||||
@@ -29,4 +29,4 @@ VALUES (
|
||||
'0001-01-01 00:00:00.000000 +00:00'
|
||||
);
|
||||
|
||||
INSERT INTO template_version_preset_parameters (id, template_version_preset_id, name, value) VALUES ('ea90ccd2-5024-459e-87e4-879afd24de0f', '28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'test', 'test');
|
||||
INSERT INTO public.template_version_preset_parameters (id, template_version_preset_id, name, value) VALUES ('ea90ccd2-5024-459e-87e4-879afd24de0f', '28b42cc0-c4fe-4907-a0fe-e4d20f1e9bfe', 'test', 'test');
|
||||
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
INSERT INTO tasks VALUES (
|
||||
INSERT INTO public.tasks VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- id
|
||||
'bb640d07-ca8a-4869-b6bc-ae61ebb2fda1', -- organization_id
|
||||
'30095c71-380b-457a-8995-97b8ee6e5307', -- owner_id
|
||||
@@ -11,7 +11,7 @@ INSERT INTO tasks VALUES (
|
||||
NULL -- deleted_at
|
||||
) ON CONFLICT DO NOTHING;
|
||||
|
||||
INSERT INTO task_workspace_apps VALUES (
|
||||
INSERT INTO public.task_workspace_apps VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- task_id
|
||||
'a8c0b8c5-c9a8-4f33-93a4-8142e6858244', -- workspace_build_id
|
||||
'8fa17bbd-c48c-44c7-91ae-d4acbc755fad', -- workspace_agent_id
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
INSERT INTO task_workspace_apps VALUES (
|
||||
INSERT INTO public.task_workspace_apps VALUES (
|
||||
'f5a1c3e4-8b2d-4f6a-9d7e-2a8b5c9e1f3d', -- task_id
|
||||
NULL, -- workspace_agent_id
|
||||
NULL, -- workspace_app_id
|
||||
|
||||
@@ -139,7 +139,6 @@ type sqlcQuerier interface {
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error
|
||||
DeleteWorkspaceAgentPortShare(ctx context.Context, arg DeleteWorkspaceAgentPortShareParams) error
|
||||
DeleteWorkspaceAgentPortSharesByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
DeleteWorkspaceSubAgentByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -678,7 +677,6 @@ type sqlcQuerier interface {
|
||||
UpdateOAuth2ProviderAppSecretByID(ctx context.Context, arg UpdateOAuth2ProviderAppSecretByIDParams) (OAuth2ProviderAppSecret, error)
|
||||
UpdateOrganization(ctx context.Context, arg UpdateOrganizationParams) (Organization, error)
|
||||
UpdateOrganizationDeletedByID(ctx context.Context, arg UpdateOrganizationDeletedByIDParams) error
|
||||
UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg UpdateOrganizationWorkspaceSharingSettingsParams) (Organization, error)
|
||||
// Cancels all pending provisioner jobs for prebuilt workspaces on a specific preset from an
|
||||
// inactive template version.
|
||||
// This is an optimization to clean up stale pending jobs.
|
||||
|
||||
@@ -2304,94 +2304,6 @@ func TestDeleteCustomRoleDoesNotDeleteSystemRole(t *testing.T) {
|
||||
require.True(t, roles[0].IsSystem)
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationWorkspaceSharingSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
updated, err := db.UpdateOrganizationWorkspaceSharingSettings(ctx, database.UpdateOrganizationWorkspaceSharingSettingsParams{
|
||||
ID: org.ID,
|
||||
WorkspaceSharingDisabled: true,
|
||||
UpdatedAt: dbtime.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, updated.WorkspaceSharingDisabled)
|
||||
|
||||
got, err := db.GetOrganizationByID(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.WorkspaceSharingDisabled)
|
||||
}
|
||||
|
||||
func TestDeleteWorkspaceACLsByOrganization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
org1 := dbgen.Organization(t, db, database.Organization{})
|
||||
org2 := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
owner1 := dbgen.User(t, db, database.User{})
|
||||
owner2 := dbgen.User(t, db, database.User{})
|
||||
sharedUser := dbgen.User(t, db, database.User{})
|
||||
sharedGroup := dbgen.Group(t, db, database.Group{
|
||||
OrganizationID: org1.ID,
|
||||
})
|
||||
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: owner1.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org2.ID,
|
||||
UserID: owner2.ID,
|
||||
})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
OrganizationID: org1.ID,
|
||||
UserID: sharedUser.ID,
|
||||
})
|
||||
|
||||
ws1 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner1.ID,
|
||||
OrganizationID: org1.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
sharedUser.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
GroupACL: database.WorkspaceACL{
|
||||
sharedGroup.ID.String(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ws2 := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: owner2.ID,
|
||||
OrganizationID: org2.ID,
|
||||
UserACL: database.WorkspaceACL{
|
||||
uuid.NewString(): {
|
||||
Permissions: []policy.Action{policy.ActionRead},
|
||||
},
|
||||
},
|
||||
}).Do().Workspace
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
err := db.DeleteWorkspaceACLsByOrganization(ctx, org1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
got1, err := db.GetWorkspaceByID(ctx, ws1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, got1.UserACL)
|
||||
require.Empty(t, got1.GroupACL)
|
||||
|
||||
got2, err := db.GetWorkspaceByID(ctx, ws2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, got2.UserACL)
|
||||
}
|
||||
|
||||
func TestAuthorizedAuditLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -8197,41 +8197,6 @@ func (q *sqlQuerier) UpdateOrganizationDeletedByID(ctx context.Context, arg Upda
|
||||
return err
|
||||
}
|
||||
|
||||
const updateOrganizationWorkspaceSharingSettings = `-- name: UpdateOrganizationWorkspaceSharingSettings :one
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
workspace_sharing_disabled = $1,
|
||||
updated_at = $2
|
||||
WHERE
|
||||
id = $3
|
||||
RETURNING id, name, description, created_at, updated_at, is_default, display_name, icon, deleted, workspace_sharing_disabled
|
||||
`
|
||||
|
||||
type UpdateOrganizationWorkspaceSharingSettingsParams struct {
|
||||
WorkspaceSharingDisabled bool `db:"workspace_sharing_disabled" json:"workspace_sharing_disabled"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateOrganizationWorkspaceSharingSettings(ctx context.Context, arg UpdateOrganizationWorkspaceSharingSettingsParams) (Organization, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateOrganizationWorkspaceSharingSettings, arg.WorkspaceSharingDisabled, arg.UpdatedAt, arg.ID)
|
||||
var i Organization
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Name,
|
||||
&i.Description,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.IsDefault,
|
||||
&i.DisplayName,
|
||||
&i.Icon,
|
||||
&i.Deleted,
|
||||
&i.WorkspaceSharingDisabled,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getParameterSchemasByJobID = `-- name: GetParameterSchemasByJobID :many
|
||||
SELECT
|
||||
id, created_at, job_id, name, description, default_source_scheme, default_source_value, allow_override_source, default_destination_scheme, allow_override_destination, default_refresh, redisplay_value, validation_error, validation_condition, validation_type_system, validation_value_type, index
|
||||
@@ -22186,21 +22151,6 @@ func (q *sqlQuerier) DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) e
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteWorkspaceACLsByOrganization = `-- name: DeleteWorkspaceACLsByOrganization :exec
|
||||
UPDATE
|
||||
workspaces
|
||||
SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) DeleteWorkspaceACLsByOrganization(ctx context.Context, organizationID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteWorkspaceACLsByOrganization, organizationID)
|
||||
return err
|
||||
}
|
||||
|
||||
const favoriteWorkspace = `-- name: FavoriteWorkspace :exec
|
||||
UPDATE workspaces SET favorite = true WHERE id = $1
|
||||
`
|
||||
|
||||
@@ -143,13 +143,3 @@ WHERE
|
||||
id = @id AND
|
||||
is_default = false;
|
||||
|
||||
-- name: UpdateOrganizationWorkspaceSharingSettings :one
|
||||
UPDATE
|
||||
organizations
|
||||
SET
|
||||
workspace_sharing_disabled = @workspace_sharing_disabled,
|
||||
updated_at = @updated_at
|
||||
WHERE
|
||||
id = @id
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -947,15 +947,6 @@ SET
|
||||
WHERE
|
||||
id = @id;
|
||||
|
||||
-- name: DeleteWorkspaceACLsByOrganization :exec
|
||||
UPDATE
|
||||
workspaces
|
||||
SET
|
||||
group_acl = '{}'::jsonb,
|
||||
user_acl = '{}'::jsonb
|
||||
WHERE
|
||||
organization_id = @organization_id;
|
||||
|
||||
-- name: GetRegularWorkspaceCreateMetrics :many
|
||||
-- Count regular workspaces: only those whose first successful 'start' build
|
||||
-- was not initiated by the prebuild system user.
|
||||
|
||||
+2
-6
@@ -83,21 +83,17 @@ func (api *API) debugDeploymentHealth(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), api.Options.HealthcheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create and store progress tracker for timeout diagnostics.
|
||||
report := api.HealthcheckFunc(ctx, apiKey, &api.healthCheckProgress)
|
||||
report := api.HealthcheckFunc(ctx, apiKey)
|
||||
if report != nil { // Only store non-nil reports.
|
||||
api.healthCheckCache.Store(report)
|
||||
}
|
||||
api.healthCheckProgress.Reset()
|
||||
return report, nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
summary := api.healthCheckProgress.Summary()
|
||||
httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{
|
||||
Message: "Healthcheck timed out.",
|
||||
Detail: summary,
|
||||
Message: "Healthcheck is in progress and did not complete in time. Try again in a few seconds.",
|
||||
})
|
||||
return
|
||||
case res := <-resChan:
|
||||
|
||||
+17
-20
@@ -14,8 +14,6 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
@@ -30,7 +28,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
calls.Add(1)
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
@@ -63,7 +61,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
calls.Add(1)
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
@@ -95,14 +93,19 @@ func TestDebugHealth(t *testing.T) {
|
||||
// Need to ignore errors due to ctx timeout
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
done = make(chan struct{})
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
Logger: &logger,
|
||||
HealthcheckTimeout: time.Second,
|
||||
HealthcheckFunc: func(_ context.Context, _ string, progress *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
progress.Start("test")
|
||||
<-done
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
HealthcheckTimeout: time.Microsecond,
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
t := time.NewTimer(time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
case <-t.C:
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
}
|
||||
},
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
@@ -112,14 +115,8 @@ func TestDebugHealth(t *testing.T) {
|
||||
res, err := client.Request(ctx, "GET", "/api/v2/debug/health", nil)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
close(done)
|
||||
bs, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err, "reading body")
|
||||
_, _ = io.ReadAll(res.Body)
|
||||
require.Equal(t, http.StatusServiceUnavailable, res.StatusCode)
|
||||
var sdkResp codersdk.Response
|
||||
require.NoError(t, json.Unmarshal(bs, &sdkResp), "unmarshaling sdk response")
|
||||
require.Equal(t, "Healthcheck timed out.", sdkResp.Message)
|
||||
require.Contains(t, sdkResp.Detail, "Still running: test (elapsed:")
|
||||
})
|
||||
|
||||
t.Run("Refresh", func(t *testing.T) {
|
||||
@@ -131,7 +128,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckRefresh: time.Microsecond,
|
||||
HealthcheckFunc: func(context.Context, string, *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
calls <- struct{}{}
|
||||
return &healthsdk.HealthcheckReport{}
|
||||
},
|
||||
@@ -176,7 +173,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckRefresh: time.Hour,
|
||||
HealthcheckTimeout: time.Hour,
|
||||
HealthcheckFunc: func(context.Context, string, *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(context.Context, string) *healthsdk.HealthcheckReport {
|
||||
calls++
|
||||
return &healthsdk.HealthcheckReport{
|
||||
Time: time.Now(),
|
||||
@@ -210,7 +207,7 @@ func TestDebugHealth(t *testing.T) {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
sessionToken string
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string, _ *healthcheck.Progress) *healthsdk.HealthcheckReport {
|
||||
HealthcheckFunc: func(_ context.Context, apiKey string) *healthsdk.HealthcheckReport {
|
||||
assert.Equal(t, sessionToken, apiKey)
|
||||
return &healthsdk.HealthcheckReport{
|
||||
Time: time.Now(),
|
||||
|
||||
@@ -2,9 +2,6 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -13,91 +10,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Progress tracks the progress of healthcheck components for timeout
|
||||
// diagnostics. It records which checks have started and completed, along with
|
||||
// their durations, to provide useful information when a healthcheck times out.
|
||||
// The zero value is usable.
|
||||
type Progress struct {
|
||||
Clock quartz.Clock
|
||||
mu sync.Mutex
|
||||
checks map[string]*checkStatus
|
||||
}
|
||||
|
||||
type checkStatus struct {
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
}
|
||||
|
||||
// Start records that a check has started.
|
||||
func (p *Progress) Start(name string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.Clock == nil {
|
||||
p.Clock = quartz.NewReal()
|
||||
}
|
||||
if p.checks == nil {
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
p.checks[name] = &checkStatus{startedAt: p.Clock.Now()}
|
||||
}
|
||||
|
||||
// Complete records that a check has finished.
|
||||
func (p *Progress) Complete(name string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.Clock == nil {
|
||||
p.Clock = quartz.NewReal()
|
||||
}
|
||||
if p.checks == nil {
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
if p.checks[name] == nil {
|
||||
p.checks[name] = &checkStatus{startedAt: p.Clock.Now()}
|
||||
}
|
||||
p.checks[name].completedAt = p.Clock.Now()
|
||||
}
|
||||
|
||||
// Reset clears all recorded check statuses.
|
||||
func (p *Progress) Reset() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.checks = make(map[string]*checkStatus)
|
||||
}
|
||||
|
||||
// Summary returns a human-readable summary of check progress.
|
||||
// Example: "Completed: AccessURL (95ms), Database (120ms). Still running: DERP, Websocket"
|
||||
func (p *Progress) Summary() string {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
var completed, running []string
|
||||
for name, status := range p.checks {
|
||||
if status.completedAt.IsZero() {
|
||||
elapsed := p.Clock.Now().Sub(status.startedAt).Round(time.Millisecond)
|
||||
running = append(running, fmt.Sprintf("%s (elapsed: %dms)", name, elapsed.Milliseconds()))
|
||||
continue
|
||||
}
|
||||
duration := status.completedAt.Sub(status.startedAt).Round(time.Millisecond)
|
||||
completed = append(completed, fmt.Sprintf("%s (%dms)", name, duration.Milliseconds()))
|
||||
}
|
||||
|
||||
// Sort for consistent output.
|
||||
slices.Sort(completed)
|
||||
slices.Sort(running)
|
||||
|
||||
var parts []string
|
||||
if len(completed) > 0 {
|
||||
parts = append(parts, "Completed: "+strings.Join(completed, ", "))
|
||||
}
|
||||
if len(running) > 0 {
|
||||
parts = append(parts, "Still running: "+strings.Join(running, ", "))
|
||||
}
|
||||
return strings.Join(parts, ". ")
|
||||
}
|
||||
|
||||
type Checker interface {
|
||||
DERP(ctx context.Context, opts *derphealth.ReportOptions) healthsdk.DERPHealthReport
|
||||
AccessURL(ctx context.Context, opts *AccessURLReportOptions) healthsdk.AccessURLReport
|
||||
@@ -116,10 +30,6 @@ type ReportOptions struct {
|
||||
ProvisionerDaemons ProvisionerDaemonsReportDeps
|
||||
|
||||
Checker Checker
|
||||
|
||||
// Progress tracks healthcheck progress for timeout diagnostics.
|
||||
// If set, each check will record its start and completion time.
|
||||
Progress *Progress
|
||||
}
|
||||
|
||||
type defaultChecker struct{}
|
||||
@@ -179,10 +89,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("DERP")
|
||||
defer opts.Progress.Complete("DERP")
|
||||
}
|
||||
report.DERP = opts.Checker.DERP(ctx, &opts.DerpHealth)
|
||||
}()
|
||||
|
||||
@@ -195,10 +101,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("AccessURL")
|
||||
defer opts.Progress.Complete("AccessURL")
|
||||
}
|
||||
report.AccessURL = opts.Checker.AccessURL(ctx, &opts.AccessURL)
|
||||
}()
|
||||
|
||||
@@ -211,10 +113,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("Websocket")
|
||||
defer opts.Progress.Complete("Websocket")
|
||||
}
|
||||
report.Websocket = opts.Checker.Websocket(ctx, &opts.Websocket)
|
||||
}()
|
||||
|
||||
@@ -227,10 +125,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("Database")
|
||||
defer opts.Progress.Complete("Database")
|
||||
}
|
||||
report.Database = opts.Checker.Database(ctx, &opts.Database)
|
||||
}()
|
||||
|
||||
@@ -243,10 +137,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("WorkspaceProxy")
|
||||
defer opts.Progress.Complete("WorkspaceProxy")
|
||||
}
|
||||
report.WorkspaceProxy = opts.Checker.WorkspaceProxy(ctx, &opts.WorkspaceProxy)
|
||||
}()
|
||||
|
||||
@@ -259,10 +149,6 @@ func Run(ctx context.Context, opts *ReportOptions) *healthsdk.HealthcheckReport
|
||||
}
|
||||
}()
|
||||
|
||||
if opts.Progress != nil {
|
||||
opts.Progress.Start("ProvisionerDaemons")
|
||||
defer opts.Progress.Complete("ProvisionerDaemons")
|
||||
}
|
||||
report.ProvisionerDaemons = opts.Checker.ProvisionerDaemons(ctx, &opts.ProvisionerDaemons)
|
||||
}()
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package healthcheck_test
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -11,7 +10,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/health"
|
||||
"github.com/coder/coder/v2/codersdk/healthsdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
type testChecker struct {
|
||||
@@ -535,69 +533,3 @@ func TestHealthcheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckProgress(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Summary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
|
||||
// Start some checks
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
progress.Start("AccessURL")
|
||||
|
||||
// Advance time to simulate check duration
|
||||
mClock.Advance(100 * time.Millisecond)
|
||||
|
||||
// Complete some checks
|
||||
progress.Complete("Database")
|
||||
progress.Complete("AccessURL")
|
||||
|
||||
summary := progress.Summary()
|
||||
|
||||
// Verify completed and running checks are listed with duration / elapsed
|
||||
assert.Equal(t, summary, "Completed: AccessURL (100ms), Database (100ms). Still running: DERP (elapsed: 100ms)")
|
||||
})
|
||||
|
||||
t.Run("EmptyProgress", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
summary := progress.Summary()
|
||||
|
||||
// Should be empty string when nothing tracked
|
||||
assert.Empty(t, summary)
|
||||
})
|
||||
|
||||
t.Run("AllCompleted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
mClock.Advance(50 * time.Millisecond)
|
||||
progress.Complete("Database")
|
||||
progress.Complete("DERP")
|
||||
|
||||
summary := progress.Summary()
|
||||
assert.Equal(t, summary, "Completed: DERP (50ms), Database (50ms)")
|
||||
})
|
||||
|
||||
t.Run("AllRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
progress := healthcheck.Progress{Clock: mClock}
|
||||
progress.Start("Database")
|
||||
progress.Start("DERP")
|
||||
|
||||
summary := progress.Summary()
|
||||
assert.Equal(t, summary, "Still running: DERP (elapsed: 0ms), Database (elapsed: 0ms)")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -493,10 +493,16 @@ func OneWayWebSocketEventSender(rw http.ResponseWriter, r *http.Request) (
|
||||
return sendEvent, closed, nil
|
||||
}
|
||||
|
||||
// OAuth2Error represents an OAuth2-compliant error response per RFC 6749.
|
||||
type OAuth2Error struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// WriteOAuth2Error writes an OAuth2-compliant error response per RFC 6749.
|
||||
// This should be used for all OAuth2 endpoints (/oauth2/*) to ensure compliance.
|
||||
func WriteOAuth2Error(ctx context.Context, rw http.ResponseWriter, status int, errorCode codersdk.OAuth2ErrorCode, description string) {
|
||||
Write(ctx, rw, status, codersdk.OAuth2Error{
|
||||
func WriteOAuth2Error(ctx context.Context, rw http.ResponseWriter, status int, errorCode, description string) {
|
||||
Write(ctx, rw, status, OAuth2Error{
|
||||
Error: errorCode,
|
||||
ErrorDescription: description,
|
||||
})
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type (
|
||||
httpRouteInfoKey struct{}
|
||||
)
|
||||
|
||||
type httpRouteInfo struct {
|
||||
Route string
|
||||
Method string
|
||||
}
|
||||
|
||||
// ExtractHTTPRoute retrieves just the HTTP route pattern from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPRoute(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Route
|
||||
}
|
||||
|
||||
// ExtractHTTPMethod retrieves just the HTTP method from context.
|
||||
// Returns empty string if not set.
|
||||
func ExtractHTTPMethod(ctx context.Context) string {
|
||||
ri, _ := ctx.Value(httpRouteInfoKey{}).(httpRouteInfo)
|
||||
return ri.Method
|
||||
}
|
||||
|
||||
// HTTPRoute is middleware that stores the HTTP route pattern and method in
|
||||
// context for use by downstream handlers and services (e.g. prometheus).
|
||||
func HTTPRoute(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
route := getRoutePattern(r)
|
||||
ctx := context.WithValue(r.Context(), httpRouteInfoKey{}, httpRouteInfo{
|
||||
Route: route,
|
||||
Method: r.Method,
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestHTTPRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
reqFn func() *http.Request
|
||||
registerRoutes map[string]string
|
||||
mws []func(http.Handler) http.Handler
|
||||
expectedRoute string
|
||||
expectedMethod string
|
||||
}{
|
||||
{
|
||||
name: "without middleware",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{},
|
||||
expectedRoute: "",
|
||||
expectedMethod: "",
|
||||
},
|
||||
{
|
||||
name: "root",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "parameterized route",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPut, "/users/123", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodPut: "/users/{id}"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "/users/{id}",
|
||||
expectedMethod: http.MethodPut,
|
||||
},
|
||||
{
|
||||
name: "unknown",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/api/a", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/api/b"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "UNKNOWN",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "static",
|
||||
reqFn: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodGet, "/some/static/file.png", nil)
|
||||
},
|
||||
registerRoutes: map[string]string{http.MethodGet: "/"},
|
||||
mws: []func(http.Handler) http.Handler{httpmw.HTTPRoute},
|
||||
expectedRoute: "STATIC",
|
||||
expectedMethod: http.MethodGet,
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
r := chi.NewRouter()
|
||||
done := make(chan string)
|
||||
for _, mw := range tc.mws {
|
||||
r.Use(mw)
|
||||
}
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
method := httpmw.ExtractHTTPMethod(r.Context())
|
||||
route := httpmw.ExtractHTTPRoute(r.Context())
|
||||
assert.Equal(t, tc.expectedMethod, method, "expected method mismatch")
|
||||
assert.Equal(t, tc.expectedRoute, route, "expected route mismatch")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
for method, route := range tc.registerRoutes {
|
||||
r.MethodFunc(method, route, func(w http.ResponseWriter, r *http.Request) {})
|
||||
}
|
||||
req := tc.reqFn()
|
||||
r.ServeHTTP(httptest.NewRecorder(), req)
|
||||
_ = testutil.TryReceive(ctx, t, done)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -290,15 +290,15 @@ func (*codersdkErrorWriter) writeClientNotFound(ctx context.Context, rw http.Res
|
||||
type oauth2ErrorWriter struct{}
|
||||
|
||||
func (*oauth2ErrorWriter) writeMissingClientID(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "Missing client_id parameter")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Missing client_id parameter")
|
||||
}
|
||||
|
||||
func (*oauth2ErrorWriter) writeInvalidClientID(ctx context.Context, rw http.ResponseWriter, _ error) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
}
|
||||
|
||||
func (*oauth2ErrorWriter) writeClientNotFound(ctx context.Context, rw http.ResponseWriter) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
}
|
||||
|
||||
// extractOAuth2ProviderAppBase is the internal implementation that uses the strategy pattern
|
||||
|
||||
@@ -3,8 +3,10 @@ package httpmw
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
@@ -69,7 +71,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
var (
|
||||
dist *prometheus.HistogramVec
|
||||
distOpts []string
|
||||
path = ExtractHTTPRoute(r.Context())
|
||||
path = getRoutePattern(r)
|
||||
)
|
||||
|
||||
// We want to count WebSockets separately.
|
||||
@@ -96,3 +98,29 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getRoutePattern(r *http.Request) string {
|
||||
rctx := chi.RouteContext(r.Context())
|
||||
if rctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
routePath := r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
routePath = r.URL.RawPath
|
||||
}
|
||||
|
||||
tctx := chi.NewRouteContext()
|
||||
routes := rctx.Routes
|
||||
if routes != nil && !routes.Match(tctx, r.Method, routePath) {
|
||||
// No matching pattern. /api/* requests will be matched as "UNKNOWN"
|
||||
// All other ones will be matched as "STATIC".
|
||||
if strings.HasPrefix(routePath, "/api/") {
|
||||
return "UNKNOWN"
|
||||
}
|
||||
return "STATIC"
|
||||
}
|
||||
|
||||
// tctx has the updated pattern, since Match mutates it
|
||||
return tctx.RoutePattern()
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ func TestPrometheus(t *testing.T) {
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
reg := prometheus.NewRegistry()
|
||||
httpmw.HTTPRoute(httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))).ServeHTTP(res, req)
|
||||
})).ServeHTTP(res, req)
|
||||
metrics, err := reg.Gather()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(metrics), 0)
|
||||
@@ -57,7 +57,7 @@ func TestPrometheus(t *testing.T) {
|
||||
wrappedHandler := promMW(testHandler)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(tracing.StatusWriterMiddleware, httpmw.HTTPRoute, promMW)
|
||||
r.Use(tracing.StatusWriterMiddleware, promMW)
|
||||
r.Get("/api/v2/build/{build}/logs", func(rw http.ResponseWriter, r *http.Request) {
|
||||
wrappedHandler.ServeHTTP(rw, r)
|
||||
})
|
||||
@@ -85,7 +85,7 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.With(httpmw.HTTPRoute).With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
r.With(promMW).Get("/api/v2/users/{user}", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v2/users/john", nil)
|
||||
|
||||
@@ -115,7 +115,6 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -146,7 +145,6 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -175,7 +173,6 @@ func TestPrometheus(t *testing.T) {
|
||||
promMW := httpmw.Prometheus(reg)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(httpmw.HTTPRoute)
|
||||
r.Use(promMW)
|
||||
r.Get("/api/v2/workspaceagents/{workspaceagent}/pty", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestOAuth2RegistrationErrorCodes(t *testing.T) {
|
||||
req: codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("test-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{"unsupported_grant_type"},
|
||||
GrantTypes: []string{"unsupported_grant_type"},
|
||||
},
|
||||
expectedError: "invalid_client_metadata",
|
||||
expectedCode: http.StatusBadRequest,
|
||||
@@ -109,7 +109,7 @@ func TestOAuth2RegistrationErrorCodes(t *testing.T) {
|
||||
req: codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("test-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{"unsupported_response_type"},
|
||||
ResponseTypes: []string{"unsupported_response_type"},
|
||||
},
|
||||
expectedError: "invalid_client_metadata",
|
||||
expectedCode: http.StatusBadRequest,
|
||||
|
||||
@@ -44,10 +44,10 @@ func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
require.NotEmpty(t, metadata.Issuer)
|
||||
require.NotEmpty(t, metadata.AuthorizationEndpoint)
|
||||
require.NotEmpty(t, metadata.TokenEndpoint)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, "code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "authorization_code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "refresh_token")
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, "S256")
|
||||
// Supported scopes are published from the curated catalog
|
||||
require.Equal(t, rbac.ExternalScopeNames(), metadata.ScopesSupported)
|
||||
}
|
||||
|
||||
@@ -277,47 +277,47 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
grantTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{},
|
||||
grantTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidAuthorizationCode",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code"},
|
||||
grantTypes: []string{"authorization_code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidRefreshTokenAlone",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"refresh_token"},
|
||||
grantTypes: []string{"refresh_token"},
|
||||
expectError: true, // refresh_token requires authorization_code to be present
|
||||
},
|
||||
{
|
||||
name: "ValidMultiple",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code", "refresh_token"},
|
||||
grantTypes: []string{"authorization_code", "refresh_token"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidUnsupported",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"client_credentials"},
|
||||
grantTypes: []string{"client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidPassword",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"password"},
|
||||
grantTypes: []string{"password"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"implicit"},
|
||||
grantTypes: []string{"implicit"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MixedValidInvalid",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{"authorization_code", "client_credentials"},
|
||||
grantTypes: []string{"authorization_code", "client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -352,32 +352,32 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
responseTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{},
|
||||
responseTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidCode",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"code"},
|
||||
responseTypes: []string{"code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidToken",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"token"},
|
||||
responseTypes: []string{"token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"id_token"},
|
||||
responseTypes: []string{"id_token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidMultiple",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"code", "token"},
|
||||
responseTypes: []string{"code", "token"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
authMethod string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
@@ -659,14 +659,14 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should default to authorization_code
|
||||
require.Contains(t, config.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, config.GrantTypes, "authorization_code")
|
||||
|
||||
// Should default to code
|
||||
require.Contains(t, config.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, config.ResponseTypes, "code")
|
||||
|
||||
// Should default to client_secret_basic or client_secret_post
|
||||
require.True(t, config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic ||
|
||||
config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost ||
|
||||
require.True(t, config.TokenEndpointAuthMethod == "client_secret_basic" ||
|
||||
config.TokenEndpointAuthMethod == "client_secret_post" ||
|
||||
config.TokenEndpointAuthMethod == "")
|
||||
|
||||
// Client secret should be generated
|
||||
|
||||
@@ -1329,10 +1329,10 @@ func TestOAuth2DynamicClientRegistration(t *testing.T) {
|
||||
require.Equal(t, int64(0), resp.ClientSecretExpiresAt) // Non-expiring
|
||||
|
||||
// Verify default values
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, resp.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Equal(t, codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic, resp.TokenEndpointAuthMethod)
|
||||
require.Contains(t, resp.GrantTypes, "authorization_code")
|
||||
require.Contains(t, resp.GrantTypes, "refresh_token")
|
||||
require.Contains(t, resp.ResponseTypes, "code")
|
||||
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod)
|
||||
|
||||
// Verify request values are preserved
|
||||
require.Equal(t, req.RedirectURIs, resp.RedirectURIs)
|
||||
@@ -1363,9 +1363,9 @@ func TestOAuth2DynamicClientRegistration(t *testing.T) {
|
||||
require.NotEmpty(t, resp.RegistrationClientURI)
|
||||
|
||||
// Should have defaults applied
|
||||
require.Contains(t, resp.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, resp.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Equal(t, codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic, resp.TokenEndpointAuthMethod)
|
||||
require.Contains(t, resp.GrantTypes, "authorization_code")
|
||||
require.Contains(t, resp.ResponseTypes, "code")
|
||||
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod)
|
||||
})
|
||||
|
||||
t.Run("InvalidRedirectURI", func(t *testing.T) {
|
||||
|
||||
@@ -137,13 +137,13 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
|
||||
callbackURL, err := url.Parse(app.CallbackURL)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to validate query parameters")
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, "server_error", "Failed to validate query parameters")
|
||||
return
|
||||
}
|
||||
|
||||
params, _, err := extractAuthorizeParams(r, callbackURL)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,10 +151,10 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
if params.codeChallenge != "" {
|
||||
// If code_challenge is provided but method is not, default to S256
|
||||
if params.codeChallengeMethod == "" {
|
||||
params.codeChallengeMethod = string(codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
params.codeChallengeMethod = "S256"
|
||||
}
|
||||
if err := codersdk.ValidatePKCECodeChallengeMethod(params.codeChallengeMethod); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
if params.codeChallengeMethod != "S256" {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid code_challenge_method: only S256 is supported")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -162,7 +162,7 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
// TODO: Ignoring scope for now, but should look into implementing.
|
||||
code, err := GenerateSecret()
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to generate OAuth2 app authorization code")
|
||||
httpapi.WriteOAuth2Error(r.Context(), rw, http.StatusInternalServerError, "server_error", "Failed to generate OAuth2 app authorization code")
|
||||
return
|
||||
}
|
||||
err = db.InTx(func(tx database.Store) error {
|
||||
@@ -202,7 +202,7 @@ func ProcessAuthorize(db database.Store) http.HandlerFunc {
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Failed to generate OAuth2 authorization code")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, "server_error", "Failed to generate OAuth2 authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ func GetAuthorizationServerMetadata(accessURL *url.URL) http.HandlerFunc {
|
||||
TokenEndpoint: accessURL.JoinPath("/oauth2/tokens").String(),
|
||||
RegistrationEndpoint: accessURL.JoinPath("/oauth2/register").String(), // RFC 7591
|
||||
RevocationEndpoint: accessURL.JoinPath("/oauth2/revoke").String(), // RFC 7009
|
||||
ResponseTypesSupported: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
GrantTypesSupported: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
CodeChallengeMethodsSupported: []codersdk.OAuth2PKCECodeChallengeMethod{codersdk.OAuth2PKCECodeChallengeMethodS256},
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
CodeChallengeMethodsSupported: []string{"S256"},
|
||||
ScopesSupported: rbac.ExternalScopeNames(),
|
||||
TokenEndpointAuthMethodsSupported: []codersdk.OAuth2TokenEndpointAuthMethod{codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost},
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_post"},
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, metadata)
|
||||
}
|
||||
|
||||
@@ -32,10 +32,10 @@ func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
require.NotEmpty(t, metadata.Issuer)
|
||||
require.NotEmpty(t, metadata.AuthorizationEndpoint)
|
||||
require.NotEmpty(t, metadata.TokenEndpoint)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, metadata.GrantTypesSupported, codersdk.OAuth2ProviderGrantTypeRefreshToken)
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, codersdk.OAuth2PKCECodeChallengeMethodS256)
|
||||
require.Contains(t, metadata.ResponseTypesSupported, "code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "authorization_code")
|
||||
require.Contains(t, metadata.GrantTypesSupported, "refresh_token")
|
||||
require.Contains(t, metadata.CodeChallengeMethodsSupported, "S256")
|
||||
// Supported scopes are published from the curated catalog
|
||||
require.Equal(t, rbac.ExternalScopeNames(), metadata.ScopesSupported)
|
||||
}
|
||||
|
||||
@@ -105,9 +105,8 @@ func GenerateState(t *testing.T) string {
|
||||
return base64.RawURLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// doAuthorizeRequest performs the OAuth2 authorization request and returns the response.
|
||||
// Caller is responsible for closing the response body.
|
||||
func doAuthorizeRequest(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) *http.Response {
|
||||
// AuthorizeOAuth2App performs the OAuth2 authorization flow and returns the authorization code
|
||||
func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) string {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -124,8 +123,6 @@ func doAuthorizeRequest(t *testing.T, client *codersdk.Client, baseURL string, p
|
||||
|
||||
if params.CodeChallenge != "" {
|
||||
query.Set("code_challenge", params.CodeChallenge)
|
||||
}
|
||||
if params.CodeChallengeMethod != "" {
|
||||
query.Set("code_challenge_method", params.CodeChallengeMethod)
|
||||
}
|
||||
if params.Resource != "" {
|
||||
@@ -154,15 +151,6 @@ func doAuthorizeRequest(t *testing.T, client *codersdk.Client, baseURL string, p
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
require.NoError(t, err, "failed to perform authorization request")
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// AuthorizeOAuth2App performs the OAuth2 authorization flow and returns the authorization code
|
||||
func AuthorizeOAuth2App(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams) string {
|
||||
t.Helper()
|
||||
|
||||
resp := doAuthorizeRequest(t, client, baseURL, params)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should get a redirect response (either 302 Found or 307 Temporary Redirect)
|
||||
@@ -338,13 +326,3 @@ func CleanupOAuth2App(t *testing.T, client *codersdk.Client, appID uuid.UUID) {
|
||||
t.Logf("Warning: failed to cleanup OAuth2 app %s: %v", appID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeOAuth2AppExpectingError performs the OAuth2 authorization flow expecting an error
|
||||
func AuthorizeOAuth2AppExpectingError(t *testing.T, client *codersdk.Client, baseURL string, params AuthorizeParams, expectedStatusCode int) {
|
||||
t.Helper()
|
||||
|
||||
resp := doAuthorizeRequest(t, client, baseURL, params)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, expectedStatusCode, resp.StatusCode, "unexpected status code")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/oauth2provider/oauth2providertest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestOAuth2AuthorizationServerMetadata(t *testing.T) {
|
||||
@@ -186,38 +185,6 @@ func TestOAuth2WithoutPKCE(t *testing.T) {
|
||||
require.NotEmpty(t, token.RefreshToken, "should receive refresh token")
|
||||
}
|
||||
|
||||
func TestOAuth2PKCEPlainMethodRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: false,
|
||||
})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Create OAuth2 app
|
||||
app, _ := oauth2providertest.CreateTestOAuth2App(t, client)
|
||||
t.Cleanup(func() {
|
||||
oauth2providertest.CleanupOAuth2App(t, client, app.ID)
|
||||
})
|
||||
|
||||
// Generate PKCE parameters but use "plain" method (should be rejected)
|
||||
_, codeChallenge := oauth2providertest.GeneratePKCE(t)
|
||||
state := oauth2providertest.GenerateState(t)
|
||||
|
||||
// Attempt authorization with plain method - should fail
|
||||
authParams := oauth2providertest.AuthorizeParams{
|
||||
ClientID: app.ID.String(),
|
||||
ResponseType: string(codersdk.OAuth2ProviderResponseTypeCode),
|
||||
RedirectURI: oauth2providertest.TestRedirectURI,
|
||||
State: state,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: string(codersdk.OAuth2PKCECodeChallengeMethodPlain),
|
||||
}
|
||||
|
||||
// Should get a 400 Bad Request
|
||||
oauth2providertest.AuthorizeOAuth2AppExpectingError(t, client, client.URL.String(), authParams, 400)
|
||||
}
|
||||
|
||||
func TestOAuth2ResourceParameter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/oauth2provider"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestVerifyPKCE(t *testing.T) {
|
||||
@@ -76,52 +75,3 @@ func TestPKCES256Generation(t *testing.T) {
|
||||
require.Equal(t, expectedChallenge, challenge)
|
||||
require.True(t, oauth2provider.VerifyPKCE(challenge, verifier))
|
||||
}
|
||||
|
||||
func TestValidatePKCECodeChallengeMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "EmptyIsValid",
|
||||
method: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "S256IsValid",
|
||||
method: string(codersdk.OAuth2PKCECodeChallengeMethodS256),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "PlainIsRejected",
|
||||
method: string(codersdk.OAuth2PKCECodeChallengeMethodPlain),
|
||||
expectError: true,
|
||||
errorContains: "plain",
|
||||
},
|
||||
{
|
||||
name: "UnknownIsRejected",
|
||||
method: "unknown_method",
|
||||
expectError: true,
|
||||
errorContains: "unsupported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := codersdk.ValidatePKCECodeChallengeMethod(tt.method)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
require.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("valid-grant-types-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
GrantTypes: []string{"authorization_code", "refresh_token"},
|
||||
}
|
||||
|
||||
resp, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -266,7 +266,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("invalid-grant-types-client-%d", time.Now().UnixNano()),
|
||||
GrantTypes: []codersdk.OAuth2ProviderGrantType{"unsupported_grant"},
|
||||
GrantTypes: []string{"unsupported_grant"},
|
||||
}
|
||||
|
||||
_, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -284,7 +284,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("valid-response-types-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
ResponseTypes: []string{"code"},
|
||||
}
|
||||
|
||||
resp, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
@@ -302,7 +302,7 @@ func TestOAuth2ClientRegistrationValidation(t *testing.T) {
|
||||
req := codersdk.OAuth2ClientRegistrationRequest{
|
||||
RedirectURIs: []string{"https://example.com/callback"},
|
||||
ClientName: fmt.Sprintf("invalid-response-types-client-%d", time.Now().UnixNano()),
|
||||
ResponseTypes: []codersdk.OAuth2ProviderResponseType{"token"}, // Not supported
|
||||
ResponseTypes: []string{"token"}, // Not supported
|
||||
}
|
||||
|
||||
_, err := client.PostOAuth2ClientRegistration(ctx, req)
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -86,9 +85,9 @@ func CreateDynamicClientRegistration(db database.Store, accessURL *url.URL, audi
|
||||
DynamicallyRegistered: sql.NullBool{Bool: true, Valid: true},
|
||||
ClientIDIssuedAt: sql.NullTime{Time: now, Valid: true},
|
||||
ClientSecretExpiresAt: sql.NullTime{}, // No expiration for now
|
||||
GrantTypes: slice.ToStrings(req.GrantTypes),
|
||||
ResponseTypes: slice.ToStrings(req.ResponseTypes),
|
||||
TokenEndpointAuthMethod: sql.NullString{String: string(req.TokenEndpointAuthMethod), Valid: true},
|
||||
GrantTypes: req.GrantTypes,
|
||||
ResponseTypes: req.ResponseTypes,
|
||||
TokenEndpointAuthMethod: sql.NullString{String: req.TokenEndpointAuthMethod, Valid: true},
|
||||
Scope: sql.NullString{String: req.Scope, Valid: true},
|
||||
Contacts: req.Contacts,
|
||||
ClientUri: sql.NullString{String: req.ClientURI, Valid: req.ClientURI != ""},
|
||||
@@ -155,9 +154,9 @@ func CreateDynamicClientRegistration(db database.Store, accessURL *url.URL, audi
|
||||
JWKS: app.Jwks.RawMessage,
|
||||
SoftwareID: app.SoftwareID.String,
|
||||
SoftwareVersion: app.SoftwareVersion.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](app.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](app.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(app.TokenEndpointAuthMethod.String),
|
||||
GrantTypes: app.GrantTypes,
|
||||
ResponseTypes: app.ResponseTypes,
|
||||
TokenEndpointAuthMethod: app.TokenEndpointAuthMethod.String,
|
||||
Scope: app.Scope.String,
|
||||
Contacts: app.Contacts,
|
||||
RegistrationAccessToken: registrationToken,
|
||||
@@ -218,12 +217,12 @@ func GetClientConfiguration(db database.Store) http.HandlerFunc {
|
||||
JWKS: app.Jwks.RawMessage,
|
||||
SoftwareID: app.SoftwareID.String,
|
||||
SoftwareVersion: app.SoftwareVersion.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](app.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](app.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(app.TokenEndpointAuthMethod.String),
|
||||
GrantTypes: app.GrantTypes,
|
||||
ResponseTypes: app.ResponseTypes,
|
||||
TokenEndpointAuthMethod: app.TokenEndpointAuthMethod.String,
|
||||
Scope: app.Scope.String,
|
||||
Contacts: app.Contacts,
|
||||
RegistrationAccessToken: "", // RFC 7592: Not returned in GET responses for security
|
||||
RegistrationAccessToken: nil, // RFC 7592: Not returned in GET responses for security
|
||||
RegistrationClientURI: app.RegistrationClientUri.String,
|
||||
}
|
||||
|
||||
@@ -304,9 +303,9 @@ func UpdateClientConfiguration(db database.Store, auditor *audit.Auditor, logger
|
||||
RedirectUris: req.RedirectURIs,
|
||||
ClientType: sql.NullString{String: req.DetermineClientType(), Valid: true},
|
||||
ClientSecretExpiresAt: sql.NullTime{}, // No expiration for now
|
||||
GrantTypes: slice.ToStrings(req.GrantTypes),
|
||||
ResponseTypes: slice.ToStrings(req.ResponseTypes),
|
||||
TokenEndpointAuthMethod: sql.NullString{String: string(req.TokenEndpointAuthMethod), Valid: true},
|
||||
GrantTypes: req.GrantTypes,
|
||||
ResponseTypes: req.ResponseTypes,
|
||||
TokenEndpointAuthMethod: sql.NullString{String: req.TokenEndpointAuthMethod, Valid: true},
|
||||
Scope: sql.NullString{String: req.Scope, Valid: true},
|
||||
Contacts: req.Contacts,
|
||||
ClientUri: sql.NullString{String: req.ClientURI, Valid: req.ClientURI != ""},
|
||||
@@ -342,12 +341,12 @@ func UpdateClientConfiguration(db database.Store, auditor *audit.Auditor, logger
|
||||
JWKS: updatedApp.Jwks.RawMessage,
|
||||
SoftwareID: updatedApp.SoftwareID.String,
|
||||
SoftwareVersion: updatedApp.SoftwareVersion.String,
|
||||
GrantTypes: slice.StringEnums[codersdk.OAuth2ProviderGrantType](updatedApp.GrantTypes),
|
||||
ResponseTypes: slice.StringEnums[codersdk.OAuth2ProviderResponseType](updatedApp.ResponseTypes),
|
||||
TokenEndpointAuthMethod: codersdk.OAuth2TokenEndpointAuthMethod(updatedApp.TokenEndpointAuthMethod.String),
|
||||
GrantTypes: updatedApp.GrantTypes,
|
||||
ResponseTypes: updatedApp.ResponseTypes,
|
||||
TokenEndpointAuthMethod: updatedApp.TokenEndpointAuthMethod.String,
|
||||
Scope: updatedApp.Scope.String,
|
||||
Contacts: updatedApp.Contacts,
|
||||
RegistrationAccessToken: "", // RFC 7592: Not returned for security
|
||||
RegistrationAccessToken: updatedApp.RegistrationAccessToken,
|
||||
RegistrationClientURI: updatedApp.RegistrationClientUri.String,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -28,26 +27,6 @@ var (
|
||||
ErrInvalidTokenFormat = xerrors.New("invalid token format")
|
||||
)
|
||||
|
||||
func extractRevocationRequest(r *http.Request) (codersdk.OAuth2TokenRevocationRequest, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return codersdk.OAuth2TokenRevocationRequest{}, xerrors.Errorf("invalid form data: %w", err)
|
||||
}
|
||||
|
||||
req := codersdk.OAuth2TokenRevocationRequest{
|
||||
Token: r.Form.Get("token"),
|
||||
TokenTypeHint: codersdk.OAuth2RevocationTokenTypeHint(r.Form.Get("token_type_hint")),
|
||||
ClientID: r.Form.Get("client_id"),
|
||||
ClientSecret: r.Form.Get("client_secret"),
|
||||
}
|
||||
|
||||
// RFC 7009 requires 'token' parameter.
|
||||
if req.Token == "" {
|
||||
return codersdk.OAuth2TokenRevocationRequest{}, xerrors.New("missing token parameter")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// RevokeToken implements RFC 7009 OAuth2 Token Revocation
|
||||
// Authentication is unique for this endpoint in that it does not use the
|
||||
// standard token authentication middleware. Instead, it expects the token that
|
||||
@@ -62,29 +41,35 @@ func RevokeToken(db database.Store, logger slog.Logger) http.HandlerFunc {
|
||||
|
||||
// RFC 7009 requires POST method with application/x-www-form-urlencoded
|
||||
if r.Method != http.MethodPost {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusMethodNotAllowed, codersdk.OAuth2ErrorCodeInvalidRequest, "Method not allowed")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusMethodNotAllowed, "invalid_request", "Method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
req, err := extractRevocationRequest(r)
|
||||
if err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, err.Error())
|
||||
if err := r.ParseForm(); err != nil {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid form data")
|
||||
return
|
||||
}
|
||||
|
||||
// RFC 7009 requires 'token' parameter
|
||||
token := r.Form.Get("token")
|
||||
if token == "" {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Missing token parameter")
|
||||
return
|
||||
}
|
||||
|
||||
// Determine if this is a refresh token (starts with "coder_") or API key
|
||||
// APIKeys do not have the SecretIdentifier prefix.
|
||||
const coderPrefix = SecretIdentifier + "_"
|
||||
isRefreshToken := strings.HasPrefix(req.Token, coderPrefix)
|
||||
isRefreshToken := strings.HasPrefix(token, coderPrefix)
|
||||
|
||||
// Revoke the token with ownership verification
|
||||
err = db.InTx(func(tx database.Store) error {
|
||||
err := db.InTx(func(tx database.Store) error {
|
||||
if isRefreshToken {
|
||||
// Handle refresh token revocation
|
||||
return revokeRefreshTokenInTx(ctx, tx, req.Token, app.ID)
|
||||
return revokeRefreshTokenInTx(ctx, tx, token, app.ID)
|
||||
}
|
||||
// Handle API key revocation
|
||||
return revokeAPIKeyInTx(ctx, tx, req.Token, app.ID)
|
||||
return revokeAPIKeyInTx(ctx, tx, token, app.ID)
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrTokenNotBelongsToClient) {
|
||||
@@ -100,14 +85,14 @@ func RevokeToken(db database.Store, logger slog.Logger) http.HandlerFunc {
|
||||
logger.Debug(ctx, "token revocation failed: invalid token format",
|
||||
slog.F("client_id", app.ID.String()),
|
||||
slog.F("app_name", app.Name))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "Invalid token format")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "Invalid token format")
|
||||
return
|
||||
}
|
||||
logger.Error(ctx, "token revocation failed with internal server error",
|
||||
slog.Error(err),
|
||||
slog.F("client_id", app.ID.String()),
|
||||
slog.F("app_name", app.Name))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, codersdk.OAuth2ErrorCodeServerError, "Internal server error")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusInternalServerError, "server_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
@@ -36,18 +38,28 @@ var (
|
||||
errInvalidResource = xerrors.New("invalid resource parameter")
|
||||
)
|
||||
|
||||
func extractTokenRequest(r *http.Request, callbackURL *url.URL) (codersdk.OAuth2TokenRequest, []codersdk.ValidationError, error) {
|
||||
type tokenParams struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
code string
|
||||
grantType codersdk.OAuth2ProviderGrantType
|
||||
redirectURL *url.URL
|
||||
refreshToken string
|
||||
codeVerifier string // PKCE verifier
|
||||
resource string // RFC 8707 resource for token binding
|
||||
scopes []string
|
||||
}
|
||||
|
||||
func extractTokenParams(r *http.Request, callbackURL *url.URL) (tokenParams, []codersdk.ValidationError, error) {
|
||||
p := httpapi.NewQueryParamParser()
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenRequest{}, nil, xerrors.Errorf("parse form: %w", err)
|
||||
return tokenParams{}, nil, xerrors.Errorf("parse form: %w", err)
|
||||
}
|
||||
|
||||
vals := r.Form
|
||||
p.RequiredNotEmpty("grant_type")
|
||||
grantType := httpapi.ParseCustom(p, vals, "", "grant_type", httpapi.ParseEnum[codersdk.OAuth2ProviderGrantType])
|
||||
|
||||
// Grant-type specific validation - must be called before parsing values.
|
||||
switch grantType {
|
||||
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
|
||||
p.RequiredNotEmpty("refresh_token")
|
||||
@@ -55,23 +67,19 @@ func extractTokenRequest(r *http.Request, callbackURL *url.URL) (codersdk.OAuth2
|
||||
p.RequiredNotEmpty("client_secret", "client_id", "code")
|
||||
}
|
||||
|
||||
req := codersdk.OAuth2TokenRequest{
|
||||
GrantType: grantType,
|
||||
ClientID: p.String(vals, "", "client_id"),
|
||||
ClientSecret: p.String(vals, "", "client_secret"),
|
||||
Code: p.String(vals, "", "code"),
|
||||
RedirectURI: p.String(vals, "", "redirect_uri"),
|
||||
RefreshToken: p.String(vals, "", "refresh_token"),
|
||||
CodeVerifier: p.String(vals, "", "code_verifier"),
|
||||
Resource: p.String(vals, "", "resource"),
|
||||
Scope: p.String(vals, "", "scope"),
|
||||
params := tokenParams{
|
||||
clientID: p.String(vals, "", "client_id"),
|
||||
clientSecret: p.String(vals, "", "client_secret"),
|
||||
code: p.String(vals, "", "code"),
|
||||
grantType: grantType,
|
||||
redirectURL: p.RedirectURL(vals, callbackURL, "redirect_uri"),
|
||||
refreshToken: p.String(vals, "", "refresh_token"),
|
||||
codeVerifier: p.String(vals, "", "code_verifier"),
|
||||
resource: p.String(vals, "", "resource"),
|
||||
scopes: strings.Fields(strings.TrimSpace(p.String(vals, "", "scope"))),
|
||||
}
|
||||
|
||||
// Validate redirect URI - errors are added to p.Errors.
|
||||
_ = p.RedirectURL(vals, callbackURL, "redirect_uri")
|
||||
|
||||
// Validate resource parameter syntax (RFC 8707): must be absolute URI without fragment.
|
||||
if err := validateResourceParameter(req.Resource); err != nil {
|
||||
// Validate resource parameter syntax (RFC 8707): must be absolute URI without fragment
|
||||
if err := validateResourceParameter(params.resource); err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: "resource",
|
||||
Detail: "must be an absolute URI without fragment",
|
||||
@@ -80,9 +88,9 @@ func extractTokenRequest(r *http.Request, callbackURL *url.URL) (codersdk.OAuth2
|
||||
|
||||
p.ErrorExcessParams(vals)
|
||||
if len(p.Errors) > 0 {
|
||||
return codersdk.OAuth2TokenRequest{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
|
||||
return tokenParams{}, p.Errors, xerrors.Errorf("invalid query params: %w", p.Errors)
|
||||
}
|
||||
return req, nil, nil
|
||||
return params, nil, nil
|
||||
}
|
||||
|
||||
// Tokens
|
||||
@@ -102,13 +110,13 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
return
|
||||
}
|
||||
|
||||
req, validationErrs, err := extractTokenRequest(r, callbackURL)
|
||||
params, validationErrs, err := extractTokenParams(r, callbackURL)
|
||||
if err != nil {
|
||||
// Check for specific validation errors in priority order
|
||||
if slices.ContainsFunc(validationErrs, func(validationError codersdk.ValidationError) bool {
|
||||
return validationError.Field == "grant_type"
|
||||
}) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeUnsupportedGrantType, "The grant type is missing or unsupported")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "unsupported_grant_type", "The grant type is missing or unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,47 +125,47 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
if slices.ContainsFunc(validationErrs, func(validationError codersdk.ValidationError) bool {
|
||||
return validationError.Field == field
|
||||
}) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, fmt.Sprintf("Missing required parameter: %s", field))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", fmt.Sprintf("Missing required parameter: %s", field))
|
||||
return
|
||||
}
|
||||
}
|
||||
// Generic invalid request for other validation errors
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidRequest, "The request is missing required parameters or is otherwise malformed")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_request", "The request is missing required parameters or is otherwise malformed")
|
||||
return
|
||||
}
|
||||
|
||||
var token codersdk.OAuth2TokenResponse
|
||||
var token oauth2.Token
|
||||
//nolint:gocritic,revive // More cases will be added later.
|
||||
switch req.GrantType {
|
||||
switch params.grantType {
|
||||
// TODO: Client creds, device code.
|
||||
case codersdk.OAuth2ProviderGrantTypeRefreshToken:
|
||||
token, err = refreshTokenGrant(ctx, db, app, lifetimes, req)
|
||||
token, err = refreshTokenGrant(ctx, db, app, lifetimes, params)
|
||||
case codersdk.OAuth2ProviderGrantTypeAuthorizationCode:
|
||||
token, err = authorizationCodeGrant(ctx, db, app, lifetimes, req)
|
||||
token, err = authorizationCodeGrant(ctx, db, app, lifetimes, params)
|
||||
default:
|
||||
// This should handle truly invalid grant types
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeUnsupportedGrantType, fmt.Sprintf("The grant type %q is not supported", req.GrantType))
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "unsupported_grant_type", fmt.Sprintf("The grant type %q is not supported", params.grantType))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, errBadSecret) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, codersdk.OAuth2ErrorCodeInvalidClient, "The client credentials are invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusUnauthorized, "invalid_client", "The client credentials are invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errBadCode) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The authorization code is invalid or expired")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The authorization code is invalid or expired")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errInvalidPKCE) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The PKCE code verifier is invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The PKCE code verifier is invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errInvalidResource) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidTarget, "The resource parameter is invalid")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_target", "The resource parameter is invalid")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errBadToken) {
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, codersdk.OAuth2ErrorCodeInvalidGrant, "The refresh token is invalid or expired")
|
||||
httpapi.WriteOAuth2Error(ctx, rw, http.StatusBadRequest, "invalid_grant", "The refresh token is invalid or expired")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -174,77 +182,77 @@ func Tokens(db database.Store, lifetimes codersdk.SessionLifetime) http.HandlerF
|
||||
}
|
||||
}
|
||||
|
||||
func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, req codersdk.OAuth2TokenRequest) (codersdk.OAuth2TokenResponse, error) {
|
||||
func authorizationCodeGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, params tokenParams) (oauth2.Token, error) {
|
||||
// Validate the client secret.
|
||||
secret, err := ParseFormattedSecret(req.ClientSecret)
|
||||
secret, err := ParseFormattedSecret(params.clientSecret)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
return oauth2.Token{}, errBadSecret
|
||||
}
|
||||
//nolint:gocritic // Users cannot read secrets so we must use the system.
|
||||
dbSecret, err := db.GetOAuth2ProviderAppSecretByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(secret.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
return oauth2.Token{}, errBadSecret
|
||||
}
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
equalSecret := apikey.ValidateHash(dbSecret.HashedSecret, secret.Secret)
|
||||
if !equalSecret {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadSecret
|
||||
return oauth2.Token{}, errBadSecret
|
||||
}
|
||||
|
||||
// Validate the authorization code.
|
||||
code, err := ParseFormattedSecret(req.Code)
|
||||
code, err := ParseFormattedSecret(params.code)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
return oauth2.Token{}, errBadCode
|
||||
}
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
dbCode, err := db.GetOAuth2ProviderAppCodeByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(code.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
return oauth2.Token{}, errBadCode
|
||||
}
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
equalCode := apikey.ValidateHash(dbCode.HashedSecret, code.Secret)
|
||||
if !equalCode {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
return oauth2.Token{}, errBadCode
|
||||
}
|
||||
|
||||
// Ensure the code has not expired.
|
||||
if dbCode.ExpiresAt.Before(dbtime.Now()) {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadCode
|
||||
return oauth2.Token{}, errBadCode
|
||||
}
|
||||
|
||||
// Verify PKCE challenge if present
|
||||
if dbCode.CodeChallenge.Valid && dbCode.CodeChallenge.String != "" {
|
||||
if req.CodeVerifier == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
if params.codeVerifier == "" {
|
||||
return oauth2.Token{}, errInvalidPKCE
|
||||
}
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, req.CodeVerifier) {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidPKCE
|
||||
if !VerifyPKCE(dbCode.CodeChallenge.String, params.codeVerifier) {
|
||||
return oauth2.Token{}, errInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency (RFC 8707)
|
||||
if dbCode.ResourceUri.Valid && dbCode.ResourceUri.String != "" {
|
||||
// Resource was specified during authorization - it must match in token request
|
||||
if req.Resource == "" {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
if params.resource == "" {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
}
|
||||
if req.Resource != dbCode.ResourceUri.String {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
if params.resource != dbCode.ResourceUri.String {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
}
|
||||
} else if req.Resource != "" {
|
||||
} else if params.resource != "" {
|
||||
// Resource was not specified during authorization but is now provided
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
}
|
||||
|
||||
// Generate a refresh token.
|
||||
refreshToken, err := GenerateSecret()
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
// Generate the API key we will swap for the code.
|
||||
@@ -258,13 +266,13 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
TokenName: tokenName,
|
||||
})
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
// Grab the user roles so we can perform the exchange as the user.
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, dbCode.UserID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
}
|
||||
|
||||
// Do the actual token exchange in the database.
|
||||
@@ -316,47 +324,47 @@ func authorizationCodeGrant(ctx context.Context, db database.Store, app database
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
return codersdk.OAuth2TokenResponse{
|
||||
return oauth2.Token{
|
||||
AccessToken: sessionToken,
|
||||
TokenType: codersdk.OAuth2TokenTypeBearer,
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: refreshToken.Formatted,
|
||||
Expiry: key.ExpiresAt,
|
||||
ExpiresIn: int64(time.Until(key.ExpiresAt).Seconds()),
|
||||
Expiry: &key.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, req codersdk.OAuth2TokenRequest) (codersdk.OAuth2TokenResponse, error) {
|
||||
func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAuth2ProviderApp, lifetimes codersdk.SessionLifetime, params tokenParams) (oauth2.Token, error) {
|
||||
// Validate the token.
|
||||
token, err := ParseFormattedSecret(req.RefreshToken)
|
||||
token, err := ParseFormattedSecret(params.refreshToken)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
return oauth2.Token{}, errBadToken
|
||||
}
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
dbToken, err := db.GetOAuth2ProviderAppTokenByPrefix(dbauthz.AsSystemRestricted(ctx), []byte(token.Prefix))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
return oauth2.Token{}, errBadToken
|
||||
}
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
equal := apikey.ValidateHash(dbToken.RefreshHash, token.Secret)
|
||||
if !equal {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
return oauth2.Token{}, errBadToken
|
||||
}
|
||||
|
||||
// Ensure the token has not expired.
|
||||
if dbToken.ExpiresAt.Before(dbtime.Now()) {
|
||||
return codersdk.OAuth2TokenResponse{}, errBadToken
|
||||
return oauth2.Token{}, errBadToken
|
||||
}
|
||||
|
||||
// Verify resource parameter consistency for refresh tokens (RFC 8707)
|
||||
if req.Resource != "" {
|
||||
if params.resource != "" {
|
||||
// If resource is provided in refresh request, it must match the original token's audience
|
||||
if !dbToken.Audience.Valid || dbToken.Audience.String != req.Resource {
|
||||
return codersdk.OAuth2TokenResponse{}, errInvalidResource
|
||||
if !dbToken.Audience.Valid || dbToken.Audience.String != params.resource {
|
||||
return oauth2.Token{}, errInvalidResource
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,18 +372,18 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
//nolint:gocritic // There is no user yet so we must use the system.
|
||||
prevKey, err := db.GetAPIKeyByID(dbauthz.AsSystemRestricted(ctx), dbToken.APIKeyID)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
actor, _, err := httpmw.UserRBACSubject(ctx, db, prevKey.UserID, rbac.ScopeAll)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
return oauth2.Token{}, xerrors.Errorf("fetch user actor: %w", err)
|
||||
}
|
||||
|
||||
// Generate a new refresh token.
|
||||
refreshToken, err := GenerateSecret()
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
// Generate the new API key.
|
||||
@@ -389,7 +397,7 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
TokenName: tokenName,
|
||||
})
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
// Replace the token.
|
||||
@@ -429,15 +437,15 @@ func refreshTokenGrant(ctx context.Context, db database.Store, app database.OAut
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return codersdk.OAuth2TokenResponse{}, err
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
return codersdk.OAuth2TokenResponse{
|
||||
return oauth2.Token{
|
||||
AccessToken: sessionToken,
|
||||
TokenType: codersdk.OAuth2TokenTypeBearer,
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: refreshToken.Formatted,
|
||||
Expiry: key.ExpiresAt,
|
||||
ExpiresIn: int64(time.Until(key.ExpiresAt).Seconds()),
|
||||
Expiry: &key.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package oauth2provider
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -11,12 +10,6 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// parseScopes parses a space-delimited scope string into a slice of scopes
|
||||
// per RFC 6749.
|
||||
func parseScopes(scope string) []string {
|
||||
return strings.Fields(strings.TrimSpace(scope))
|
||||
}
|
||||
|
||||
// TestExtractTokenParams_Scopes tests OAuth2 scope parameter parsing
|
||||
// to ensure RFC 6749 compliance where scopes are space-delimited
|
||||
func TestExtractTokenParams_Scopes(t *testing.T) {
|
||||
@@ -122,15 +115,15 @@ func TestExtractTokenParams_Scopes(t *testing.T) {
|
||||
Form: form, // Form is the combination of PostForm and URL query
|
||||
}
|
||||
|
||||
// Extract token request
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
// Extract token params
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
|
||||
// Verify no errors occurred
|
||||
require.NoError(t, err, "extractTokenRequest should not return error for: %s", tc.description)
|
||||
require.NoError(t, err, "extractTokenParams should not return error for: %s", tc.description)
|
||||
require.Empty(t, validationErrs, "should have no validation errors for: %s", tc.description)
|
||||
|
||||
// Verify scopes match expected
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope), "scope parsing failed for: %s", tc.description)
|
||||
require.Equal(t, tc.expectedScopes, params.scopes, "scope parsing failed for: %s", tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -185,15 +178,15 @@ func TestExtractTokenParams_ScopesURLEncoded(t *testing.T) {
|
||||
Form: values,
|
||||
}
|
||||
|
||||
// Extract token request
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
// Extract token params
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
|
||||
// Verify no errors
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
|
||||
// Verify scopes
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope))
|
||||
require.Equal(t, tc.expectedScopes, params.scopes)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -266,11 +259,11 @@ func TestExtractTokenParams_ScopesEdgeCases(t *testing.T) {
|
||||
Form: form,
|
||||
}
|
||||
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
|
||||
require.NoError(t, err, "extractTokenRequest should not error for: %s", tc.description)
|
||||
require.NoError(t, err, "extractTokenParams should not error for: %s", tc.description)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, tc.expectedScopes, parseScopes(tokenReq.Scope), "scope mismatch for: %s", tc.description)
|
||||
require.Equal(t, tc.expectedScopes, params.scopes, "scope mismatch for: %s", tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -361,10 +354,10 @@ func TestRefreshTokenGrant_Scopes(t *testing.T) {
|
||||
Form: form,
|
||||
}
|
||||
|
||||
tokenReq, validationErrs, err := extractTokenRequest(req, callbackURL)
|
||||
params, validationErrs, err := extractTokenParams(req, callbackURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, validationErrs)
|
||||
require.Equal(t, codersdk.OAuth2ProviderGrantTypeRefreshToken, tokenReq.GrantType)
|
||||
require.Equal(t, []string{"reduced:scope", "subset:scope"}, parseScopes(tokenReq.Scope))
|
||||
require.Equal(t, codersdk.OAuth2ProviderGrantTypeRefreshToken, params.grantType)
|
||||
require.Equal(t, []string{"reduced:scope", "subset:scope"}, params.scopes)
|
||||
}
|
||||
|
||||
@@ -277,47 +277,47 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
grantTypes []codersdk.OAuth2ProviderGrantType
|
||||
grantTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{},
|
||||
grantTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidAuthorizationCode",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode},
|
||||
grantTypes: []string{"authorization_code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidRefreshTokenAlone",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
grantTypes: []string{"refresh_token"},
|
||||
expectError: true, // refresh_token requires authorization_code to be present
|
||||
},
|
||||
{
|
||||
name: "ValidMultiple",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeRefreshToken},
|
||||
grantTypes: []string{"authorization_code", "refresh_token"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidUnsupported",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeClientCredentials},
|
||||
grantTypes: []string{"client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidPassword",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypePassword},
|
||||
grantTypes: []string{"password"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidImplicit",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeImplicit},
|
||||
grantTypes: []string{"implicit"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MixedValidInvalid",
|
||||
grantTypes: []codersdk.OAuth2ProviderGrantType{codersdk.OAuth2ProviderGrantTypeAuthorizationCode, codersdk.OAuth2ProviderGrantTypeClientCredentials},
|
||||
grantTypes: []string{"authorization_code", "client_credentials"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -352,32 +352,32 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
responseTypes []codersdk.OAuth2ProviderResponseType
|
||||
responseTypes []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "DefaultEmpty",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{},
|
||||
responseTypes: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidCode",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode},
|
||||
responseTypes: []string{"code"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidToken",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeToken},
|
||||
responseTypes: []string{"token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidIDToken",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{"id_token"}, // OIDC-specific, no constant
|
||||
name: "InvalidImplicit",
|
||||
responseTypes: []string{"id_token"},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidMultiple",
|
||||
responseTypes: []codersdk.OAuth2ProviderResponseType{codersdk.OAuth2ProviderResponseTypeCode, codersdk.OAuth2ProviderResponseTypeToken},
|
||||
responseTypes: []string{"code", "token"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authMethod codersdk.OAuth2TokenEndpointAuthMethod
|
||||
authMethod string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
@@ -422,27 +422,27 @@ func TestOAuth2ClientMetadataValidation(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "ValidClientSecretBasic",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic,
|
||||
authMethod: "client_secret_basic",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidClientSecretPost",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost,
|
||||
authMethod: "client_secret_post",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidNone",
|
||||
authMethod: codersdk.OAuth2TokenEndpointAuthMethodNone,
|
||||
authMethod: "none",
|
||||
expectError: false, // "none" is valid for public clients per RFC 7591
|
||||
},
|
||||
{
|
||||
name: "InvalidPrivateKeyJWT",
|
||||
authMethod: "private_key_jwt", // OIDC-specific, no constant defined
|
||||
authMethod: "private_key_jwt",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidClientSecretJWT",
|
||||
authMethod: "client_secret_jwt", // OIDC-specific, no constant defined
|
||||
authMethod: "client_secret_jwt",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
@@ -659,14 +659,14 @@ func TestOAuth2ClientMetadataDefaults(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should default to authorization_code
|
||||
require.Contains(t, config.GrantTypes, codersdk.OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
require.Contains(t, config.GrantTypes, "authorization_code")
|
||||
|
||||
// Should default to code
|
||||
require.Contains(t, config.ResponseTypes, codersdk.OAuth2ProviderResponseTypeCode)
|
||||
require.Contains(t, config.ResponseTypes, "code")
|
||||
|
||||
// Should default to client_secret_basic or client_secret_post
|
||||
require.True(t, config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretBasic ||
|
||||
config.TokenEndpointAuthMethod == codersdk.OAuth2TokenEndpointAuthMethodClientSecretPost ||
|
||||
require.True(t, config.TokenEndpointAuthMethod == "client_secret_basic" ||
|
||||
config.TokenEndpointAuthMethod == "client_secret_post" ||
|
||||
config.TokenEndpointAuthMethod == "")
|
||||
|
||||
// Client secret should be generated
|
||||
|
||||
@@ -2344,10 +2344,6 @@ func (api *API) patchWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
defer commitAudit()
|
||||
aReq.Old = workspace.WorkspaceTable()
|
||||
|
||||
if !api.allowWorkspaceSharing(ctx, rw, workspace.OrganizationID) {
|
||||
return
|
||||
}
|
||||
|
||||
var req codersdk.UpdateWorkspaceACL
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -2444,10 +2440,6 @@ func (api *API) deleteWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
defer commitAuditor()
|
||||
aReq.Old = workspace.WorkspaceTable()
|
||||
|
||||
if !api.allowWorkspaceSharing(ctx, rw, workspace.OrganizationID) {
|
||||
return
|
||||
}
|
||||
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
err := tx.DeleteWorkspaceACLByID(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
@@ -2471,27 +2463,6 @@ func (api *API) deleteWorkspaceACL(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
// allowWorkspaceSharing enforces the workspace-sharing gate for an
|
||||
// organization. It writes an HTTP error response and returns false if
|
||||
// sharing is disabled or the org lookup fails; otherwise it returns
|
||||
// true.
|
||||
func (api *API) allowWorkspaceSharing(ctx context.Context, rw http.ResponseWriter, organizationID uuid.UUID) bool {
|
||||
//nolint:gocritic // Use system context so this check doesn’t
|
||||
// depend on the caller having organization:read.
|
||||
org, err := api.Database.GetOrganizationByID(dbauthz.AsSystemRestricted(ctx), organizationID)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return false
|
||||
}
|
||||
if org.WorkspaceSharingDisabled {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Workspace sharing is disabled for this organization.",
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// workspacesData only returns the data the caller can access. If the caller
|
||||
// does not have the correct perms to read a given template, the template will
|
||||
// not be returned.
|
||||
|
||||
@@ -5266,66 +5266,7 @@ func TestDeleteWorkspaceACL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// `use`-role shares are granted `workspace:read` via the workspace RBAC ACL
|
||||
// list, so they should be able to read the ACL.
|
||||
//
|
||||
//nolint:tparallel,paralleltest // Test modifies a package global (rbac.workspaceACLDisabled).
|
||||
func TestWorkspaceReadCanListACL(t *testing.T) {
|
||||
// Be defensive by saving/restoring the modified package global.
|
||||
prevWorkspaceACLDisabled := rbac.WorkspaceACLDisabled()
|
||||
rbac.SetWorkspaceACLDisabled(false)
|
||||
t.Cleanup(func() { rbac.SetWorkspaceACLDisabled(prevWorkspaceACLDisabled) })
|
||||
|
||||
var (
|
||||
client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
dv.Experiments = []string{string(codersdk.ExperimentWorkspaceSharing)}
|
||||
}),
|
||||
})
|
||||
admin = coderdtest.CreateFirstUser(t, client)
|
||||
workspaceOwnerClient, workspaceOwner = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedUserClientA, sharedUserA = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
_, sharedUserB = coderdtest.CreateAnotherUser(t, client, admin.OrganizationID)
|
||||
sharedGroup = dbgen.Group(t, db, database.Group{OrganizationID: admin.OrganizationID})
|
||||
workspace = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OwnerID: workspaceOwner.ID,
|
||||
OrganizationID: admin.OrganizationID,
|
||||
}).Do().Workspace
|
||||
)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := workspaceOwnerClient.UpdateWorkspaceACL(ctx, workspace.ID, codersdk.UpdateWorkspaceACL{
|
||||
UserRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedUserA.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
sharedUserB.ID.String(): codersdk.WorkspaceRoleAdmin,
|
||||
},
|
||||
GroupRoles: map[string]codersdk.WorkspaceRole{
|
||||
sharedGroup.ID.String(): codersdk.WorkspaceRoleUse,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
acl, err := sharedUserClientA.WorkspaceACL(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, acl.Users, 2)
|
||||
require.Len(t, acl.Groups, 1)
|
||||
|
||||
gotRoles := make(map[uuid.UUID]codersdk.WorkspaceRole, len(acl.Users))
|
||||
for _, u := range acl.Users {
|
||||
gotRoles[u.ID] = u.Role
|
||||
}
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, gotRoles[sharedUserA.ID])
|
||||
require.Equal(t, codersdk.WorkspaceRoleAdmin, gotRoles[sharedUserB.ID])
|
||||
|
||||
gotGroupRoles := make(map[uuid.UUID]codersdk.WorkspaceRole, len(acl.Groups))
|
||||
for _, g := range acl.Groups {
|
||||
gotGroupRoles[g.ID] = g.Role
|
||||
}
|
||||
require.Equal(t, codersdk.WorkspaceRoleUse, gotGroupRoles[sharedGroup.ID])
|
||||
}
|
||||
|
||||
// nolint:tparallel,paralleltest // Subtests modify a package global (rbac.workspaceACLDisabled).
|
||||
// nolint:tparallel,paralleltest // Subtests modify package global.
|
||||
func TestWorkspaceSharingDisabled(t *testing.T) {
|
||||
t.Run("CanAccessWhenEnabled", func(t *testing.T) {
|
||||
var (
|
||||
|
||||
@@ -3484,16 +3484,6 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "rateLimit",
|
||||
},
|
||||
{
|
||||
Name: "AI Bridge Structured Logging",
|
||||
Description: "Emit structured logs for AI Bridge interception records. Use this for exporting these records to external SIEM or observability systems.",
|
||||
Flag: "aibridge-structured-logging",
|
||||
Env: "CODER_AIBRIDGE_STRUCTURED_LOGGING",
|
||||
Value: &c.AI.BridgeConfig.StructuredLogging,
|
||||
Default: "false",
|
||||
Group: &deploymentGroupAIBridge,
|
||||
YAML: "structuredLogging",
|
||||
},
|
||||
|
||||
// AI Bridge Proxy Options
|
||||
{
|
||||
@@ -3620,7 +3610,6 @@ type AIBridgeConfig struct {
|
||||
Retention serpent.Duration `json:"retention" typescript:",notnull"`
|
||||
MaxConcurrency serpent.Int64 `json:"max_concurrency" typescript:",notnull"`
|
||||
RateLimit serpent.Int64 `json:"rate_limit" typescript:",notnull"`
|
||||
StructuredLogging serpent.Bool `json:"structured_logging" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIBridgeOpenAIConfig struct {
|
||||
|
||||
+88
-248
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -187,22 +186,14 @@ func (c *Client) DeleteOAuth2ProviderAppSecret(ctx context.Context, appID uuid.U
|
||||
|
||||
type OAuth2ProviderGrantType string
|
||||
|
||||
// OAuth2ProviderGrantType values (RFC 6749).
|
||||
const (
|
||||
OAuth2ProviderGrantTypeAuthorizationCode OAuth2ProviderGrantType = "authorization_code"
|
||||
OAuth2ProviderGrantTypeRefreshToken OAuth2ProviderGrantType = "refresh_token"
|
||||
OAuth2ProviderGrantTypePassword OAuth2ProviderGrantType = "password"
|
||||
OAuth2ProviderGrantTypeClientCredentials OAuth2ProviderGrantType = "client_credentials"
|
||||
OAuth2ProviderGrantTypeImplicit OAuth2ProviderGrantType = "implicit"
|
||||
)
|
||||
|
||||
func (e OAuth2ProviderGrantType) Valid() bool {
|
||||
switch e {
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode,
|
||||
OAuth2ProviderGrantTypeRefreshToken,
|
||||
OAuth2ProviderGrantTypePassword,
|
||||
OAuth2ProviderGrantTypeClientCredentials,
|
||||
OAuth2ProviderGrantTypeImplicit:
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode, OAuth2ProviderGrantTypeRefreshToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -210,171 +201,19 @@ func (e OAuth2ProviderGrantType) Valid() bool {
|
||||
|
||||
type OAuth2ProviderResponseType string
|
||||
|
||||
// OAuth2ProviderResponseType values (RFC 6749).
|
||||
const (
|
||||
OAuth2ProviderResponseTypeCode OAuth2ProviderResponseType = "code"
|
||||
OAuth2ProviderResponseTypeToken OAuth2ProviderResponseType = "token"
|
||||
OAuth2ProviderResponseTypeCode OAuth2ProviderResponseType = "code"
|
||||
)
|
||||
|
||||
func (e OAuth2ProviderResponseType) Valid() bool {
|
||||
//nolint:gocritic,revive // More cases might be added later.
|
||||
switch e {
|
||||
case OAuth2ProviderResponseTypeCode, OAuth2ProviderResponseTypeToken:
|
||||
case OAuth2ProviderResponseTypeCode:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2TokenEndpointAuthMethod string
|
||||
|
||||
const (
|
||||
OAuth2TokenEndpointAuthMethodClientSecretBasic OAuth2TokenEndpointAuthMethod = "client_secret_basic"
|
||||
OAuth2TokenEndpointAuthMethodClientSecretPost OAuth2TokenEndpointAuthMethod = "client_secret_post"
|
||||
OAuth2TokenEndpointAuthMethodNone OAuth2TokenEndpointAuthMethod = "none"
|
||||
)
|
||||
|
||||
func (m OAuth2TokenEndpointAuthMethod) Valid() bool {
|
||||
switch m {
|
||||
case OAuth2TokenEndpointAuthMethodClientSecretBasic,
|
||||
OAuth2TokenEndpointAuthMethodClientSecretPost,
|
||||
OAuth2TokenEndpointAuthMethodNone:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2PKCECodeChallengeMethod string
|
||||
|
||||
// OAuth2PKCECodeChallengeMethod values (RFC 7636).
|
||||
const (
|
||||
OAuth2PKCECodeChallengeMethodS256 OAuth2PKCECodeChallengeMethod = "S256"
|
||||
OAuth2PKCECodeChallengeMethodPlain OAuth2PKCECodeChallengeMethod = "plain"
|
||||
)
|
||||
|
||||
func (m OAuth2PKCECodeChallengeMethod) Valid() bool {
|
||||
switch m {
|
||||
case OAuth2PKCECodeChallengeMethodS256, OAuth2PKCECodeChallengeMethodPlain:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2TokenType string
|
||||
|
||||
// OAuth2TokenType values (RFC 6749, RFC 9449).
|
||||
const (
|
||||
OAuth2TokenTypeBearer OAuth2TokenType = "Bearer"
|
||||
OAuth2TokenTypeDPoP OAuth2TokenType = "DPoP"
|
||||
)
|
||||
|
||||
func (t OAuth2TokenType) Valid() bool {
|
||||
switch t {
|
||||
case OAuth2TokenTypeBearer, OAuth2TokenTypeDPoP:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2RevocationTokenTypeHint string
|
||||
|
||||
const (
|
||||
OAuth2RevocationTokenTypeHintAccessToken OAuth2RevocationTokenTypeHint = "access_token"
|
||||
OAuth2RevocationTokenTypeHintRefreshToken OAuth2RevocationTokenTypeHint = "refresh_token"
|
||||
)
|
||||
|
||||
func (h OAuth2RevocationTokenTypeHint) Valid() bool {
|
||||
switch h {
|
||||
case OAuth2RevocationTokenTypeHintAccessToken, OAuth2RevocationTokenTypeHintRefreshToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type OAuth2ErrorCode string
|
||||
|
||||
// OAuth2 error codes per RFC 6749, RFC 7009, RFC 8707.
|
||||
// This is not comprehensive; it includes only codes relevant to this implementation.
|
||||
const (
|
||||
// RFC 6749 - Token endpoint errors.
|
||||
OAuth2ErrorCodeInvalidRequest OAuth2ErrorCode = "invalid_request"
|
||||
OAuth2ErrorCodeInvalidClient OAuth2ErrorCode = "invalid_client"
|
||||
OAuth2ErrorCodeInvalidGrant OAuth2ErrorCode = "invalid_grant"
|
||||
OAuth2ErrorCodeUnauthorizedClient OAuth2ErrorCode = "unauthorized_client"
|
||||
OAuth2ErrorCodeUnsupportedGrantType OAuth2ErrorCode = "unsupported_grant_type"
|
||||
OAuth2ErrorCodeInvalidScope OAuth2ErrorCode = "invalid_scope"
|
||||
|
||||
// RFC 6749 - Authorization endpoint errors.
|
||||
OAuth2ErrorCodeAccessDenied OAuth2ErrorCode = "access_denied"
|
||||
OAuth2ErrorCodeUnsupportedResponseType OAuth2ErrorCode = "unsupported_response_type"
|
||||
OAuth2ErrorCodeServerError OAuth2ErrorCode = "server_error"
|
||||
OAuth2ErrorCodeTemporarilyUnavailable OAuth2ErrorCode = "temporarily_unavailable"
|
||||
|
||||
// RFC 7009 - Token revocation errors.
|
||||
OAuth2ErrorCodeUnsupportedTokenType OAuth2ErrorCode = "unsupported_token_type"
|
||||
|
||||
// RFC 8707 - Resource indicator errors.
|
||||
OAuth2ErrorCodeInvalidTarget OAuth2ErrorCode = "invalid_target"
|
||||
)
|
||||
|
||||
func (c OAuth2ErrorCode) Valid() bool {
|
||||
switch c {
|
||||
case OAuth2ErrorCodeInvalidRequest,
|
||||
OAuth2ErrorCodeInvalidClient,
|
||||
OAuth2ErrorCodeInvalidGrant,
|
||||
OAuth2ErrorCodeUnauthorizedClient,
|
||||
OAuth2ErrorCodeUnsupportedGrantType,
|
||||
OAuth2ErrorCodeInvalidScope,
|
||||
OAuth2ErrorCodeAccessDenied,
|
||||
OAuth2ErrorCodeUnsupportedResponseType,
|
||||
OAuth2ErrorCodeServerError,
|
||||
OAuth2ErrorCodeTemporarilyUnavailable,
|
||||
OAuth2ErrorCodeUnsupportedTokenType,
|
||||
OAuth2ErrorCodeInvalidTarget:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OAuth2Error represents an OAuth2-compliant error response per RFC 6749.
|
||||
type OAuth2Error struct {
|
||||
Error OAuth2ErrorCode `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ErrorURI string `json:"error_uri,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2TokenRequest represents a token request per RFC 6749. The actual wire
|
||||
// format is application/x-www-form-urlencoded; this struct is for SDK docs.
|
||||
type OAuth2TokenRequest struct {
|
||||
GrantType OAuth2ProviderGrantType `json:"grant_type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
CodeVerifier string `json:"code_verifier,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2TokenResponse represents a successful token response per RFC 6749.
|
||||
type OAuth2TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType OAuth2TokenType `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// Expiry is not part of RFC 6749 but is included for compatibility with
|
||||
// golang.org/x/oauth2.Token and clients that expect a timestamp.
|
||||
Expiry *time.Time `json:"expiry,omitempty" format:"date-time"`
|
||||
}
|
||||
|
||||
// OAuth2TokenRevocationRequest represents a token revocation request per RFC 7009.
|
||||
type OAuth2TokenRevocationRequest struct {
|
||||
Token string `json:"token"`
|
||||
TokenTypeHint OAuth2RevocationTokenTypeHint `json:"token_type_hint,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
}
|
||||
|
||||
// RevokeOAuth2Token revokes a specific OAuth2 token using RFC 7009 token revocation.
|
||||
func (c *Client) RevokeOAuth2Token(ctx context.Context, clientID uuid.UUID, token string) error {
|
||||
form := url.Values{}
|
||||
@@ -417,18 +256,18 @@ type OAuth2DeviceFlowCallbackResponse struct {
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
}
|
||||
|
||||
// OAuth2AuthorizationServerMetadata represents RFC 8414 OAuth 2.0 Authorization Server Metadata.
|
||||
// OAuth2AuthorizationServerMetadata represents RFC 8414 OAuth 2.0 Authorization Server Metadata
|
||||
type OAuth2AuthorizationServerMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
ResponseTypesSupported []OAuth2ProviderResponseType `json:"response_types_supported"`
|
||||
GrantTypesSupported []OAuth2ProviderGrantType `json:"grant_types_supported,omitempty"`
|
||||
CodeChallengeMethodsSupported []OAuth2PKCECodeChallengeMethod `json:"code_challenge_methods_supported,omitempty"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2ProtectedResourceMetadata represents RFC 9728 OAuth 2.0 Protected Resource Metadata
|
||||
@@ -439,50 +278,50 @@ type OAuth2ProtectedResourceMetadata struct {
|
||||
BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
// OAuth2ClientRegistrationRequest represents RFC 7591 Dynamic Client Registration Request.
|
||||
// OAuth2ClientRegistrationRequest represents RFC 7591 Dynamic Client Registration Request
|
||||
type OAuth2ClientRegistrationRequest struct {
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
SoftwareStatement string `json:"software_statement,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types,omitempty"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types,omitempty"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
SoftwareStatement string `json:"software_statement,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
}
|
||||
|
||||
func (req OAuth2ClientRegistrationRequest) ApplyDefaults() OAuth2ClientRegistrationRequest {
|
||||
// Apply grant type defaults.
|
||||
// Apply grant type defaults
|
||||
if len(req.GrantTypes) == 0 {
|
||||
req.GrantTypes = []OAuth2ProviderGrantType{
|
||||
OAuth2ProviderGrantTypeAuthorizationCode,
|
||||
OAuth2ProviderGrantTypeRefreshToken,
|
||||
req.GrantTypes = []string{
|
||||
string(OAuth2ProviderGrantTypeAuthorizationCode),
|
||||
string(OAuth2ProviderGrantTypeRefreshToken),
|
||||
}
|
||||
}
|
||||
|
||||
// Apply response type defaults.
|
||||
// Apply response type defaults
|
||||
if len(req.ResponseTypes) == 0 {
|
||||
req.ResponseTypes = []OAuth2ProviderResponseType{
|
||||
OAuth2ProviderResponseTypeCode,
|
||||
req.ResponseTypes = []string{
|
||||
string(OAuth2ProviderResponseTypeCode),
|
||||
}
|
||||
}
|
||||
|
||||
// Apply token endpoint auth method default (RFC 7591 section 2).
|
||||
// Apply token endpoint auth method default (RFC 7591 section 2)
|
||||
if req.TokenEndpointAuthMethod == "" {
|
||||
// Default according to RFC 7591: "client_secret_basic" for confidential clients.
|
||||
// For public clients, should be explicitly set to "none".
|
||||
req.TokenEndpointAuthMethod = OAuth2TokenEndpointAuthMethodClientSecretBasic
|
||||
// Default according to RFC 7591: "client_secret_basic" for confidential clients
|
||||
// For public clients, should be explicitly set to "none"
|
||||
req.TokenEndpointAuthMethod = "client_secret_basic"
|
||||
}
|
||||
|
||||
// Apply client name default if not provided.
|
||||
// Apply client name default if not provided
|
||||
if req.ClientName == "" {
|
||||
req.ClientName = "Dynamically Registered Client"
|
||||
}
|
||||
@@ -538,29 +377,29 @@ func (req *OAuth2ClientRegistrationRequest) GenerateClientName() string {
|
||||
return "Dynamically Registered Client"
|
||||
}
|
||||
|
||||
// OAuth2ClientRegistrationResponse represents RFC 7591 Dynamic Client Registration Response.
|
||||
// OAuth2ClientRegistrationResponse represents RFC 7591 Dynamic Client Registration Response
|
||||
type OAuth2ClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
}
|
||||
|
||||
// PostOAuth2ClientRegistration dynamically registers a new OAuth2 client (RFC 7591)
|
||||
@@ -627,26 +466,27 @@ func (c *Client) DeleteOAuth2ClientConfiguration(ctx context.Context, clientID s
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuth2ClientConfiguration represents RFC 7592 Client Read Response.
|
||||
// OAuth2ClientConfiguration represents RFC 7592 Client Configuration (for GET/PUT operations)
|
||||
// Same as OAuth2ClientRegistrationResponse but without client_secret in GET responses
|
||||
type OAuth2ClientConfiguration struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []OAuth2ProviderGrantType `json:"grant_types"`
|
||||
ResponseTypes []OAuth2ProviderResponseType `json:"response_types"`
|
||||
TokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS json.RawMessage `json:"jwks,omitempty" swaggertype:"object"`
|
||||
SoftwareID string `json:"software_id,omitempty"`
|
||||
SoftwareVersion string `json:"software_version,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RegistrationAccessToken []byte `json:"registration_access_token"`
|
||||
RegistrationClientURI string `json:"registration_client_uri"`
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (req *OAuth2ClientRegistrationRequest) Validate() error {
|
||||
}
|
||||
|
||||
// validateRedirectURIs validates redirect URIs according to RFC 7591, 8252
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndpointAuthMethod) error {
|
||||
func validateRedirectURIs(uris []string, tokenEndpointAuthMethod string) error {
|
||||
if len(uris) == 0 {
|
||||
return xerrors.New("at least one redirect URI is required")
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndp
|
||||
}
|
||||
|
||||
// Determine if this is a public client based on token endpoint auth method
|
||||
isPublicClient := tokenEndpointAuthMethod == OAuth2TokenEndpointAuthMethodNone
|
||||
isPublicClient := tokenEndpointAuthMethod == "none"
|
||||
|
||||
// Handle different validation for public vs confidential clients
|
||||
if uri.Scheme == "http" || uri.Scheme == "https" {
|
||||
@@ -155,15 +155,23 @@ func validateRedirectURIs(uris []string, tokenEndpointAuthMethod OAuth2TokenEndp
|
||||
}
|
||||
|
||||
// validateGrantTypes validates OAuth2 grant types
|
||||
func validateGrantTypes(grantTypes []OAuth2ProviderGrantType) error {
|
||||
func validateGrantTypes(grantTypes []string) error {
|
||||
validGrants := []string{
|
||||
string(OAuth2ProviderGrantTypeAuthorizationCode),
|
||||
string(OAuth2ProviderGrantTypeRefreshToken),
|
||||
// Add more grant types as they are implemented
|
||||
// "client_credentials",
|
||||
// "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
|
||||
for _, grant := range grantTypes {
|
||||
if !isSupportedGrantType(grant) {
|
||||
if !slices.Contains(validGrants, grant) {
|
||||
return xerrors.Errorf("unsupported grant type: %s", grant)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure authorization_code is present if redirect_uris are specified
|
||||
hasAuthCode := slices.Contains(grantTypes, OAuth2ProviderGrantTypeAuthorizationCode)
|
||||
hasAuthCode := slices.Contains(grantTypes, string(OAuth2ProviderGrantTypeAuthorizationCode))
|
||||
if !hasAuthCode {
|
||||
return xerrors.New("authorization_code grant type is required when redirect_uris are specified")
|
||||
}
|
||||
@@ -171,18 +179,15 @@ func validateGrantTypes(grantTypes []OAuth2ProviderGrantType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSupportedGrantType(grant OAuth2ProviderGrantType) bool {
|
||||
switch grant {
|
||||
case OAuth2ProviderGrantTypeAuthorizationCode, OAuth2ProviderGrantTypeRefreshToken:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validateResponseTypes validates OAuth2 response types
|
||||
func validateResponseTypes(responseTypes []OAuth2ProviderResponseType) error {
|
||||
func validateResponseTypes(responseTypes []string) error {
|
||||
validResponses := []string{
|
||||
string(OAuth2ProviderResponseTypeCode),
|
||||
// Add more response types as they are implemented
|
||||
}
|
||||
|
||||
for _, responseType := range responseTypes {
|
||||
if !isSupportedResponseType(responseType) {
|
||||
if !slices.Contains(validResponses, responseType) {
|
||||
return xerrors.Errorf("unsupported response type: %s", responseType)
|
||||
}
|
||||
}
|
||||
@@ -190,39 +195,24 @@ func validateResponseTypes(responseTypes []OAuth2ProviderResponseType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isSupportedResponseType(responseType OAuth2ProviderResponseType) bool {
|
||||
return responseType == OAuth2ProviderResponseTypeCode
|
||||
}
|
||||
|
||||
// validateTokenEndpointAuthMethod validates token endpoint authentication method
|
||||
func validateTokenEndpointAuthMethod(method OAuth2TokenEndpointAuthMethod) error {
|
||||
if !method.Valid() {
|
||||
func validateTokenEndpointAuthMethod(method string) error {
|
||||
validMethods := []string{
|
||||
"client_secret_post",
|
||||
"client_secret_basic",
|
||||
"none", // for public clients (RFC 7591)
|
||||
// Add more methods as they are implemented
|
||||
// "private_key_jwt",
|
||||
// "client_secret_jwt",
|
||||
}
|
||||
|
||||
if !slices.Contains(validMethods, method) {
|
||||
return xerrors.Errorf("unsupported token endpoint auth method: %s", method)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePKCECodeChallengeMethod validates PKCE code_challenge_method parameter.
|
||||
// Per OAuth 2.1, only S256 is supported; plain is rejected for security reasons.
|
||||
func ValidatePKCECodeChallengeMethod(method string) error {
|
||||
if method == "" {
|
||||
return nil // Optional, defaults to S256 if code_challenge is provided
|
||||
}
|
||||
|
||||
m := OAuth2PKCECodeChallengeMethod(method)
|
||||
|
||||
if m == OAuth2PKCECodeChallengeMethodPlain {
|
||||
return xerrors.New("code_challenge_method 'plain' is not supported; use 'S256'")
|
||||
}
|
||||
|
||||
if m != OAuth2PKCECodeChallengeMethodS256 {
|
||||
return xerrors.Errorf("unsupported code_challenge_method: %s", method)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateURIField validates a URI field
|
||||
func validateURIField(uriStr, fieldName string) error {
|
||||
if uriStr == "" {
|
||||
|
||||
@@ -146,27 +146,6 @@ func parseVariableValuesFromHCL(content []byte) ([]VariableValue, error) {
|
||||
return nil, err
|
||||
}
|
||||
stringData[attribute.Name] = string(m)
|
||||
case ctyType.IsObjectType() || ctyType.IsMapType():
|
||||
// In case of objects/maps, Coder only supports the map(string) type.
|
||||
result := map[string]string{}
|
||||
var err error
|
||||
_ = ctyValue.ForEachElement(func(key, val cty.Value) (stop bool) {
|
||||
if !val.Type().Equals(cty.String) {
|
||||
err = xerrors.Errorf("unsupported map value type for key %q: %s", key.AsString(), val.Type().GoString())
|
||||
return true
|
||||
}
|
||||
result[key.AsString()] = val.AsString()
|
||||
return false
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stringData[attribute.Name] = string(m)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unsupported value type (name: %s): %s", attribute.Name, ctyType.GoString())
|
||||
}
|
||||
|
||||
@@ -175,76 +175,3 @@ cores: 2`
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), `use the equals sign "=" to introduce the argument value`)
|
||||
}
|
||||
|
||||
func TestParseVariableValuesFromVarsFiles_MapOfStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given
|
||||
const (
|
||||
hclFilename = "file.tfvars"
|
||||
hclContent = `region = "us-east-1"
|
||||
default_tags = {
|
||||
owner_name = "John Doe"
|
||||
owner_email = "john@example.com"
|
||||
owner_slack = "@johndoe"
|
||||
}`
|
||||
)
|
||||
|
||||
// Prepare the .tfvars files
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "test-parse-variable-values-from-vars-files-map-of-strings-*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
err = os.WriteFile(filepath.Join(tempDir, hclFilename), []byte(hclContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When
|
||||
actual, err := codersdk.ParseUserVariableValues([]string{
|
||||
filepath.Join(tempDir, hclFilename),
|
||||
}, "", nil)
|
||||
|
||||
// Then
|
||||
require.NoError(t, err)
|
||||
require.Len(t, actual, 2)
|
||||
|
||||
// Results are sorted by name
|
||||
require.Equal(t, "default_tags", actual[0].Name)
|
||||
require.JSONEq(t, `{"owner_email":"john@example.com","owner_name":"John Doe","owner_slack":"@johndoe"}`, actual[0].Value)
|
||||
require.Equal(t, "region", actual[1].Name)
|
||||
require.Equal(t, "us-east-1", actual[1].Value)
|
||||
}
|
||||
|
||||
func TestParseVariableValuesFromVarsFiles_MapWithNonStringValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given - a map with non-string values should error
|
||||
const (
|
||||
hclFilename = "file.tfvars"
|
||||
hclContent = `config = {
|
||||
name = "test"
|
||||
count = 5
|
||||
}`
|
||||
)
|
||||
|
||||
// Prepare the .tfvars files
|
||||
tempDir, err := os.MkdirTemp(os.TempDir(), "test-parse-variable-values-from-vars-files-map-non-string-*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
err = os.WriteFile(filepath.Join(tempDir, hclFilename), []byte(hclContent), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
// When
|
||||
actual, err := codersdk.ParseUserVariableValues([]string{
|
||||
filepath.Join(tempDir, hclFilename),
|
||||
}, "", nil)
|
||||
|
||||
// Then
|
||||
require.Nil(t, actual)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unsupported map value type")
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// WorkspaceSharingSettings represents workspace sharing settings for an organization.
|
||||
type WorkspaceSharingSettings struct {
|
||||
SharingDisabled bool `json:"sharing_disabled"`
|
||||
}
|
||||
|
||||
// WorkspaceSharingSettings retrieves the workspace sharing settings for an organization.
|
||||
func (c *Client) WorkspaceSharingSettings(ctx context.Context, orgID string) (WorkspaceSharingSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/organizations/%s/settings/workspace-sharing", orgID), nil)
|
||||
if err != nil {
|
||||
return WorkspaceSharingSettings{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return WorkspaceSharingSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp WorkspaceSharingSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// PatchWorkspaceSharingSettings modifies the workspace sharing settings for an organization.
|
||||
func (c *Client) PatchWorkspaceSharingSettings(ctx context.Context, orgID string, req WorkspaceSharingSettings) (WorkspaceSharingSettings, error) {
|
||||
res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/v2/organizations/%s/settings/workspace-sharing", orgID), req)
|
||||
if err != nil {
|
||||
return WorkspaceSharingSettings{}, err
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return WorkspaceSharingSettings{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp WorkspaceSharingSettings
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
@@ -23,9 +23,8 @@ Rules follow the format: `key=value [key=value ...]` with three supported keys:
|
||||
|
||||
```yaml
|
||||
allowlist:
|
||||
- domain=github.com # All methods, all paths for github.com (exact match)
|
||||
- domain=*.github.com # All subdomains of github.com
|
||||
- method=GET,POST domain=api.example.com # GET/POST to api.example.com (exact match)
|
||||
- domain=github.com # All methods, all paths for github.com
|
||||
- method=GET,POST domain=api.example.com # GET/POST to api.example.com
|
||||
- domain=api.example.com path=/users,/posts # Multiple paths
|
||||
- method=GET domain=github.com path=/api/* # All three keys
|
||||
```
|
||||
@@ -36,20 +35,19 @@ allowlist:
|
||||
|
||||
The `*` wildcard matches domain labels (parts separated by dots).
|
||||
|
||||
| Pattern | Matches | Does NOT Match |
|
||||
|----------------|-------------------------------------------------------------|--------------------------------------------------------------------------|
|
||||
| `*` | All domains | - |
|
||||
| `github.com` | `github.com` (exact match only) | `api.github.com`, `v1.api.github.com` (subdomains), `github.io` |
|
||||
| `*.github.com` | `api.github.com`, `v1.api.github.com` (1+ subdomain levels) | `github.com` (base domain) |
|
||||
| `api.*.com` | `api.github.com`, `api.google.com` | `api.v1.github.com` (`*` in the middle matches exactly one domain label) |
|
||||
| `*.*.com` | `api.example.com`, `api.v1.github.com` | - |
|
||||
| `api.*` | ❌ **ERROR** - Cannot end with `*` | - |
|
||||
| Pattern | Matches | Does NOT Match |
|
||||
|----------------|------------------------------------------------------------------|--------------------------------------------------------------------------|
|
||||
| `*` | All domains | - |
|
||||
| `github.com` | `github.com`, `api.github.com`, `v1.api.github.com` (subdomains) | `github.io` (diff domain) |
|
||||
| `*.github.com` | `api.github.com`, `v1.api.github.com` (1+ subdomain levels) | `github.com` (base domain) |
|
||||
| `api.*.com` | `api.github.com`, `api.google.com` | `api.v1.github.com` (`*` in the middle matches exactly one domain label) |
|
||||
| `*.*.com` | `api.example.com`, `api.v1.github.com` | - |
|
||||
| `api.*` | ❌ **ERROR** - Cannot end with `*` | - |
|
||||
|
||||
**Important**:
|
||||
|
||||
- Patterns without `*` match **exactly** (no automatic subdomain matching)
|
||||
- Patterns without `*` at the start automatically match subdomains
|
||||
- `*.example.com` matches one or more subdomain levels
|
||||
- To match both base domain and subdomains, use separate rules: `domain=github.com` and `domain=*.github.com`
|
||||
- Domain patterns **cannot end with asterisk**
|
||||
|
||||
---
|
||||
|
||||
@@ -1606,11 +1606,6 @@
|
||||
"description": "Role sync settings to sync organization roles from an IdP.",
|
||||
"path": "reference/cli/organizations_settings_set_role-sync.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings set workspace-sharing",
|
||||
"description": "Workspace sharing settings for the organization.",
|
||||
"path": "reference/cli/organizations_settings_set_workspace-sharing.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings show",
|
||||
"description": "Outputs specified organization setting.",
|
||||
@@ -1631,11 +1626,6 @@
|
||||
"description": "Role sync settings to sync organization roles from an IdP.",
|
||||
"path": "reference/cli/organizations_settings_show_role-sync.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations settings show workspace-sharing",
|
||||
"description": "Workspace sharing settings for the organization.",
|
||||
"path": "reference/cli/organizations_settings_show_workspace-sharing.md"
|
||||
},
|
||||
{
|
||||
"title": "organizations show",
|
||||
"description": "Show the organization. Using \"selected\" will show the selected organization from the \"--org\" flag. Using \"me\" will show all organizations you are a member of.",
|
||||
|
||||
Generated
+34
-114
@@ -20,15 +20,15 @@ curl -X GET http://coder-server:8080/api/v2/.well-known/oauth-authorization-serv
|
||||
{
|
||||
"authorization_endpoint": "string",
|
||||
"code_challenge_methods_supported": [
|
||||
"S256"
|
||||
"string"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"issuer": "string",
|
||||
"registration_endpoint": "string",
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"revocation_endpoint": "string",
|
||||
"scopes_supported": [
|
||||
@@ -36,7 +36,7 @@ curl -X GET http://coder-server:8080/api/v2/.well-known/oauth-authorization-serv
|
||||
],
|
||||
"token_endpoint": "string",
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic"
|
||||
"string"
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -1265,9 +1265,9 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/authorize?client_id=string&st
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|-----------------|
|
||||
| `response_type` | `code`, `token` |
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|----------|
|
||||
| `response_type` | `code` |
|
||||
|
||||
### Responses
|
||||
|
||||
@@ -1301,9 +1301,9 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/authorize?client_id=string&s
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|-----------------|
|
||||
| `response_type` | `code`, `token` |
|
||||
| Parameter | Value(s) |
|
||||
|-----------------|----------|
|
||||
| `response_type` | `code` |
|
||||
|
||||
### Responses
|
||||
|
||||
@@ -1346,7 +1346,7 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1355,15 +1355,17 @@ curl -X GET http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1397,7 +1399,7 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1407,13 +1409,13 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1440,7 +1442,7 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1449,15 +1451,17 @@ curl -X PUT http://coder-server:8080/api/v2/oauth2/clients/{client_id} \
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1515,7 +1519,7 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1525,13 +1529,13 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1558,7 +1562,7 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -1570,12 +1574,12 @@ curl -X POST http://coder-server:8080/api/v2/oauth2/register \
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
@@ -1658,9 +1662,9 @@ grant_type: authorization_code
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Parameter | Value(s) |
|
||||
|----------------|-------------------------------------------------------------------------------------|
|
||||
| `» grant_type` | `authorization_code`, `client_credentials`, `implicit`, `password`, `refresh_token` |
|
||||
| Parameter | Value(s) |
|
||||
|----------------|---------------------------------------|
|
||||
| `» grant_type` | `authorization_code`, `refresh_token` |
|
||||
|
||||
### Example responses
|
||||
|
||||
@@ -2828,90 +2832,6 @@ curl -X PATCH http://coder-server:8080/api/v2/organizations/{organization}/setti
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Get workspace sharing settings for organization
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/organizations/{organization}/settings/workspace-sharing \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /organizations/{organization}/settings/workspace-sharing`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------------|------|--------------|----------|-----------------|
|
||||
| `organization` | path | string(uuid) | true | Organization ID |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Update workspace sharing settings for organization
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X PATCH http://coder-server:8080/api/v2/organizations/{organization}/settings/workspace-sharing \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`PATCH /organizations/{organization}/settings/workspace-sharing`
|
||||
|
||||
> Body parameter
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|----------------|------|----------------------------------------------------------------------------------|----------|----------------------------|
|
||||
| `organization` | path | string(uuid) | true | Organization ID |
|
||||
| `body` | body | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) | true | Workspace sharing settings |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.WorkspaceSharingSettings](schemas.md#codersdkworkspacesharingsettings) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## Fetch provisioner key details
|
||||
|
||||
### Code samples
|
||||
|
||||
Generated
+1
-2
@@ -191,8 +191,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
"retention": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
|
||||
Generated
+95
-168
@@ -396,8 +396,7 @@
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
"retention": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -413,7 +412,6 @@
|
||||
| `openai` | [codersdk.AIBridgeOpenAIConfig](#codersdkaibridgeopenaiconfig) | false | | |
|
||||
| `rate_limit` | integer | false | | |
|
||||
| `retention` | integer | false | | |
|
||||
| `structured_logging` | boolean | false | | |
|
||||
|
||||
## codersdk.AIBridgeInterception
|
||||
|
||||
@@ -745,8 +743,7 @@
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
"retention": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -2661,8 +2658,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
"retention": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
@@ -3206,8 +3202,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"key": "string"
|
||||
},
|
||||
"rate_limit": 0,
|
||||
"retention": 0,
|
||||
"structured_logging": true
|
||||
"retention": 0
|
||||
}
|
||||
},
|
||||
"allow_workspace_renames": true,
|
||||
@@ -5188,15 +5183,15 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
{
|
||||
"authorization_endpoint": "string",
|
||||
"code_challenge_methods_supported": [
|
||||
"S256"
|
||||
"string"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"issuer": "string",
|
||||
"registration_endpoint": "string",
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"revocation_endpoint": "string",
|
||||
"scopes_supported": [
|
||||
@@ -5204,25 +5199,25 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
],
|
||||
"token_endpoint": "string",
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic"
|
||||
"string"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------------------------------|-------------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `authorization_endpoint` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of [codersdk.OAuth2PKCECodeChallengeMethod](#codersdkoauth2pkcecodechallengemethod) | false | | |
|
||||
| `grant_types_supported` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `issuer` | string | false | | |
|
||||
| `registration_endpoint` | string | false | | |
|
||||
| `response_types_supported` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `revocation_endpoint` | string | false | | |
|
||||
| `scopes_supported` | array of string | false | | |
|
||||
| `token_endpoint` | string | false | | |
|
||||
| `token_endpoint_auth_methods_supported` | array of [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `authorization_endpoint` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of string | false | | |
|
||||
| `grant_types_supported` | array of string | false | | |
|
||||
| `issuer` | string | false | | |
|
||||
| `registration_endpoint` | string | false | | |
|
||||
| `response_types_supported` | array of string | false | | |
|
||||
| `revocation_endpoint` | string | false | | |
|
||||
| `scopes_supported` | array of string | false | | |
|
||||
| `token_endpoint` | string | false | | |
|
||||
| `token_endpoint_auth_methods_supported` | array of string | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientConfiguration
|
||||
|
||||
@@ -5237,7 +5232,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5246,43 +5241,45 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"redirect_uris": [
|
||||
"string"
|
||||
],
|
||||
"registration_access_token": "string",
|
||||
"registration_access_token": [
|
||||
0
|
||||
],
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | array of integer | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientRegistrationRequest
|
||||
|
||||
@@ -5294,7 +5291,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5304,37 +5301,37 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_statement": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_name` | string | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_statement` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `client_name` | string | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_statement` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ClientRegistrationResponse
|
||||
|
||||
@@ -5350,7 +5347,7 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"string"
|
||||
],
|
||||
"grant_types": [
|
||||
"authorization_code"
|
||||
"string"
|
||||
],
|
||||
"jwks": {},
|
||||
"jwks_uri": "string",
|
||||
@@ -5362,41 +5359,41 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
"registration_access_token": "string",
|
||||
"registration_client_uri": "string",
|
||||
"response_types": [
|
||||
"code"
|
||||
"string"
|
||||
],
|
||||
"scope": "string",
|
||||
"software_id": "string",
|
||||
"software_version": "string",
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
"token_endpoint_auth_method": "string",
|
||||
"tos_uri": "string"
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of [codersdk.OAuth2ProviderGrantType](#codersdkoauth2providergranttype) | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of [codersdk.OAuth2ProviderResponseType](#codersdkoauth2providerresponsetype) | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | [codersdk.OAuth2TokenEndpointAuthMethod](#codersdkoauth2tokenendpointauthmethod) | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------|-----------------|----------|--------------|-------------|
|
||||
| `client_id` | string | false | | |
|
||||
| `client_id_issued_at` | integer | false | | |
|
||||
| `client_name` | string | false | | |
|
||||
| `client_secret` | string | false | | |
|
||||
| `client_secret_expires_at` | integer | false | | |
|
||||
| `client_uri` | string | false | | |
|
||||
| `contacts` | array of string | false | | |
|
||||
| `grant_types` | array of string | false | | |
|
||||
| `jwks` | object | false | | |
|
||||
| `jwks_uri` | string | false | | |
|
||||
| `logo_uri` | string | false | | |
|
||||
| `policy_uri` | string | false | | |
|
||||
| `redirect_uris` | array of string | false | | |
|
||||
| `registration_access_token` | string | false | | |
|
||||
| `registration_client_uri` | string | false | | |
|
||||
| `response_types` | array of string | false | | |
|
||||
| `scope` | string | false | | |
|
||||
| `software_id` | string | false | | |
|
||||
| `software_version` | string | false | | |
|
||||
| `token_endpoint_auth_method` | string | false | | |
|
||||
| `tos_uri` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2Config
|
||||
|
||||
@@ -5460,20 +5457,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
| `device_flow` | boolean | false | | |
|
||||
| `enterprise_base_url` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2PKCECodeChallengeMethod
|
||||
|
||||
```json
|
||||
"S256"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------|
|
||||
| `S256`, `plain` |
|
||||
|
||||
## codersdk.OAuth2ProtectedResourceMetadata
|
||||
|
||||
```json
|
||||
@@ -5561,48 +5544,6 @@ Only certain features set these fields: - FeatureManagedAgentLimit|
|
||||
| `client_secret_full` | string | false | | |
|
||||
| `id` | string | false | | |
|
||||
|
||||
## codersdk.OAuth2ProviderGrantType
|
||||
|
||||
```json
|
||||
"authorization_code"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-------------------------------------------------------------------------------------|
|
||||
| `authorization_code`, `client_credentials`, `implicit`, `password`, `refresh_token` |
|
||||
|
||||
## codersdk.OAuth2ProviderResponseType
|
||||
|
||||
```json
|
||||
"code"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------|
|
||||
| `code`, `token` |
|
||||
|
||||
## codersdk.OAuth2TokenEndpointAuthMethod
|
||||
|
||||
```json
|
||||
"client_secret_basic"
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
#### Enumerated Values
|
||||
|
||||
| Value(s) |
|
||||
|-----------------------------------------------------|
|
||||
| `client_secret_basic`, `client_secret_post`, `none` |
|
||||
|
||||
## codersdk.OAuthConversionResponse
|
||||
|
||||
```json
|
||||
@@ -11725,20 +11666,6 @@ If the schedule is empty, the user will be updated to use the default schedule.|
|
||||
|--------------------|
|
||||
| ``, `admin`, `use` |
|
||||
|
||||
## codersdk.WorkspaceSharingSettings
|
||||
|
||||
```json
|
||||
{
|
||||
"sharing_disabled": true
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|--------------------|---------|----------|--------------|-------------|
|
||||
| `sharing_disabled` | boolean | false | | |
|
||||
|
||||
## codersdk.WorkspaceStatus
|
||||
|
||||
```json
|
||||
|
||||
@@ -24,4 +24,3 @@ coder organizations settings set
|
||||
| [<code>group-sync</code>](./organizations_settings_set_group-sync.md) | Group sync settings to sync groups from an IdP. |
|
||||
| [<code>role-sync</code>](./organizations_settings_set_role-sync.md) | Role sync settings to sync organization roles from an IdP. |
|
||||
| [<code>organization-sync</code>](./organizations_settings_set_organization-sync.md) | Organization sync settings to sync organization memberships from an IdP. |
|
||||
| [<code>workspace-sharing</code>](./organizations_settings_set_workspace-sharing.md) | Workspace sharing settings for the organization. |
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
<!-- DO NOT EDIT | GENERATED CONTENT -->
|
||||
# organizations settings set workspace-sharing
|
||||
|
||||
Workspace sharing settings for the organization.
|
||||
|
||||
Aliases:
|
||||
|
||||
* workspacesharing
|
||||
|
||||
## Usage
|
||||
|
||||
```console
|
||||
coder organizations settings set workspace-sharing
|
||||
```
|
||||
@@ -24,4 +24,3 @@ coder organizations settings show
|
||||
| [<code>group-sync</code>](./organizations_settings_show_group-sync.md) | Group sync settings to sync groups from an IdP. |
|
||||
| [<code>role-sync</code>](./organizations_settings_show_role-sync.md) | Role sync settings to sync organization roles from an IdP. |
|
||||
| [<code>organization-sync</code>](./organizations_settings_show_organization-sync.md) | Organization sync settings to sync organization memberships from an IdP. |
|
||||
| [<code>workspace-sharing</code>](./organizations_settings_show_workspace-sharing.md) | Workspace sharing settings for the organization. |
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
<!-- DO NOT EDIT | GENERATED CONTENT -->
|
||||
# organizations settings show workspace-sharing
|
||||
|
||||
Workspace sharing settings for the organization.
|
||||
|
||||
Aliases:
|
||||
|
||||
* workspacesharing
|
||||
|
||||
## Usage
|
||||
|
||||
```console
|
||||
coder organizations settings show workspace-sharing
|
||||
```
|
||||
Generated
-11
@@ -1836,17 +1836,6 @@ Maximum number of concurrent AI Bridge requests per replica. Set to 0 to disable
|
||||
|
||||
Maximum number of AI Bridge requests per second per replica. Set to 0 to disable (unlimited).
|
||||
|
||||
### --aibridge-structured-logging
|
||||
|
||||
| | |
|
||||
|-------------|-------------------------------------------------|
|
||||
| Type | <code>bool</code> |
|
||||
| Environment | <code>$CODER_AIBRIDGE_STRUCTURED_LOGGING</code> |
|
||||
| YAML | <code>aibridge.structuredLogging</code> |
|
||||
| Default | <code>false</code> |
|
||||
|
||||
Emit structured logs for AI Bridge interception records. Use this for exporting these records to external SIEM or observability systems.
|
||||
|
||||
### --aibridge-proxy-enabled
|
||||
|
||||
| | |
|
||||
|
||||
@@ -11,8 +11,8 @@ RUN cargo install jj-cli typos-cli watchexec-cli
|
||||
FROM ubuntu:jammy@sha256:104ae83764a5119017b8e8d6218fa0832b09df65aae7d5a6de29a85d813da2fb AS go
|
||||
|
||||
# Install Go manually, so that we can control the version
|
||||
ARG GO_VERSION=1.24.11
|
||||
ARG GO_CHECKSUM="bceca00afaac856bc48b4cc33db7cd9eb383c81811379faed3bdbc80edb0af65"
|
||||
ARG GO_VERSION=1.24.10
|
||||
ARG GO_CHECKSUM="dd52b974e3d9c5a7bbfb222c685806def6be5d6f7efd10f9caa9ca1fa2f47955"
|
||||
|
||||
# Boring Go is needed to build FIPS-compliant binaries.
|
||||
RUN apt-get update && \
|
||||
|
||||
@@ -7,6 +7,7 @@ allowlist:
|
||||
- domain=dev.coder.com
|
||||
|
||||
# test domains
|
||||
- method=GET domain=google.com
|
||||
- method=GET domain=typicode.com
|
||||
|
||||
# domain used in coder task workspaces
|
||||
|
||||
+15
-1
@@ -290,6 +290,11 @@ data "coder_parameter" "ide_choices" {
|
||||
value = "jetbrains"
|
||||
icon = "/icon/jetbrains.svg"
|
||||
}
|
||||
option {
|
||||
name = "JetBrains Fleet"
|
||||
value = "fleet"
|
||||
icon = "/icon/fleet.svg"
|
||||
}
|
||||
option {
|
||||
name = "Cursor"
|
||||
value = "cursor"
|
||||
@@ -453,6 +458,15 @@ module "zed" {
|
||||
folder = local.repo_dir
|
||||
}
|
||||
|
||||
module "jetbrains-fleet" {
|
||||
count = contains(jsondecode(data.coder_parameter.ide_choices.value), "fleet") ? data.coder_workspace.me.start_count : 0
|
||||
source = "registry.coder.com/coder/jetbrains-fleet/coder"
|
||||
version = "1.0.2"
|
||||
agent_id = coder_agent.dev.id
|
||||
agent_name = "dev"
|
||||
folder = local.repo_dir
|
||||
}
|
||||
|
||||
module "devcontainers-cli" {
|
||||
count = data.coder_workspace.me.start_count
|
||||
source = "dev.registry.coder.com/modules/devcontainers-cli/coder"
|
||||
@@ -890,7 +904,7 @@ module "claude-code" {
|
||||
source = "dev.registry.coder.com/coder/claude-code/coder"
|
||||
version = "4.3.0"
|
||||
enable_boundary = true
|
||||
boundary_version = "v0.5.5"
|
||||
boundary_version = "v0.5.2"
|
||||
agent_id = coder_agent.dev.id
|
||||
workdir = local.repo_dir
|
||||
claude_code_version = "latest"
|
||||
|
||||
@@ -46,10 +46,6 @@ var (
|
||||
ErrNoExternalAuthLinkFound = xerrors.New("no external auth link found")
|
||||
)
|
||||
|
||||
const (
|
||||
InterceptionLogMarker = "interception log"
|
||||
)
|
||||
|
||||
var _ aibridged.DRPCServer = &Server{}
|
||||
|
||||
type store interface {
|
||||
@@ -77,8 +73,7 @@ type Server struct {
|
||||
logger slog.Logger
|
||||
externalAuthConfigs map[string]*externalauth.Config
|
||||
|
||||
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
||||
structuredLogging bool
|
||||
coderMCPConfig *proto.MCPServerConfig // may be nil if not available
|
||||
}
|
||||
|
||||
func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, accessURL string,
|
||||
@@ -97,9 +92,8 @@ func NewServer(lifecycleCtx context.Context, store store, logger slog.Logger, ac
|
||||
srv := &Server{
|
||||
lifecycleCtx: lifecycleCtx,
|
||||
store: store,
|
||||
logger: logger,
|
||||
logger: logger.Named("aibridgedserver"),
|
||||
externalAuthConfigs: eac,
|
||||
structuredLogging: bridgeCfg.StructuredLogging.Value(),
|
||||
}
|
||||
|
||||
if bridgeCfg.InjectCoderMCPTools {
|
||||
@@ -129,33 +123,13 @@ func (s *Server) RecordInterception(ctx context.Context, in *proto.RecordInterce
|
||||
return nil, xerrors.Errorf("empty API key ID")
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "interception_start"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("initiator_id", initID.String()),
|
||||
slog.F("api_key_id", in.ApiKeyId),
|
||||
slog.F("provider", in.Provider),
|
||||
slog.F("model", in.Model),
|
||||
slog.F("started_at", in.StartedAt.AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeInterception(ctx, database.InsertAIBridgeInterceptionParams{
|
||||
ID: intcID,
|
||||
APIKeyID: sql.NullString{String: in.ApiKeyId, Valid: true},
|
||||
InitiatorID: initID,
|
||||
Provider: in.Provider,
|
||||
Model: in.Model,
|
||||
Metadata: out,
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
StartedAt: in.StartedAt.AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -174,14 +148,6 @@ func (s *Server) RecordInterceptionEnded(ctx context.Context, in *proto.RecordIn
|
||||
return nil, xerrors.Errorf("invalid interception ID %q: %w", in.GetId(), err)
|
||||
}
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "interception_end"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("ended_at", in.EndedAt.AsTime()),
|
||||
)
|
||||
}
|
||||
|
||||
_, err = s.store.UpdateAIBridgeInterceptionEnded(ctx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
ID: intcID,
|
||||
EndedAt: in.EndedAt.AsTime(),
|
||||
@@ -202,38 +168,18 @@ func (s *Server) RecordTokenUsage(ctx context.Context, in *proto.RecordTokenUsag
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "token_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("input_tokens", in.GetInputTokens()),
|
||||
slog.F("output_tokens", in.GetOutputTokens()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
ProviderResponseID: in.GetMsgId(),
|
||||
InputTokens: in.GetInputTokens(),
|
||||
OutputTokens: in.GetOutputTokens(),
|
||||
Metadata: out,
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert token usage: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordTokenUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -246,36 +192,17 @@ func (s *Server) RecordPromptUsage(ctx context.Context, in *proto.RecordPromptUs
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "prompt_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("prompt", in.GetPrompt()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
ProviderResponseID: in.GetMsgId(),
|
||||
Prompt: in.GetPrompt(),
|
||||
Metadata: out,
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert user prompt: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordPromptUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -288,28 +215,6 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
|
||||
return nil, xerrors.Errorf("failed to parse interception_id %q: %w", in.GetInterceptionId(), err)
|
||||
}
|
||||
|
||||
metadata := metadataToMap(in.GetMetadata())
|
||||
|
||||
if s.structuredLogging {
|
||||
s.logger.Info(ctx, InterceptionLogMarker,
|
||||
slog.F("record_type", "tool_usage"),
|
||||
slog.F("interception_id", intcID.String()),
|
||||
slog.F("msg_id", in.GetMsgId()),
|
||||
slog.F("tool", in.GetTool()),
|
||||
slog.F("input", in.GetInput()),
|
||||
slog.F("server_url", in.GetServerUrl()),
|
||||
slog.F("injected", in.GetInjected()),
|
||||
slog.F("invocation_error", in.GetInvocationError()),
|
||||
slog.F("created_at", in.GetCreatedAt().AsTime()),
|
||||
slog.F("metadata", metadata),
|
||||
)
|
||||
}
|
||||
|
||||
out, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
@@ -319,13 +224,12 @@ func (s *Server) RecordToolUsage(ctx context.Context, in *proto.RecordToolUsageR
|
||||
Input: in.GetInput(),
|
||||
Injected: in.GetInjected(),
|
||||
InvocationError: sql.NullString{String: in.GetInvocationError(), Valid: in.InvocationError != nil},
|
||||
Metadata: out,
|
||||
Metadata: marshalMetadata(ctx, s.logger, in.GetMetadata()),
|
||||
CreatedAt: in.GetCreatedAt().AsTime(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert tool usage: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RecordToolUsageResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -529,16 +433,24 @@ func getCoderMCPServerConfig(experiments codersdk.Experiments, accessURL string)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func metadataToMap(in map[string]*anypb.Any) map[string]any {
|
||||
meta := make(map[string]any, len(in))
|
||||
// marshalMetadata attempts to marshal the given metadata map into a
|
||||
// JSON-encoded byte slice. If the marshaling fails, the function logs a
|
||||
// warning and returns nil. The supplied context is only used for logging.
|
||||
func marshalMetadata(ctx context.Context, logger slog.Logger, in map[string]*anypb.Any) []byte {
|
||||
mdMap := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
var sv structpb.Value
|
||||
if err := v.UnmarshalTo(&sv); err == nil {
|
||||
meta[k] = sv.AsInterface()
|
||||
mdMap[k] = sv.AsInterface()
|
||||
}
|
||||
}
|
||||
return meta
|
||||
out, err := json.Marshal(mdMap)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal aibridge metadata from proto to JSON", slog.F("metadata", in), slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package aibridgedserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
)
|
||||
|
||||
func TestMarshalMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NilData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
out := marshalMetadata(context.Background(), logger, nil)
|
||||
require.JSONEq(t, "{}", string(out))
|
||||
})
|
||||
|
||||
t.Run("WithData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
list := structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
|
||||
structpb.NewStringValue("a"),
|
||||
structpb.NewNumberValue(1),
|
||||
structpb.NewBoolValue(false),
|
||||
}})
|
||||
obj := structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"a": structpb.NewStringValue("b"),
|
||||
"n": structpb.NewNumberValue(3),
|
||||
}})
|
||||
|
||||
nonValue := mustMarshalAny(t, &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"ignored": structpb.NewStringValue("yes"),
|
||||
}})
|
||||
invalid := &anypb.Any{TypeUrl: "type.googleapis.com/google.protobuf.Value", Value: []byte{0xff, 0x00}}
|
||||
|
||||
in := map[string]*anypb.Any{
|
||||
"null": mustMarshalAny(t, structpb.NewNullValue()),
|
||||
// Scalars
|
||||
"string": mustMarshalAny(t, structpb.NewStringValue("hello")),
|
||||
"bool": mustMarshalAny(t, structpb.NewBoolValue(true)),
|
||||
"number": mustMarshalAny(t, structpb.NewNumberValue(42)),
|
||||
// Complex types
|
||||
"list": mustMarshalAny(t, list),
|
||||
"object": mustMarshalAny(t, obj),
|
||||
// Extra valid entries
|
||||
"ok": mustMarshalAny(t, structpb.NewStringValue("present")),
|
||||
"nan": mustMarshalAny(t, structpb.NewNumberValue(math.NaN())),
|
||||
// Entries that should be ignored
|
||||
"invalid": invalid,
|
||||
"non_value": nonValue,
|
||||
}
|
||||
|
||||
out := marshalMetadata(context.Background(), logger, in)
|
||||
require.NotNil(t, out)
|
||||
var got map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &got))
|
||||
|
||||
expected := map[string]any{
|
||||
"string": "hello",
|
||||
"bool": true,
|
||||
"number": float64(42),
|
||||
"null": nil,
|
||||
"list": []any{"a", float64(1), false},
|
||||
"object": map[string]any{"a": "b", "n": float64(3)},
|
||||
"ok": "present",
|
||||
"nan": "NaN",
|
||||
}
|
||||
require.Equal(t, expected, got)
|
||||
})
|
||||
}
|
||||
|
||||
func mustMarshalAny(t testing.TB, m proto.Message) *anypb.Any {
|
||||
t.Helper()
|
||||
a, err := anypb.New(m)
|
||||
require.NoError(t, err)
|
||||
return a
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package aibridgedserver_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -22,8 +20,6 @@ import (
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogjson"
|
||||
"github.com/coder/coder/v2/coderd/apikey"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
@@ -836,279 +832,3 @@ func mustMarshalAny(t *testing.T, msg protobufproto.Message) *anypb.Any {
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
// logLine represents a parsed JSON log entry.
|
||||
type logLine struct {
|
||||
Msg string `json:"msg"`
|
||||
Level string `json:"level"`
|
||||
Fields map[string]any `json:"fields"`
|
||||
}
|
||||
|
||||
// parseLogLines parses JSON log lines from a buffer.
|
||||
func parseLogLines(buf *bytes.Buffer) []logLine {
|
||||
var lines []logLine
|
||||
scanner := bufio.NewScanner(buf)
|
||||
for scanner.Scan() {
|
||||
var line logLine
|
||||
if err := json.Unmarshal(scanner.Bytes(), &line); err == nil {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// getLogLinesWithMessage returns all log lines with the given message.
|
||||
func getLogLinesWithMessage(lines []logLine, msg string) []logLine {
|
||||
var result []logLine
|
||||
for _, line := range lines {
|
||||
if line.Msg == msg {
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestStructuredLogging(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
metadataProto := map[string]*anypb.Any{
|
||||
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
structuredLogging bool
|
||||
expectedErr error
|
||||
setupMocks func(db *dbmock.MockStore, interceptionID uuid.UUID)
|
||||
recordFn func(srv *aibridgedserver.Server, ctx context.Context, interceptionID uuid.UUID) error
|
||||
expectedFields map[string]any
|
||||
}
|
||||
|
||||
interceptionID := uuid.UUID{1}
|
||||
initiatorID := uuid.UUID{2}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "RecordInterception_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
InitiatorID: initiatorID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
Metadata: metadataProto,
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_start",
|
||||
"interception_id": interceptionID.String(),
|
||||
"initiator_id": initiatorID.String(),
|
||||
"provider": "anthropic",
|
||||
"model": "claude-4-opus",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordInterception_does_not_log_when_disabled",
|
||||
structuredLogging: false,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
InitiatorID: initiatorID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: nil, // No log expected.
|
||||
},
|
||||
{
|
||||
name: "RecordInterception_log_on_db_error",
|
||||
structuredLogging: true,
|
||||
expectedErr: sql.ErrConnDone,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{}, sql.ErrConnDone)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
|
||||
Id: intcID.String(),
|
||||
ApiKeyId: "api-key-123",
|
||||
InitiatorId: initiatorID.String(),
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-opus",
|
||||
StartedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
// Even though the database call errored, we must still write the logs.
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_start",
|
||||
"interception_id": interceptionID.String(),
|
||||
"initiator_id": initiatorID.String(),
|
||||
"provider": "anthropic",
|
||||
"model": "claude-4-opus",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordInterceptionEnded_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
|
||||
ID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordInterceptionEnded(ctx, &proto.RecordInterceptionEndedRequest{
|
||||
Id: intcID.String(),
|
||||
EndedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "interception_end",
|
||||
"interception_id": interceptionID.String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordTokenUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeTokenUsage{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordTokenUsage(ctx, &proto.RecordTokenUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "token_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"input_tokens": float64(100), // JSON numbers are float64.
|
||||
"output_tokens": float64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordPromptUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(database.AIBridgeUserPrompt{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordPromptUsage(ctx, &proto.RecordPromptUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
Prompt: "Hello, Claude!",
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "prompt_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"prompt": "Hello, Claude!",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RecordToolUsage_logs_when_enabled",
|
||||
structuredLogging: true,
|
||||
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
|
||||
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeToolUsage{
|
||||
ID: uuid.New(),
|
||||
InterceptionID: intcID,
|
||||
}, nil)
|
||||
},
|
||||
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
|
||||
_, err := srv.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
|
||||
InterceptionId: intcID.String(),
|
||||
MsgId: "msg_123",
|
||||
ServerUrl: strPtr("https://api.example.com"),
|
||||
Tool: "read_file",
|
||||
Input: `{"path": "/etc/hosts"}`,
|
||||
Injected: true,
|
||||
InvocationError: strPtr("permission denied"),
|
||||
Metadata: metadataProto,
|
||||
CreatedAt: timestamppb.Now(),
|
||||
})
|
||||
return err
|
||||
},
|
||||
expectedFields: map[string]any{
|
||||
"record_type": "tool_usage",
|
||||
"interception_id": interceptionID.String(),
|
||||
"tool": "read_file",
|
||||
"input": `{"path": "/etc/hosts"}`,
|
||||
"injected": true,
|
||||
"invocation_error": "permission denied",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
buf := &bytes.Buffer{}
|
||||
logger := slog.Make(slogjson.Sink(buf)).Leveled(slog.LevelDebug)
|
||||
|
||||
tc.setupMocks(db, interceptionID)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
|
||||
StructuredLogging: serpent.Bool(tc.structuredLogging),
|
||||
}, nil, requiredExperiments)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tc.recordFn(srv, ctx, interceptionID)
|
||||
if tc.expectedErr != nil {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
lines := parseLogLines(buf)
|
||||
if tc.expectedFields == nil {
|
||||
// No log expected (disabled or error case).
|
||||
require.Empty(t, lines)
|
||||
} else {
|
||||
matchedLines := getLogLinesWithMessage(lines, aibridgedserver.InterceptionLogMarker)
|
||||
require.Len(t, matchedLines, 1, "expected exactly one log line with message %q", aibridgedserver.InterceptionLogMarker)
|
||||
|
||||
fields := matchedLines[0].Fields
|
||||
for key, expected := range tc.expectedFields {
|
||||
require.Equal(t, expected, fields[key], "field %q mismatch", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import "log/slog"
|
||||
|
||||
// LogAuditor implements proxy.Auditor by logging to slog
|
||||
type LogAuditor struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLogAuditor creates a new LogAuditor
|
||||
func NewLogAuditor(logger *slog.Logger) *LogAuditor {
|
||||
return &LogAuditor{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AuditRequest logs the request using structured logging
|
||||
func (a *LogAuditor) AuditRequest(req Request) {
|
||||
if req.Allowed {
|
||||
a.logger.Info("ALLOW",
|
||||
"method", req.Method,
|
||||
"url", req.URL,
|
||||
"host", req.Host,
|
||||
"rule", req.Rule)
|
||||
} else {
|
||||
a.logger.Warn("DENY",
|
||||
"method", req.Method,
|
||||
"url", req.URL,
|
||||
"host", req.Host,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import "testing"
|
||||
|
||||
// Stub test file - tests removed
|
||||
func TestStub(t *testing.T) {
|
||||
// This is a stub test
|
||||
t.Skip("stub test file")
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// MultiAuditor wraps multiple auditors and sends audit events to all of them.
|
||||
type MultiAuditor struct {
|
||||
auditors []Auditor
|
||||
}
|
||||
|
||||
// NewMultiAuditor creates a new MultiAuditor that sends to all provided auditors.
|
||||
func NewMultiAuditor(auditors ...Auditor) *MultiAuditor {
|
||||
return &MultiAuditor{auditors: auditors}
|
||||
}
|
||||
|
||||
// AuditRequest sends the request to all wrapped auditors.
|
||||
func (m *MultiAuditor) AuditRequest(req Request) {
|
||||
for _, a := range m.auditors {
|
||||
a.AuditRequest(req)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupAuditor creates and configures the appropriate auditors based on the
|
||||
// provided configuration. It always includes a LogAuditor for stderr logging,
|
||||
// and conditionally adds a SocketAuditor if audit logs are enabled and the
|
||||
// workspace agent's log proxy socket exists.
|
||||
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) {
|
||||
stderrAuditor := NewLogAuditor(logger)
|
||||
auditors := []Auditor{stderrAuditor}
|
||||
|
||||
if !disableAuditLogs {
|
||||
if logProxySocketPath == "" {
|
||||
return nil, xerrors.New("log proxy socket path is undefined")
|
||||
}
|
||||
// Since boundary is separately versioned from a Coder deployment, it's possible
|
||||
// Coder is on an older version that will not create the socket and listen for
|
||||
// the audit logs. Here we check for the socket to determine if the workspace
|
||||
// agent is on a new enough version to prevent boundary application log spam from
|
||||
// trying to connect to the agent. This assumes the agent will run and start the
|
||||
// log proxy server before boundary runs.
|
||||
_, err := os.Stat(logProxySocketPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("failed to stat log proxy socket: %v", err)
|
||||
}
|
||||
agentWillProxy := !os.IsNotExist(err)
|
||||
if agentWillProxy {
|
||||
socketAuditor := NewSocketAuditor(logger, logProxySocketPath)
|
||||
go socketAuditor.Loop(ctx)
|
||||
auditors = append(auditors, socketAuditor)
|
||||
} else {
|
||||
logger.Warn("Audit logs are disabled; workspace agent has not created log proxy socket",
|
||||
"socket", logProxySocketPath)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Audit logs are disabled by configuration")
|
||||
}
|
||||
|
||||
return NewMultiAuditor(auditors...), nil
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockAuditor struct {
|
||||
onAudit func(req Request)
|
||||
}
|
||||
|
||||
func (m *mockAuditor) AuditRequest(req Request) {
|
||||
if m.onAudit != nil {
|
||||
m.onAudit(req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_DisabledAuditLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, true, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 1 {
|
||||
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_EmptySocketPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := SetupAuditor(ctx, logger, false, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty socket path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_SocketDoesNotExist(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 1 {
|
||||
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_SocketExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a temporary file to simulate the socket existing
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
f, err := os.Create(socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, false, socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 2 {
|
||||
t.Errorf("expected 2 auditors, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected first auditor to be *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[1].(*SocketAuditor); !ok {
|
||||
t.Errorf("expected second auditor to be *SocketAuditor, got %T", multi.auditors[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiAuditor_AuditRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var called1, called2 bool
|
||||
auditor1 := &mockAuditor{onAudit: func(req Request) { called1 = true }}
|
||||
auditor2 := &mockAuditor{onAudit: func(req Request) { called2 = true }}
|
||||
|
||||
multi := NewMultiAuditor(auditor1, auditor2)
|
||||
multi.AuditRequest(Request{Method: "GET", URL: "https://example.com"})
|
||||
|
||||
if !called1 {
|
||||
t.Error("expected first auditor to be called")
|
||||
}
|
||||
if !called2 {
|
||||
t.Error("expected second auditor to be called")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
type Auditor interface {
|
||||
AuditRequest(req Request)
|
||||
}
|
||||
|
||||
// Request represents information about an HTTP request for auditing
|
||||
type Request struct {
|
||||
Method string
|
||||
URL string // The fully qualified request URL (scheme, domain, optional path).
|
||||
Host string
|
||||
Allowed bool
|
||||
Rule string // The rule that matched (if any)
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// The batch size and timer duration are chosen to provide reasonable responsiveness
|
||||
// for consumers of the aggregated logs while still minimizing the agent <-> coderd
|
||||
// network I/O when an AI agent is actively making network requests.
|
||||
defaultBatchSize = 10
|
||||
defaultBatchTimerDuration = 5 * time.Second
|
||||
)
|
||||
|
||||
// SocketAuditor implements the Auditor interface. It sends logs to the
|
||||
// workspace agent's boundary log proxy socket. It queues logs and sends
|
||||
// them in batches using a batch size and timer. The internal queue operates
|
||||
// as a FIFO i.e., logs are sent in the order they are received and dropped
|
||||
// if the queue is full.
|
||||
type SocketAuditor struct {
|
||||
dial func() (net.Conn, error)
|
||||
logger *slog.Logger
|
||||
logCh chan *agentproto.BoundaryLog
|
||||
batchSize int
|
||||
batchTimerDuration time.Duration
|
||||
socketPath string
|
||||
|
||||
// onFlushAttempt is called after each flush attempt (intended for testing).
|
||||
onFlushAttempt func()
|
||||
}
|
||||
|
||||
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
|
||||
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
|
||||
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
|
||||
func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
|
||||
// This channel buffer size intends to allow enough buffering for bursty
|
||||
// AI agent network requests while a batch is being sent to the workspace
|
||||
// agent.
|
||||
const logChBufSize = 2 * defaultBatchSize
|
||||
|
||||
return &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, logChBufSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
socketPath: socketPath,
|
||||
}
|
||||
}
|
||||
|
||||
// AuditRequest implements the Auditor interface. It queues the log to be sent to the
|
||||
// agent in a batch.
|
||||
func (s *SocketAuditor) AuditRequest(req Request) {
|
||||
httpReq := &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: req.Method,
|
||||
Url: req.URL,
|
||||
}
|
||||
// Only include the matched rule for allowed requests. Boundary is deny by
|
||||
// default, so rules are what allow requests.
|
||||
if req.Allowed {
|
||||
httpReq.MatchedRule = req.Rule
|
||||
}
|
||||
|
||||
log := &agentproto.BoundaryLog{
|
||||
Allowed: req.Allowed,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
|
||||
}
|
||||
|
||||
select {
|
||||
case s.logCh <- log:
|
||||
default:
|
||||
s.logger.Warn("audit log dropped, channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// flushErr represents an error from flush, distinguishing between
|
||||
// permanent errors (bad data) and transient errors (network issues).
|
||||
type flushErr struct {
|
||||
err error
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (e *flushErr) Error() string { return e.err.Error() }
|
||||
|
||||
// flush sends the current batch of logs to the given connection.
|
||||
func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr {
|
||||
if len(logs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: logs,
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return &flushErr{err: err, permanent: true}
|
||||
}
|
||||
|
||||
err = codec.WriteFrame(conn, codec.TagV1, data)
|
||||
if err != nil {
|
||||
return &flushErr{err: xerrors.Errorf("write frame: %x", err)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop handles the I/O to send audit logs to the agent.
|
||||
func (s *SocketAuditor) Loop(ctx context.Context) {
|
||||
var conn net.Conn
|
||||
batch := make([]*agentproto.BoundaryLog, 0, s.batchSize)
|
||||
t := time.NewTimer(0)
|
||||
t.Stop()
|
||||
|
||||
// connect attempts to establish a connection to the socket.
|
||||
connect := func() {
|
||||
if conn != nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
conn, err = s.dial()
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to connect to audit socket", "path", s.socketPath, "error", err)
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeConn closes the current connection if open.
|
||||
closeConn := func() {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// clearBatch resets the length of the batch and frees memory while preserving
|
||||
// the batch slice backing array.
|
||||
clearBatch := func() {
|
||||
for i := range len(batch) {
|
||||
batch[i] = nil
|
||||
}
|
||||
batch = batch[:0]
|
||||
}
|
||||
|
||||
// doFlush flushes the batch and handles errors by reconnecting.
|
||||
doFlush := func() {
|
||||
t.Stop()
|
||||
defer func() {
|
||||
if s.onFlushAttempt != nil {
|
||||
s.onFlushAttempt()
|
||||
}
|
||||
}()
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
connect()
|
||||
if conn == nil {
|
||||
// No connection: logs will be retried on next flush.
|
||||
s.logger.Warn("no connection to flush; resetting batch timer",
|
||||
"duration_sec", s.batchTimerDuration.Seconds(),
|
||||
"batch_size", len(batch))
|
||||
// Reset the timer so we aren't stuck waiting for the batch to fill
|
||||
// or a new log to arrive before the next attempt.
|
||||
t.Reset(s.batchTimerDuration)
|
||||
return
|
||||
}
|
||||
|
||||
if err := flush(conn, batch); err != nil {
|
||||
if err.permanent {
|
||||
// Data error: discard batch to avoid infinite retries.
|
||||
s.logger.Warn("dropping batch due to data error on flush attempt",
|
||||
"error", err, "batch_size", len(batch))
|
||||
clearBatch()
|
||||
} else {
|
||||
// Network error: close connection but keep batch and retry.
|
||||
s.logger.Warn("failed to flush audit logs; resetting batch timer to reconnect and retry",
|
||||
"error", err, "duration_sec", s.batchTimerDuration.Seconds(),
|
||||
"batch_size", len(batch))
|
||||
closeConn()
|
||||
// Reset the timer so we aren't stuck waiting for a new log to
|
||||
// arrive before the next attempt.
|
||||
t.Reset(s.batchTimerDuration)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
clearBatch()
|
||||
}
|
||||
|
||||
connect()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Drain any pending logs before the last flush. Not concerned about
|
||||
// growing the batch slice here since we're exiting.
|
||||
drain:
|
||||
for {
|
||||
select {
|
||||
case log := <-s.logCh:
|
||||
batch = append(batch, log)
|
||||
default:
|
||||
break drain
|
||||
}
|
||||
}
|
||||
|
||||
doFlush()
|
||||
closeConn()
|
||||
return
|
||||
case <-t.C:
|
||||
doFlush()
|
||||
case log := <-s.logCh:
|
||||
// If batch is at capacity, attempt flushing first and drop the log if
|
||||
// the batch still full.
|
||||
if len(batch) >= s.batchSize {
|
||||
doFlush()
|
||||
if len(batch) >= s.batchSize {
|
||||
s.logger.Warn("audit log dropped, batch full")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
batch = append(batch, log)
|
||||
|
||||
if len(batch) == 1 {
|
||||
t.Reset(s.batchTimerDuration)
|
||||
}
|
||||
|
||||
if len(batch) >= s.batchSize {
|
||||
doFlush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
func TestSocketAuditor_AuditRequest_QueuesLog(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
auditor.AuditRequest(Request{
|
||||
Method: "GET",
|
||||
URL: "https://example.com",
|
||||
Host: "example.com",
|
||||
Allowed: true,
|
||||
Rule: "allow-all",
|
||||
})
|
||||
|
||||
select {
|
||||
case log := <-auditor.logCh:
|
||||
if log.Allowed != true {
|
||||
t.Errorf("expected Allowed=true, got %v", log.Allowed)
|
||||
}
|
||||
httpReq := log.GetHttpRequest()
|
||||
if httpReq == nil {
|
||||
t.Fatal("expected HttpRequest, got nil")
|
||||
}
|
||||
if httpReq.Method != "GET" {
|
||||
t.Errorf("expected Method=GET, got %s", httpReq.Method)
|
||||
}
|
||||
if httpReq.Url != "https://example.com" {
|
||||
t.Errorf("expected URL=https://example.com, got %s", httpReq.Url)
|
||||
}
|
||||
// Rule should be set for allowed requests
|
||||
if httpReq.MatchedRule != "allow-all" {
|
||||
t.Errorf("unexpected MatchedRule %v", httpReq.MatchedRule)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected log in channel, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_AuditRequest_AllowIncludesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
auditor.AuditRequest(Request{
|
||||
Method: "POST",
|
||||
URL: "https://evil.com",
|
||||
Host: "evil.com",
|
||||
Allowed: true,
|
||||
Rule: "allow-evil",
|
||||
})
|
||||
|
||||
select {
|
||||
case log := <-auditor.logCh:
|
||||
if log.Allowed != true {
|
||||
t.Errorf("expected Allowed=false, got %v", log.Allowed)
|
||||
}
|
||||
httpReq := log.GetHttpRequest()
|
||||
if httpReq == nil {
|
||||
t.Fatal("expected HttpRequest, got nil")
|
||||
}
|
||||
if httpReq.MatchedRule != "allow-evil" {
|
||||
t.Errorf("expected MatchedRule=allow-evil, got %s", httpReq.MatchedRule)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected log in channel, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_AuditRequest_DropsWhenFull(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
// Fill the channel (capacity is 2*batchSize = 20)
|
||||
for i := 0; i < 2*auditor.batchSize; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
}
|
||||
|
||||
// This should not block and drop the log
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://dropped.com", Allowed: true})
|
||||
|
||||
// Drain the channel and verify all entries are from the original batch (dropped.com was dropped)
|
||||
for i := 0; i < 2*auditor.batchSize; i++ {
|
||||
v := <-auditor.logCh
|
||||
resource, ok := v.Resource.(*agentproto.BoundaryLog_HttpRequest_)
|
||||
if !ok {
|
||||
t.Fatal("unexpected resource type")
|
||||
}
|
||||
if resource.HttpRequest.Url != "https://example.com" {
|
||||
t.Errorf("expected batch to be FIFO, got %s", resource.HttpRequest.Url)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case v := <-auditor.logCh:
|
||||
t.Errorf("expected empty channel, got %v", v)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnBatchSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
auditor.batchTimerDuration = time.Hour // Ensure timer doesn't interfere with the test
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// Send exactly a full batch of logs to trigger a flush
|
||||
for i := 0; i < auditor.batchSize; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
}
|
||||
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != auditor.batchSize {
|
||||
t.Errorf("expected %d logs, got %d", auditor.batchSize, len(req.Logs))
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
auditor.batchTimerDuration = 3 * time.Second
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// A single log should start the timer
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
|
||||
// Should flush after the timer duration elapses
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != 1 {
|
||||
t.Errorf("expected 1 log, got %d", len(req.Logs))
|
||||
}
|
||||
case <-time.After(2 * auditor.batchTimerDuration):
|
||||
t.Fatal("timeout waiting for timer flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnContextCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
// Make the timer long to always exercise the context cancellation case
|
||||
auditor.batchTimerDuration = time.Hour
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
auditor.Loop(ctx)
|
||||
}()
|
||||
|
||||
// Send a log but don't fill the batch
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != 1 {
|
||||
t.Errorf("expected 1 log, got %d", len(req.Logs))
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for shutdown flush")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
err := clientConn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("close client connection: %v", err)
|
||||
}
|
||||
err = serverConn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("close server connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
var dialCount atomic.Int32
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
auditor := &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
// First dial attempt fails, subsequent ones succeed
|
||||
if dialCount.Add(1) == 1 {
|
||||
return nil, xerrors.New("connection refused")
|
||||
}
|
||||
return clientConn, nil
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test
|
||||
}
|
||||
|
||||
// Set up hook to detect flush attempts
|
||||
flushed := make(chan struct{}, 1)
|
||||
auditor.onFlushAttempt = func() {
|
||||
select {
|
||||
case flushed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// Send batchSize+1 logs so we can verify the last log here gets dropped.
|
||||
for i := 0; i < auditor.batchSize+1; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://servernotup.com", Allowed: true})
|
||||
}
|
||||
|
||||
// Wait for the first flush attempt (which will fail)
|
||||
select {
|
||||
case <-flushed:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for first flush attempt")
|
||||
}
|
||||
|
||||
// Send one more log - batch is at capacity, so this triggers flush first
|
||||
// The flush succeeds (dial now works), sending the retained batch.
|
||||
auditor.AuditRequest(Request{Method: "POST", URL: "https://serverup.com", Allowed: true})
|
||||
|
||||
// Should receive the retained batch (the new log goes into a fresh batch)
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != auditor.batchSize {
|
||||
t.Errorf("expected %d logs from retry, got %d", auditor.batchSize, len(req.Logs))
|
||||
}
|
||||
for _, log := range req.Logs {
|
||||
resource, ok := log.Resource.(*agentproto.BoundaryLog_HttpRequest_)
|
||||
if !ok {
|
||||
t.Fatal("unexpected resource type")
|
||||
}
|
||||
if resource.HttpRequest.Url != "https://servernotup.com" {
|
||||
t.Errorf("expected URL https://servernotup.com, got %v", resource.HttpRequest.Url)
|
||||
}
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for retry flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlush_EmptyBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := flush(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for empty batch, got %v", err)
|
||||
}
|
||||
|
||||
err = flush(nil, []*agentproto.BoundaryLog{})
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for empty slice, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// setupSocketAuditor creates a SocketAuditor for tests that only exercise
|
||||
// the queueing behavior (no connection needed).
|
||||
func setupSocketAuditor(t *testing.T) *SocketAuditor {
|
||||
t.Helper()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
return &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return nil, xerrors.New("not connected")
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// setupTestAuditor creates a SocketAuditor with an in-memory connection using
|
||||
// net.Pipe(). Returns the auditor and the server-side connection for reading.
|
||||
func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
err := clientConn.Close()
|
||||
if err != nil {
|
||||
t.Error("Failed to close client connection", "error", err)
|
||||
}
|
||||
err = serverConn.Close()
|
||||
if err != nil {
|
||||
t.Error("Failed to close server connection", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
auditor := &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return clientConn, nil
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
}
|
||||
|
||||
return auditor, serverConn
|
||||
}
|
||||
|
||||
// readFromConn reads length-prefixed protobuf messages from a connection and
|
||||
// sends them to the received channel.
|
||||
func readFromConn(t *testing.T, conn net.Conn, received chan<- *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
buf := make([]byte, 1<<10)
|
||||
for {
|
||||
tag, data, err := codec.ReadFrame(conn, buf)
|
||||
if err != nil {
|
||||
return // connection closed
|
||||
}
|
||||
|
||||
if tag != codec.TagV1 {
|
||||
t.Errorf("invalid tag: %d", tag)
|
||||
}
|
||||
|
||||
var req agentproto.ReportBoundaryLogsRequest
|
||||
if err := proto.Unmarshal(data, &req); err != nil {
|
||||
t.Errorf("failed to unmarshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
received <- &req
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
//go:build linux
|
||||
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
|
||||
package boundary
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/log"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/run"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// printVersion prints version information.
|
||||
func printVersion(version string) {
|
||||
fmt.Println(version)
|
||||
}
|
||||
|
||||
// NewCommand creates and returns the root serpent command
|
||||
func NewCommand(version string) *serpent.Command {
|
||||
// To make the top level boundary command, we just make some minor changes to the base command
|
||||
cmd := BaseCommand(version)
|
||||
cmd.Use = "boundary [flags] -- command [args...]" // Add the flags and args pieces to usage.
|
||||
|
||||
// Add example usage to the long description. This is different from usage as a subcommand because it
|
||||
// may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder boundary` vs. `boundary`).
|
||||
cmd.Long += `Examples:
|
||||
# Allow only requests to github.com
|
||||
boundary --allow "domain=github.com" -- curl https://github.com
|
||||
|
||||
# Monitor all requests to specific domains (allow only those)
|
||||
boundary --allow "domain=github.com path=/api/issues/*" --allow "method=GET,HEAD domain=github.com" -- npm install
|
||||
|
||||
# Use allowlist from config file with additional CLI allow rules
|
||||
boundary --allow "domain=example.com" -- curl https://example.com
|
||||
|
||||
# Block everything by default (implicit)`
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Base command returns the boundary serpent command without the information involved in making it the
|
||||
// *top level* serpent command. We are creating this split to make it easier to integrate into the coder
|
||||
// CLI if needed.
|
||||
func BaseCommand(version string) *serpent.Command {
|
||||
cliConfig := config.CliConfig{}
|
||||
var showVersion serpent.Bool
|
||||
|
||||
// Set default config path if file exists - serpent will load it automatically
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
defaultPath := filepath.Join(home, ".config", "coder_boundary", "config.yaml")
|
||||
if _, err := os.Stat(defaultPath); err == nil {
|
||||
cliConfig.Config = serpent.YAMLConfigPath(defaultPath)
|
||||
}
|
||||
}
|
||||
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
|
||||
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
|
||||
Options: []serpent.Option{
|
||||
{
|
||||
Flag: "config",
|
||||
Env: "BOUNDARY_CONFIG",
|
||||
Description: "Path to YAML config file.",
|
||||
Value: &cliConfig.Config,
|
||||
YAML: "",
|
||||
},
|
||||
{
|
||||
Flag: "allow",
|
||||
Env: "BOUNDARY_ALLOW",
|
||||
Description: "Allow rule (repeatable). These are merged with allowlist from config file. Format: \"pattern\" or \"METHOD[,METHOD] pattern\".",
|
||||
Value: &cliConfig.AllowStrings,
|
||||
YAML: "", // CLI only, not loaded from YAML
|
||||
},
|
||||
{
|
||||
Flag: "allowlist",
|
||||
Description: "Allowlist rules from config file (YAML only).",
|
||||
Value: &cliConfig.AllowListStrings,
|
||||
YAML: "allowlist",
|
||||
Hidden: true, // Hidden because it's primarily for YAML config
|
||||
},
|
||||
{
|
||||
Flag: "log-level",
|
||||
Env: "BOUNDARY_LOG_LEVEL",
|
||||
Description: "Set log level (error, warn, info, debug).",
|
||||
Default: "warn",
|
||||
Value: &cliConfig.LogLevel,
|
||||
YAML: "log_level",
|
||||
},
|
||||
{
|
||||
Flag: "log-dir",
|
||||
Env: "BOUNDARY_LOG_DIR",
|
||||
Description: "Set a directory to write logs to rather than stderr.",
|
||||
Value: &cliConfig.LogDir,
|
||||
YAML: "log_dir",
|
||||
},
|
||||
{
|
||||
Flag: "proxy-port",
|
||||
Env: "PROXY_PORT",
|
||||
Description: "Set a port for HTTP proxy.",
|
||||
Default: "8080",
|
||||
Value: &cliConfig.ProxyPort,
|
||||
YAML: "proxy_port",
|
||||
},
|
||||
{
|
||||
Flag: "pprof",
|
||||
Env: "BOUNDARY_PPROF",
|
||||
Description: "Enable pprof profiling server.",
|
||||
Value: &cliConfig.PprofEnabled,
|
||||
YAML: "pprof_enabled",
|
||||
},
|
||||
{
|
||||
Flag: "pprof-port",
|
||||
Env: "BOUNDARY_PPROF_PORT",
|
||||
Description: "Set port for pprof profiling server.",
|
||||
Default: "6060",
|
||||
Value: &cliConfig.PprofPort,
|
||||
YAML: "pprof_port",
|
||||
},
|
||||
{
|
||||
Flag: "configure-dns-for-local-stub-resolver",
|
||||
Env: "BOUNDARY_CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER",
|
||||
Description: "Configure DNS for local stub resolver (e.g., systemd-resolved). Only needed when /etc/resolv.conf contains nameserver 127.0.0.53.",
|
||||
Value: &cliConfig.ConfigureDNSForLocalStubResolver,
|
||||
YAML: "configure_dns_for_local_stub_resolver",
|
||||
},
|
||||
{
|
||||
Flag: "jail-type",
|
||||
Env: "BOUNDARY_JAIL_TYPE",
|
||||
Description: "Jail type to use for network isolation. Options: nsjail (default), landjail.",
|
||||
Default: "nsjail",
|
||||
Value: &cliConfig.JailType,
|
||||
YAML: "jail_type",
|
||||
},
|
||||
{
|
||||
Flag: "disable-audit-logs",
|
||||
Env: "DISABLE_AUDIT_LOGS",
|
||||
Description: "Disable sending of audit logs to the workspace agent when set to true.",
|
||||
Value: &cliConfig.DisableAuditLogs,
|
||||
YAML: "disable_audit_logs",
|
||||
},
|
||||
{
|
||||
Flag: "log-proxy-socket-path",
|
||||
Description: "Path to the socket where the boundary log proxy server listens for audit logs.",
|
||||
// Important: this default must be the same default path used by the
|
||||
// workspace agent to ensure agreement of the default socket path without
|
||||
// explicit configuration.
|
||||
Default: boundarylogproxy.DefaultSocketPath(),
|
||||
// Important: this must be the same variable name used by the workspace agent
|
||||
// to allow a single environment variable to configure both boundary and the
|
||||
// workspace agent.
|
||||
Env: "CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH",
|
||||
Value: &cliConfig.LogProxySocketPath,
|
||||
YAML: "", // CLI only, not loaded from YAML
|
||||
},
|
||||
{
|
||||
Flag: "version",
|
||||
Description: "Print version information and exit.",
|
||||
Value: &showVersion,
|
||||
YAML: "", // CLI only
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
// Handle --version flag early
|
||||
if showVersion.Value() {
|
||||
printVersion(version)
|
||||
return nil
|
||||
}
|
||||
appConfig, err := config.NewAppConfigFromCliConfig(cliConfig, inv.Args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse cli config file: %v", err)
|
||||
}
|
||||
|
||||
// Get command arguments
|
||||
if len(appConfig.TargetCMD) == 0 {
|
||||
return fmt.Errorf("no command specified")
|
||||
}
|
||||
|
||||
logger, err := log.SetupLogging(appConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not set up logging: %v", err)
|
||||
}
|
||||
|
||||
appConfigInJSON, err := json.Marshal(appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Debug("Application config", "config", appConfigInJSON)
|
||||
|
||||
return run.Run(inv.Context(), logger, appConfig)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
//go:build !linux
|
||||
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
|
||||
package boundary
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// BaseCommand returns the boundary serpent command. On non-Linux platforms,
|
||||
// boundary is not supported and returns an error.
|
||||
func BaseCommand(_ string) *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
|
||||
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
return xerrors.Errorf("boundary is only supported on Linux (current OS: %s)", runtime.GOOS)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// JailType represents the type of jail to use for network isolation
|
||||
type JailType string
|
||||
|
||||
const (
|
||||
NSJailType JailType = "nsjail"
|
||||
LandjailType JailType = "landjail"
|
||||
)
|
||||
|
||||
func NewJailTypeFromString(str string) (JailType, error) {
|
||||
switch str {
|
||||
case "nsjail":
|
||||
return NSJailType, nil
|
||||
case "landjail":
|
||||
return LandjailType, nil
|
||||
default:
|
||||
return NSJailType, xerrors.Errorf("invalid JailType: %s", str)
|
||||
}
|
||||
}
|
||||
|
||||
// AllowStringsArray is a custom type that implements pflag.Value to support
|
||||
// repeatable --allow flags without splitting on commas. This allows comma-separated
|
||||
// paths within a single allow rule (e.g., "path=/todos/1,/todos/2").
|
||||
type AllowStringsArray []string
|
||||
|
||||
var _ pflag.Value = (*AllowStringsArray)(nil)
|
||||
|
||||
// Set implements pflag.Value. It appends the value to the slice without splitting on commas.
|
||||
func (a *AllowStringsArray) Set(value string) error {
|
||||
*a = append(*a, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements pflag.Value.
|
||||
func (a AllowStringsArray) String() string {
|
||||
return strings.Join(a, ",")
|
||||
}
|
||||
|
||||
// Type implements pflag.Value.
|
||||
func (a AllowStringsArray) Type() string {
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Value returns the underlying slice of strings.
|
||||
func (a AllowStringsArray) Value() []string {
|
||||
return []string(a)
|
||||
}
|
||||
|
||||
type CliConfig struct {
|
||||
Config serpent.YAMLConfigPath `yaml:"-"`
|
||||
AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file
|
||||
AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only
|
||||
LogLevel serpent.String `yaml:"log_level"`
|
||||
LogDir serpent.String `yaml:"log_dir"`
|
||||
ProxyPort serpent.Int64 `yaml:"proxy_port"`
|
||||
PprofEnabled serpent.Bool `yaml:"pprof_enabled"`
|
||||
PprofPort serpent.Int64 `yaml:"pprof_port"`
|
||||
ConfigureDNSForLocalStubResolver serpent.Bool `yaml:"configure_dns_for_local_stub_resolver"`
|
||||
JailType serpent.String `yaml:"jail_type"`
|
||||
DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"`
|
||||
LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
AllowRules []string
|
||||
LogLevel string
|
||||
LogDir string
|
||||
ProxyPort int64
|
||||
PprofEnabled bool
|
||||
PprofPort int64
|
||||
ConfigureDNSForLocalStubResolver bool
|
||||
JailType JailType
|
||||
TargetCMD []string
|
||||
UserInfo *UserInfo
|
||||
DisableAuditLogs bool
|
||||
LogProxySocketPath string
|
||||
}
|
||||
|
||||
func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) {
|
||||
// Merge allowlist from config file with allow from CLI flags
|
||||
allowListStrings := cfg.AllowListStrings.Value()
|
||||
allowStrings := cfg.AllowStrings.Value()
|
||||
|
||||
// Combine allowlist (config file) with allow (CLI flags)
|
||||
allAllowStrings := append(allowListStrings, allowStrings...)
|
||||
|
||||
jailType, err := NewJailTypeFromString(cfg.JailType.Value())
|
||||
if err != nil {
|
||||
return AppConfig{}, err
|
||||
}
|
||||
|
||||
userInfo := GetUserInfo()
|
||||
|
||||
return AppConfig{
|
||||
AllowRules: allAllowStrings,
|
||||
LogLevel: cfg.LogLevel.Value(),
|
||||
LogDir: cfg.LogDir.Value(),
|
||||
ProxyPort: cfg.ProxyPort.Value(),
|
||||
PprofEnabled: cfg.PprofEnabled.Value(),
|
||||
PprofPort: cfg.PprofPort.Value(),
|
||||
ConfigureDNSForLocalStubResolver: cfg.ConfigureDNSForLocalStubResolver.Value(),
|
||||
JailType: jailType,
|
||||
TargetCMD: targetCMD,
|
||||
UserInfo: userInfo,
|
||||
DisableAuditLogs: cfg.DisableAuditLogs.Value(),
|
||||
LogProxySocketPath: cfg.LogProxySocketPath.Value(),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
CAKeyName = "ca-key.pem"
|
||||
CACertName = "ca-cert.pem"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
SudoUser string
|
||||
Uid int
|
||||
Gid int
|
||||
HomeDir string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
// GetUserInfo returns information about the current user, handling sudo scenarios
|
||||
func GetUserInfo() *UserInfo {
|
||||
// Only consider SUDO_USER if we're actually running with elevated privileges
|
||||
// In environments like Coder workspaces, SUDO_USER may be set to 'root'
|
||||
// but we're not actually running under sudo
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" && os.Geteuid() == 0 && sudoUser != "root" {
|
||||
// We're actually running under sudo with a non-root original user
|
||||
user, err := user.Lookup(sudoUser)
|
||||
if err != nil {
|
||||
return getCurrentUserInfo() // Fallback to current user
|
||||
}
|
||||
|
||||
uid, _ := strconv.Atoi(os.Getenv("SUDO_UID"))
|
||||
gid, _ := strconv.Atoi(os.Getenv("SUDO_GID"))
|
||||
|
||||
// If we couldn't get UID/GID from env, parse from user info
|
||||
if uid == 0 {
|
||||
if parsedUID, err := strconv.Atoi(user.Uid); err == nil {
|
||||
uid = parsedUID
|
||||
}
|
||||
}
|
||||
if gid == 0 {
|
||||
if parsedGID, err := strconv.Atoi(user.Gid); err == nil {
|
||||
gid = parsedGID
|
||||
}
|
||||
}
|
||||
|
||||
configDir := getConfigDir(user.HomeDir)
|
||||
|
||||
return &UserInfo{
|
||||
SudoUser: sudoUser,
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
HomeDir: user.HomeDir,
|
||||
ConfigDir: configDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Not actually running under sudo, use current user
|
||||
return getCurrentUserInfo()
|
||||
}
|
||||
|
||||
// getCurrentUserInfo gets information for the current user
|
||||
func getCurrentUserInfo() *UserInfo {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
// Fallback with empty values if we can't get user info
|
||||
return &UserInfo{}
|
||||
}
|
||||
|
||||
uid, _ := strconv.Atoi(currentUser.Uid)
|
||||
gid, _ := strconv.Atoi(currentUser.Gid)
|
||||
|
||||
configDir := getConfigDir(currentUser.HomeDir)
|
||||
|
||||
return &UserInfo{
|
||||
SudoUser: currentUser.Username,
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
HomeDir: currentUser.HomeDir,
|
||||
ConfigDir: configDir,
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigDir determines the config directory based on XDG_CONFIG_HOME or fallback
|
||||
func getConfigDir(homeDir string) string {
|
||||
// Use XDG_CONFIG_HOME if set, otherwise fallback to ~/.config/coder_boundary
|
||||
if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" {
|
||||
return filepath.Join(xdgConfigHome, "coder_boundary")
|
||||
}
|
||||
return filepath.Join(homeDir, ".config", "coder_boundary")
|
||||
}
|
||||
|
||||
func (u *UserInfo) CAKeyPath() string {
|
||||
return filepath.Join(u.ConfigDir, CAKeyName)
|
||||
}
|
||||
|
||||
func (u *UserInfo) CACertPath() string {
|
||||
return filepath.Join(u.ConfigDir, CACertName)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/landlock-lsm/go-landlock/landlock"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/util"
|
||||
)
|
||||
|
||||
type LandlockConfig struct {
|
||||
// TODO(yevhenii):
|
||||
// - should it be able to bind to any port?
|
||||
// - should it be able to connect to any port on localhost?
|
||||
// BindTCPPorts []int
|
||||
ConnectTCPPorts []int
|
||||
}
|
||||
|
||||
func ApplyLandlockRestrictions(logger *slog.Logger, cfg LandlockConfig) error {
|
||||
// Get the Landlock version which works for Kernel 6.7+
|
||||
llCfg := landlock.V4
|
||||
|
||||
// Collect our rules
|
||||
var netRules []landlock.Rule
|
||||
|
||||
// Add rules for TCP connections
|
||||
for _, port := range cfg.ConnectTCPPorts {
|
||||
logger.Debug("Adding TCP connect port", "port", port)
|
||||
netRules = append(netRules, landlock.ConnectTCP(uint16(port)))
|
||||
}
|
||||
|
||||
err := llCfg.RestrictNet(netRules...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply Landlock network restrictions: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunChild(logger *slog.Logger, config config.AppConfig) error {
|
||||
landjailCfg := LandlockConfig{
|
||||
ConnectTCPPorts: []int{int(config.ProxyPort)},
|
||||
}
|
||||
|
||||
err := ApplyLandlockRestrictions(logger, landjailCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply Landlock network restrictions: %v", err)
|
||||
}
|
||||
|
||||
// Build command
|
||||
cmd := exec.Command(config.TargetCMD[0], config.TargetCMD[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
logger.Info("Executing target command", "command", config.TargetCMD)
|
||||
|
||||
// Run the command - this will block until it completes
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
return fmt.Errorf("command exited with code %d", exitCode)
|
||||
}
|
||||
// This is an unexpected error
|
||||
logger.Error("Command execution failed", "error", err)
|
||||
return fmt.Errorf("command execution failed: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Command completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns environment variables intended to be set on the child process,
|
||||
// so they can later be inherited by the target process.
|
||||
func getEnvsForTargetProcess(configDir string, caCertPath string, httpProxyPort int) []string {
|
||||
e := os.Environ()
|
||||
|
||||
proxyAddr := fmt.Sprintf("http://localhost:%d", httpProxyPort)
|
||||
e = util.MergeEnvs(e, map[string]string{
|
||||
// Set standard CA certificate environment variables for common tools
|
||||
// This makes tools like curl, git, etc. trust our dynamically generated CA
|
||||
"SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools
|
||||
"SSL_CERT_DIR": configDir, // OpenSSL certificate directory
|
||||
"CURL_CA_BUNDLE": caCertPath, // curl
|
||||
"GIT_SSL_CAINFO": caCertPath, // Git
|
||||
"REQUESTS_CA_BUNDLE": caCertPath, // Python requests
|
||||
"NODE_EXTRA_CA_CERTS": caCertPath, // Node.js
|
||||
|
||||
"HTTP_PROXY": proxyAddr,
|
||||
"HTTPS_PROXY": proxyAddr,
|
||||
"http_proxy": proxyAddr,
|
||||
"https_proxy": proxyAddr,
|
||||
})
|
||||
|
||||
return e
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/proxy"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
)
|
||||
|
||||
type LandJail struct {
|
||||
proxyServer *proxy.Server
|
||||
logger *slog.Logger
|
||||
config config.AppConfig
|
||||
}
|
||||
|
||||
func NewLandJail(
|
||||
ruleEngine rulesengine.Engine,
|
||||
auditor audit.Auditor,
|
||||
tlsConfig *tls.Config,
|
||||
logger *slog.Logger,
|
||||
config config.AppConfig,
|
||||
) (*LandJail, error) {
|
||||
// Create proxy server
|
||||
proxyServer := proxy.NewProxyServer(proxy.Config{
|
||||
HTTPPort: int(config.ProxyPort),
|
||||
RuleEngine: ruleEngine,
|
||||
Auditor: auditor,
|
||||
Logger: logger,
|
||||
TLSConfig: tlsConfig,
|
||||
PprofEnabled: config.PprofEnabled,
|
||||
PprofPort: int(config.PprofPort),
|
||||
})
|
||||
|
||||
return &LandJail{
|
||||
config: config,
|
||||
proxyServer: proxyServer,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *LandJail) Run(ctx context.Context) error {
|
||||
b.logger.Info("Start landjail manager")
|
||||
err := b.startProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start landjail manager: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
b.logger.Info("Stop landjail manager")
|
||||
err := b.stopProxy()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop landjail manager", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
err := b.RunChildProcess(os.Args)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to run child process", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Setup signal handling BEFORE any setup
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal or context cancellation
|
||||
select {
|
||||
case sig := <-sigChan:
|
||||
b.logger.Info("Received signal, shutting down...", "signal", sig)
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
// Context canceled by command completion
|
||||
b.logger.Info("Command completed, shutting down...")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) RunChildProcess(command []string) error {
|
||||
childCmd := b.getChildCommand(command)
|
||||
|
||||
b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " "))
|
||||
err := childCmd.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Command failed to start", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
b.logger.Debug("waiting on a child process to finish")
|
||||
err = childCmd.Wait()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
// Log at debug level for non-zero exits (normal behavior)
|
||||
b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
return err
|
||||
}
|
||||
|
||||
// This is an unexpected error (not just a non-zero exit)
|
||||
b.logger.Error("Command execution failed", "error", err)
|
||||
return err
|
||||
}
|
||||
b.logger.Debug("Command completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) getChildCommand(command []string) *exec.Cmd {
|
||||
cmd := exec.Command(command[0], command[1:]...)
|
||||
// Set env vars for the child process; they will be inherited by the target process.
|
||||
cmd.Env = getEnvsForTargetProcess(b.config.UserInfo.ConfigDir, b.config.UserInfo.CACertPath(), int(b.config.ProxyPort))
|
||||
cmd.Env = append(cmd.Env, "CHILD=true")
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (b *LandJail) startProxy() error {
|
||||
// Start proxy server in background
|
||||
err := b.proxyServer.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Proxy server error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Give proxy time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) stopProxy() error {
|
||||
// Stop proxy server
|
||||
if b.proxyServer != nil {
|
||||
err := b.proxyServer.Stop()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop proxy server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/tls"
|
||||
)
|
||||
|
||||
func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if len(config.AllowRules) == 0 {
|
||||
logger.Warn("No allow rules specified; all network traffic will be denied by default")
|
||||
}
|
||||
|
||||
// Parse allow rules
|
||||
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse allow rules", "error", err)
|
||||
return fmt.Errorf("failed to parse allow rules: %v", err)
|
||||
}
|
||||
|
||||
// Create rule engine
|
||||
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
|
||||
|
||||
// Create auditor
|
||||
auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup auditor: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate manager
|
||||
certManager, err := tls.NewCertificateManager(tls.Config{
|
||||
Logger: logger,
|
||||
ConfigDir: config.UserInfo.ConfigDir,
|
||||
Uid: config.UserInfo.Uid,
|
||||
Gid: config.UserInfo.Gid,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to create certificate manager", "error", err)
|
||||
return fmt.Errorf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
|
||||
// Setup TLS to get cert path for jailer
|
||||
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
|
||||
}
|
||||
|
||||
landjail, err := NewLandJail(ruleEngine, auditor, tlsConfig, logger, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create landjail: %v", err)
|
||||
}
|
||||
|
||||
return landjail.Run(ctx)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
)
|
||||
|
||||
func isChild() bool {
|
||||
return os.Getenv("CHILD") == "true"
|
||||
}
|
||||
|
||||
// Run is the main entry point that determines whether to execute as a parent or child process.
|
||||
// If running as a child (CHILD env var is set), it applies landlock restrictions
|
||||
// and executes the target command. Otherwise, it runs as the parent process, sets up the proxy server,
|
||||
// and manages the child process lifecycle.
|
||||
func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if isChild() {
|
||||
return RunChild(logger, config)
|
||||
}
|
||||
|
||||
return RunParent(ctx, logger, config)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user