Compare commits

..

3 Commits

Author SHA1 Message Date
Garrett Delfosse ce2aed9002 feat: move boundary code from coder/boundary to enterprise/cli/boundary 2026-01-14 13:07:05 -05:00
Steven Masley 8d6a202ee4 chore: git ignore jetbrains run configs (#21497)
Jetbrains ide users can save their debug/test run configs to `.run`.
2026-01-14 06:51:35 -06:00
Sas Swart ffa83a4ebc docs: add documentation for coder script ordering (#21090)
This Pull request adds documentation and guidance for the Coder script
ordering feature. We:
* explain the use case, benefits, and requirements.
* provide example configuration snippets
* discuss best practices and troubleshooting

---------

Co-authored-by: Cian Johnston <cian@coder.com>
Co-authored-by: DevCats <christofer@coder.com>
2026-01-14 14:40:38 +02:00
68 changed files with 8152 additions and 1045 deletions
+1
View File
@@ -3,6 +3,7 @@
.eslintcache
.gitpod.yml
.idea
.run
**/*.swp
gotests.coverage
gotests.xml
+8
View File
@@ -211,6 +211,14 @@ issues:
- path: scripts/rules.go
linters:
- ALL
# Boundary code is imported from github.com/coder/boundary and has different
# lint standards. Suppress lint issues in this imported code.
- path: enterprise/cli/boundary/
linters:
- revive
- gocritic
- gosec
- errorlint
fix: true
max-issues-per-linter: 0
+10 -4
View File
@@ -1,12 +1,18 @@
package cli
import (
boundarycli "github.com/coder/boundary/cli"
"golang.org/x/xerrors"
"github.com/coder/serpent"
)
func (*RootCmd) boundary() *serpent.Command {
cmd := boundarycli.BaseCommand() // Package coder/boundary/cli exports a "base command" designed to be integrated as a subcommand.
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
return cmd
return &serpent.Command{
Use: "boundary",
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests (enterprise)",
Long: `boundary creates an isolated network environment for target processes. This is an enterprise feature.`,
Handler: func(_ *serpent.Invocation) error {
return xerrors.New("boundary is an enterprise feature; upgrade to use this command")
},
}
}
+3 -7
View File
@@ -5,15 +5,13 @@ import (
"github.com/stretchr/testify/assert"
boundarycli "github.com/coder/boundary/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
// Actually testing the functionality of coder/boundary takes place in the
// coder/boundary repo, since it's a dependency of coder.
// Here we want to test basically that integrating it as a subcommand doesn't break anything.
// Here we want to test that integrating boundary as a subcommand doesn't break anything.
// The full boundary functionality is tested in enterprise/cli.
func TestBoundarySubcommand(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
@@ -27,7 +25,5 @@ func TestBoundarySubcommand(t *testing.T) {
}()
// Expect the --help output to include the short description.
// We're simply confirming that `coder boundary --help` ran without a runtime error as
// a good chunk of serpents self validation logic happens at runtime.
pty.ExpectMatch(boundarycli.BaseCommand().Short)
pty.ExpectMatch("Network isolation tool")
}
-1
View File
@@ -402,7 +402,6 @@ func WorkspaceAgentDevcontainer(t testing.TB, db database.Store, orig database.W
Name: []string{takeFirst(orig.Name, testutil.GetRandomName(t))},
WorkspaceFolder: []string{takeFirst(orig.WorkspaceFolder, "/workspace")},
ConfigPath: []string{takeFirst(orig.ConfigPath, "")},
SubagentID: []uuid.UUID{takeFirst(orig.SubagentID, uuid.NullUUID{}).UUID},
})
require.NoError(t, err, "insert workspace agent devcontainer")
return devcontainers[0]
+1 -2
View File
@@ -2505,8 +2505,7 @@ CREATE TABLE workspace_agent_devcontainers (
created_at timestamp with time zone DEFAULT now() NOT NULL,
workspace_folder text NOT NULL,
config_path text NOT NULL,
name text NOT NULL,
subagent_id uuid
name text NOT NULL
);
COMMENT ON TABLE workspace_agent_devcontainers IS 'Workspace agent devcontainer configuration';
@@ -1,2 +0,0 @@
ALTER TABLE workspace_agent_devcontainers
DROP COLUMN subagent_id;
@@ -1,2 +0,0 @@
ALTER TABLE workspace_agent_devcontainers
ADD COLUMN subagent_id UUID;
+1 -2
View File
@@ -4743,8 +4743,7 @@ type WorkspaceAgentDevcontainer struct {
// Path to devcontainer.json.
ConfigPath string `db:"config_path" json:"config_path"`
// The name of the Dev Container.
Name string `db:"name" json:"name"`
SubagentID uuid.NullUUID `db:"subagent_id" json:"subagent_id"`
Name string `db:"name" json:"name"`
}
type WorkspaceAgentLog struct {
-96
View File
@@ -7989,99 +7989,3 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
require.NoError(t, err)
require.Len(t, remaining, len(unexpiredTimes))
}
func TestWorkspaceAgentDevcontainersSubagentID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
// Setup: create workspace agent
org := dbgen.Organization(t, db, database.Organization{})
user := dbgen.User(t, db, database.User{})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
CreatedBy: user.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OrganizationID: org.ID,
OwnerID: user.ID,
TemplateID: tpl.ID,
})
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
})
build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: ws.ID,
JobID: job.ID,
TemplateVersionID: tv.ID,
})
res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: build.JobID,
})
agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: res.ID,
})
// Create a subagent that will be referenced
subagent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: res.ID,
ParentID: uuid.NullUUID{UUID: agent.ID, Valid: true},
})
t.Run("InsertWithSubagentID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
devcontainers, err := db.InsertWorkspaceAgentDevcontainers(ctx, database.InsertWorkspaceAgentDevcontainersParams{
WorkspaceAgentID: agent.ID,
CreatedAt: dbtime.Now(),
ID: []uuid.UUID{uuid.New()},
Name: []string{"test-devcontainer"},
WorkspaceFolder: []string{"/workspace"},
ConfigPath: []string{"/workspace/.devcontainer/devcontainer.json"},
SubagentID: []uuid.UUID{subagent.ID},
})
require.NoError(t, err)
require.Len(t, devcontainers, 1)
require.True(t, devcontainers[0].SubagentID.Valid)
require.Equal(t, subagent.ID, devcontainers[0].SubagentID.UUID)
// Verify retrieval
retrieved, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID)
require.NoError(t, err)
require.Len(t, retrieved, 1)
require.True(t, retrieved[0].SubagentID.Valid)
require.Equal(t, subagent.ID, retrieved[0].SubagentID.UUID)
})
t.Run("InsertWithNilSubagentID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
// Create a separate agent for this subtest
agent2 := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: res.ID,
})
// When uuid.Nil is passed, it stores the zero UUID (not NULL).
// This matches the provisionerdserver behavior.
devcontainers, err := db.InsertWorkspaceAgentDevcontainers(ctx, database.InsertWorkspaceAgentDevcontainersParams{
WorkspaceAgentID: agent2.ID,
CreatedAt: dbtime.Now(),
ID: []uuid.UUID{uuid.New()},
Name: []string{"no-subagent"},
WorkspaceFolder: []string{"/workspace"},
ConfigPath: []string{""},
SubagentID: []uuid.UUID{uuid.Nil},
})
require.NoError(t, err)
require.Len(t, devcontainers, 1)
// uuid.Nil is stored as a zero UUID, not NULL.
require.Equal(t, uuid.Nil, devcontainers[0].SubagentID.UUID)
})
}
+4 -9
View File
@@ -17336,7 +17336,7 @@ func (q *sqlQuerier) ValidateUserIDs(ctx context.Context, userIds []uuid.UUID) (
const getWorkspaceAgentDevcontainersByAgentID = `-- name: GetWorkspaceAgentDevcontainersByAgentID :many
SELECT
id, workspace_agent_id, created_at, workspace_folder, config_path, name, subagent_id
id, workspace_agent_id, created_at, workspace_folder, config_path, name
FROM
workspace_agent_devcontainers
WHERE
@@ -17361,7 +17361,6 @@ func (q *sqlQuerier) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context
&i.WorkspaceFolder,
&i.ConfigPath,
&i.Name,
&i.SubagentID,
); err != nil {
return nil, err
}
@@ -17378,16 +17377,15 @@ func (q *sqlQuerier) GetWorkspaceAgentDevcontainersByAgentID(ctx context.Context
const insertWorkspaceAgentDevcontainers = `-- name: InsertWorkspaceAgentDevcontainers :many
INSERT INTO
workspace_agent_devcontainers (workspace_agent_id, created_at, id, name, workspace_folder, config_path, subagent_id)
workspace_agent_devcontainers (workspace_agent_id, created_at, id, name, workspace_folder, config_path)
SELECT
$1::uuid AS workspace_agent_id,
$2::timestamptz AS created_at,
unnest($3::uuid[]) AS id,
unnest($4::text[]) AS name,
unnest($5::text[]) AS workspace_folder,
unnest($6::text[]) AS config_path,
unnest($7::uuid[]) AS subagent_id
RETURNING workspace_agent_devcontainers.id, workspace_agent_devcontainers.workspace_agent_id, workspace_agent_devcontainers.created_at, workspace_agent_devcontainers.workspace_folder, workspace_agent_devcontainers.config_path, workspace_agent_devcontainers.name, workspace_agent_devcontainers.subagent_id
unnest($6::text[]) AS config_path
RETURNING workspace_agent_devcontainers.id, workspace_agent_devcontainers.workspace_agent_id, workspace_agent_devcontainers.created_at, workspace_agent_devcontainers.workspace_folder, workspace_agent_devcontainers.config_path, workspace_agent_devcontainers.name
`
type InsertWorkspaceAgentDevcontainersParams struct {
@@ -17397,7 +17395,6 @@ type InsertWorkspaceAgentDevcontainersParams struct {
Name []string `db:"name" json:"name"`
WorkspaceFolder []string `db:"workspace_folder" json:"workspace_folder"`
ConfigPath []string `db:"config_path" json:"config_path"`
SubagentID []uuid.UUID `db:"subagent_id" json:"subagent_id"`
}
func (q *sqlQuerier) InsertWorkspaceAgentDevcontainers(ctx context.Context, arg InsertWorkspaceAgentDevcontainersParams) ([]WorkspaceAgentDevcontainer, error) {
@@ -17408,7 +17405,6 @@ func (q *sqlQuerier) InsertWorkspaceAgentDevcontainers(ctx context.Context, arg
pq.Array(arg.Name),
pq.Array(arg.WorkspaceFolder),
pq.Array(arg.ConfigPath),
pq.Array(arg.SubagentID),
)
if err != nil {
return nil, err
@@ -17424,7 +17420,6 @@ func (q *sqlQuerier) InsertWorkspaceAgentDevcontainers(ctx context.Context, arg
&i.WorkspaceFolder,
&i.ConfigPath,
&i.Name,
&i.SubagentID,
); err != nil {
return nil, err
}
@@ -1,14 +1,13 @@
-- name: InsertWorkspaceAgentDevcontainers :many
INSERT INTO
workspace_agent_devcontainers (workspace_agent_id, created_at, id, name, workspace_folder, config_path, subagent_id)
workspace_agent_devcontainers (workspace_agent_id, created_at, id, name, workspace_folder, config_path)
SELECT
@workspace_agent_id::uuid AS workspace_agent_id,
@created_at::timestamptz AS created_at,
unnest(@id::uuid[]) AS id,
unnest(@name::text[]) AS name,
unnest(@workspace_folder::text[]) AS workspace_folder,
unnest(@config_path::text[]) AS config_path,
unnest(@subagent_id::uuid[]) AS subagent_id
unnest(@config_path::text[]) AS config_path
RETURNING workspace_agent_devcontainers.*;
-- name: GetWorkspaceAgentDevcontainersByAgentID :many
+20 -269
View File
@@ -2897,7 +2897,6 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
devcontainerNames = make([]string, 0, len(devcontainers))
devcontainerWorkspaceFolders = make([]string, 0, len(devcontainers))
devcontainerConfigPaths = make([]string, 0, len(devcontainers))
devcontainerSubagentIDs = make([]uuid.UUID, 0, len(devcontainers))
)
for _, dc := range devcontainers {
id := uuid.New()
@@ -2906,22 +2905,6 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
devcontainerWorkspaceFolders = append(devcontainerWorkspaceFolders, dc.WorkspaceFolder)
devcontainerConfigPaths = append(devcontainerConfigPaths, dc.ConfigPath)
var subagentID uuid.UUID
hasSubagentID := len(dc.SubagentId) > 0
if hasSubagentID {
subagentID, err = uuid.FromBytes(dc.SubagentId)
if err != nil {
return xerrors.Errorf("parse devcontainer %q subagent_id: %w", dc.Name, err)
}
}
if hasSubagentID && (len(dc.Apps) > 0 || len(dc.Scripts) > 0 || len(dc.Envs) > 0) {
subagentID, err = insertDevcontainerSubagent(ctx, db, subagentID, dc, prAgent, agentID, resource.ID, snapshot)
if err != nil {
return err
}
}
devcontainerSubagentIDs = append(devcontainerSubagentIDs, subagentID)
// Add a log source and script for each devcontainer so we can
// track logs and timings for each devcontainer.
displayName := fmt.Sprintf("Dev Container (%s)", dc.Name)
@@ -2949,7 +2932,6 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
Name: devcontainerNames,
WorkspaceFolder: devcontainerWorkspaceFolders,
ConfigPath: devcontainerConfigPaths,
SubagentID: devcontainerSubagentIDs,
})
if err != nil {
return xerrors.Errorf("insert agent devcontainer: %w", err)
@@ -3001,18 +2983,35 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
}
appSlugs[slug] = struct{}{}
health := database.WorkspaceAppHealthDisabled
if app.Healthcheck == nil {
app.Healthcheck = &sdkproto.Healthcheck{}
}
health := appHealthFromHealthcheck(app.Healthcheck)
sharingLevel := appSharingLevelToDatabase(app.SharingLevel)
openIn := appOpenInToDatabase(app.OpenIn)
if app.Healthcheck.Url != "" {
health = database.WorkspaceAppHealthInitializing
}
sharingLevel := database.AppSharingLevelOwner
switch app.SharingLevel {
case sdkproto.AppSharingLevel_AUTHENTICATED:
sharingLevel = database.AppSharingLevelAuthenticated
case sdkproto.AppSharingLevel_PUBLIC:
sharingLevel = database.AppSharingLevelPublic
}
displayGroup := sql.NullString{
Valid: app.Group != "",
String: app.Group,
}
openIn := database.WorkspaceAppOpenInSlimWindow
switch app.OpenIn {
case sdkproto.AppOpenIn_TAB:
openIn = database.WorkspaceAppOpenInTab
case sdkproto.AppOpenIn_SLIM_WINDOW:
openIn = database.WorkspaceAppOpenInSlimWindow
}
var appID string
if app.Id == "" || app.Id == uuid.Nil.String() {
appID = uuid.NewString()
@@ -3363,251 +3362,3 @@ func convertDisplayApps(apps *sdkproto.DisplayApps) []database.DisplayApp {
}
return dapps
}
// appSharingLevelToDatabase converts a proto app sharing level to a database
// app sharing level.
func appSharingLevelToDatabase(level sdkproto.AppSharingLevel) database.AppSharingLevel {
switch level {
case sdkproto.AppSharingLevel_AUTHENTICATED:
return database.AppSharingLevelAuthenticated
case sdkproto.AppSharingLevel_PUBLIC:
return database.AppSharingLevelPublic
default:
return database.AppSharingLevelOwner
}
}
// appOpenInToDatabase converts a proto app open_in setting to a database
// workspace app open_in setting.
func appOpenInToDatabase(openIn sdkproto.AppOpenIn) database.WorkspaceAppOpenIn {
switch openIn {
case sdkproto.AppOpenIn_TAB:
return database.WorkspaceAppOpenInTab
default:
return database.WorkspaceAppOpenInSlimWindow
}
}
// appHealthFromHealthcheck returns the initial health status for an app based
// on whether it has a healthcheck URL configured.
func appHealthFromHealthcheck(hc *sdkproto.Healthcheck) database.WorkspaceAppHealth {
if hc != nil && hc.Url != "" {
return database.WorkspaceAppHealthInitializing
}
return database.WorkspaceAppHealthDisabled
}
// insertDevcontainerSubagent creates a subagent for a devcontainer with its apps, scripts, and envs.
// If subagentID is uuid.Nil, a new UUID will be generated.
func insertDevcontainerSubagent(
ctx context.Context,
db database.Store,
subagentID uuid.UUID,
dc *sdkproto.Devcontainer,
parentAgent *sdkproto.Agent,
parentAgentID uuid.UUID,
resourceID uuid.UUID,
snapshot *telemetry.Snapshot,
) (uuid.UUID, error) {
if subagentID == uuid.Nil {
subagentID = uuid.New()
}
subAgentEnvs := make(map[string]string, len(dc.Envs))
for _, env := range dc.Envs {
subAgentEnvs[env.Name] = env.Value
}
var subAgentEnvsJSON pqtype.NullRawMessage
if len(subAgentEnvs) > 0 {
envJSON, err := json.Marshal(subAgentEnvs)
if err != nil {
return uuid.Nil, xerrors.Errorf("marshal devcontainer %q envs: %w", dc.Name, err)
}
subAgentEnvsJSON = pqtype.NullRawMessage{RawMessage: envJSON, Valid: true}
}
_, err := db.InsertWorkspaceAgent(ctx, database.InsertWorkspaceAgentParams{
ID: subagentID,
ParentID: uuid.NullUUID{Valid: true, UUID: parentAgentID},
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
ResourceID: resourceID,
Name: dc.Name,
AuthToken: uuid.New(),
AuthInstanceID: sql.NullString{},
Architecture: parentAgent.Architecture,
EnvironmentVariables: subAgentEnvsJSON,
Directory: dc.WorkspaceFolder,
OperatingSystem: parentAgent.OperatingSystem,
ConnectionTimeoutSeconds: parentAgent.GetConnectionTimeoutSeconds(),
TroubleshootingURL: parentAgent.GetTroubleshootingUrl(),
MOTDFile: "",
DisplayApps: []database.DisplayApp{},
InstanceMetadata: pqtype.NullRawMessage{},
ResourceMetadata: pqtype.NullRawMessage{},
DisplayOrder: 0,
APIKeyScope: database.AgentKeyScopeEnumAll,
})
if err != nil {
return uuid.Nil, xerrors.Errorf("insert devcontainer %q subagent: %w", dc.Name, err)
}
if err := insertDevcontainerSubagentApps(ctx, db, dc, subagentID, snapshot); err != nil {
return uuid.Nil, err
}
if err := insertDevcontainerSubagentScripts(ctx, db, dc, subagentID); err != nil {
return uuid.Nil, err
}
return subagentID, nil
}
// insertDevcontainerSubagentApps inserts workspace apps for a devcontainer subagent.
func insertDevcontainerSubagentApps(
ctx context.Context,
db database.Store,
dc *sdkproto.Devcontainer,
subagentID uuid.UUID,
snapshot *telemetry.Snapshot,
) error {
for _, app := range dc.Apps {
slug := app.Slug
if slug == "" {
return xerrors.Errorf("devcontainer %q app must have a slug set", dc.Name)
}
if !provisioner.AppSlugRegex.MatchString(slug) {
return xerrors.Errorf("devcontainer %q app slug %q does not match regex %q", dc.Name, slug, provisioner.AppSlugRegex.String())
}
if app.Healthcheck == nil {
app.Healthcheck = &sdkproto.Healthcheck{}
}
health := appHealthFromHealthcheck(app.Healthcheck)
sharingLevel := appSharingLevelToDatabase(app.SharingLevel)
openIn := appOpenInToDatabase(app.OpenIn)
displayGroup := sql.NullString{
Valid: app.Group != "",
String: app.Group,
}
appID := uuid.New()
if app.Id != "" && app.Id != uuid.Nil.String() {
var err error
appID, err = uuid.Parse(app.Id)
if err != nil {
return xerrors.Errorf("parse devcontainer %q app uuid: %w", dc.Name, err)
}
}
dbApp, err := db.UpsertWorkspaceApp(ctx, database.UpsertWorkspaceAppParams{
ID: appID,
CreatedAt: dbtime.Now(),
AgentID: subagentID,
Slug: slug,
DisplayName: app.DisplayName,
Icon: app.Icon,
Command: sql.NullString{
String: app.Command,
Valid: app.Command != "",
},
Url: sql.NullString{
String: app.Url,
Valid: app.Url != "",
},
External: app.External,
Subdomain: app.Subdomain,
SharingLevel: sharingLevel,
HealthcheckUrl: app.Healthcheck.Url,
HealthcheckInterval: app.Healthcheck.Interval,
HealthcheckThreshold: app.Healthcheck.Threshold,
Health: health,
// #nosec G115 - Order represents a display order value that's always small and fits in int32
DisplayOrder: int32(app.Order),
DisplayGroup: displayGroup,
Hidden: app.Hidden,
OpenIn: openIn,
Tooltip: app.Tooltip,
})
if err != nil {
return xerrors.Errorf("upsert devcontainer %q app: %w", dc.Name, err)
}
snapshot.WorkspaceApps = append(snapshot.WorkspaceApps, telemetry.ConvertWorkspaceApp(dbApp))
}
return nil
}
// insertDevcontainerSubagentScripts inserts scripts and log sources for a devcontainer subagent.
func insertDevcontainerSubagentScripts(
ctx context.Context,
db database.Store,
dc *sdkproto.Devcontainer,
subagentID uuid.UUID,
) error {
if len(dc.Scripts) == 0 {
return nil
}
var (
logSourceIDs = make([]uuid.UUID, 0, len(dc.Scripts))
logSourceNames = make([]string, 0, len(dc.Scripts))
logSourceIcons = make([]string, 0, len(dc.Scripts))
scriptIDs = make([]uuid.UUID, 0, len(dc.Scripts))
scriptLogPaths = make([]string, 0, len(dc.Scripts))
scriptSources = make([]string, 0, len(dc.Scripts))
scriptCron = make([]string, 0, len(dc.Scripts))
scriptTimeout = make([]int32, 0, len(dc.Scripts))
scriptStartBlock = make([]bool, 0, len(dc.Scripts))
scriptRunOnStart = make([]bool, 0, len(dc.Scripts))
scriptRunOnStop = make([]bool, 0, len(dc.Scripts))
scriptDisplayNames = make([]string, 0, len(dc.Scripts))
)
for _, script := range dc.Scripts {
logSourceID := uuid.New()
logSourceIDs = append(logSourceIDs, logSourceID)
logSourceNames = append(logSourceNames, script.DisplayName)
logSourceIcons = append(logSourceIcons, script.Icon)
scriptIDs = append(scriptIDs, uuid.New())
scriptLogPaths = append(scriptLogPaths, script.LogPath)
scriptSources = append(scriptSources, script.Script)
scriptCron = append(scriptCron, script.Cron)
scriptTimeout = append(scriptTimeout, script.TimeoutSeconds)
scriptStartBlock = append(scriptStartBlock, script.StartBlocksLogin)
scriptRunOnStart = append(scriptRunOnStart, script.RunOnStart)
scriptRunOnStop = append(scriptRunOnStop, script.RunOnStop)
scriptDisplayNames = append(scriptDisplayNames, script.DisplayName)
}
_, err := db.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{
WorkspaceAgentID: subagentID,
ID: logSourceIDs,
CreatedAt: dbtime.Now(),
DisplayName: logSourceNames,
Icon: logSourceIcons,
})
if err != nil {
return xerrors.Errorf("insert devcontainer %q subagent log sources: %w", dc.Name, err)
}
_, err = db.InsertWorkspaceAgentScripts(ctx, database.InsertWorkspaceAgentScriptsParams{
WorkspaceAgentID: subagentID,
LogSourceID: logSourceIDs,
LogPath: scriptLogPaths,
CreatedAt: dbtime.Now(),
Script: scriptSources,
Cron: scriptCron,
TimeoutSeconds: scriptTimeout,
StartBlocksLogin: scriptStartBlock,
RunOnStart: scriptRunOnStart,
RunOnStop: scriptRunOnStop,
DisplayName: scriptDisplayNames,
ID: scriptIDs,
})
if err != nil {
return xerrors.Errorf("insert devcontainer %q subagent scripts: %w", dc.Name, err)
}
return nil
}
@@ -3706,7 +3706,6 @@ func TestInsertWorkspaceResource(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{})
subagentID := uuid.New()
err := insert(db, job.ID, &sdkproto.Resource{
Name: "something",
Type: "aws_instance",
@@ -3715,7 +3714,6 @@ func TestInsertWorkspaceResource(t *testing.T) {
Devcontainers: []*sdkproto.Devcontainer{
{Name: "foo", WorkspaceFolder: "/workspace1"},
{Name: "bar", WorkspaceFolder: "/workspace2", ConfigPath: "/workspace2/.devcontainer/devcontainer.json"},
{Name: "baz", WorkspaceFolder: "/workspace3", SubagentId: subagentID[:]},
},
}},
})
@@ -3725,33 +3723,20 @@ func TestInsertWorkspaceResource(t *testing.T) {
require.Len(t, resources, 1)
agents, err := db.GetWorkspaceAgentsByResourceIDs(ctx, []uuid.UUID{resources[0].ID})
require.NoError(t, err)
// Expect 2 agents: the parent agent "dev" and the subagent "baz".
require.Len(t, agents, 2)
// Find the parent agent (no parent ID).
var agent database.WorkspaceAgent
for _, a := range agents {
if !a.ParentID.Valid {
agent = a
break
}
}
require.Equal(t, "dev", agent.Name)
require.Len(t, agents, 1)
agent := agents[0]
devcontainers, err := db.GetWorkspaceAgentDevcontainersByAgentID(ctx, agent.ID)
sort.Slice(devcontainers, func(i, j int) bool {
return devcontainers[i].Name > devcontainers[j].Name
})
require.NoError(t, err)
require.Len(t, devcontainers, 3)
require.Len(t, devcontainers, 2)
require.Equal(t, "foo", devcontainers[0].Name)
require.Equal(t, "/workspace1", devcontainers[0].WorkspaceFolder)
require.Equal(t, "", devcontainers[0].ConfigPath)
require.Equal(t, uuid.Nil, devcontainers[0].SubagentID.UUID)
require.Equal(t, "baz", devcontainers[1].Name)
require.Equal(t, "/workspace3", devcontainers[1].WorkspaceFolder)
require.Equal(t, subagentID, devcontainers[1].SubagentID.UUID)
require.Equal(t, "bar", devcontainers[2].Name)
require.Equal(t, "/workspace2", devcontainers[2].WorkspaceFolder)
require.Equal(t, "/workspace2/.devcontainer/devcontainer.json", devcontainers[2].ConfigPath)
require.Equal(t, "bar", devcontainers[1].Name)
require.Equal(t, "/workspace2", devcontainers[1].WorkspaceFolder)
require.Equal(t, "/workspace2/.devcontainer/devcontainer.json", devcontainers[1].ConfigPath)
})
}
@@ -0,0 +1,215 @@
# Workspace Startup Coordination Examples
## Script Example
This example shows a complete, production-ready script that starts Claude Code
only after a repository has been cloned. It includes error handling, graceful
degradation, and cleanup on exit:
```bash
#!/bin/bash
set -euo pipefail
UNIT_NAME="claude-code"
DEPENDENCIES="git-clone"
REPO_DIR="/workspace/repo"
# Track if sync started successfully
SYNC_STARTED=0
# Declare dependencies
if [ -n "$DEPENDENCIES" ]; then
if command -v coder > /dev/null 2>&1; then
IFS=',' read -ra DEPS <<< "$DEPENDENCIES"
for dep in "${DEPS[@]}"; do
dep=$(echo "$dep" | xargs)
if [ -n "$dep" ]; then
echo "Waiting for dependency: $dep"
coder exp sync want "$UNIT_NAME" "$dep" > /dev/null 2>&1 || \
echo "Warning: Failed to register dependency $dep, continuing..."
fi
done
else
echo "Coder CLI not found, running without sync coordination"
fi
fi
# Start sync and track success
if [ -n "$UNIT_NAME" ]; then
if command -v coder > /dev/null 2>&1; then
if coder exp sync start "$UNIT_NAME" > /dev/null 2>&1; then
SYNC_STARTED=1
echo "Started sync: $UNIT_NAME"
else
echo "Sync start failed or not available, continuing without sync..."
fi
fi
fi
# Ensure completion on exit (even if script fails)
cleanup_sync() {
if [ "$SYNC_STARTED" -eq 1 ] && [ -n "$UNIT_NAME" ]; then
echo "Completing sync: $UNIT_NAME"
coder exp sync complete "$UNIT_NAME" > /dev/null 2>&1 || \
echo "Warning: Sync complete failed, but continuing..."
fi
}
trap cleanup_sync EXIT
# Now do the actual work
echo "Repository cloned, starting Claude Code"
cd "$REPO_DIR"
claude
```
This script demonstrates several [best practices](./usage.md#best-practices):
- Checking for Coder CLI availability before using sync commands
- Tracking whether `coder exp sync` started successfully
- Using `trap` to ensure completion even if the script exits early
- Graceful degradation when `coder exp sync` isn't available
- Redirecting `coder exp sync` output to reduce noise in logs
## Template Migration Example
Below is a simple example Docker template that clones [Miguel Grinberg's example Flask repo](https://github.com/miguelgrinberg/microblog/) using the [`git-clone` module](https://registry.coder.com/modules/coder/git-clone) and installs the required dependencies for the project:
- Python development headers (required for building some Python packages)
- Python dependencies from the project's `requirements.txt`
We've omitted some details (such as persistent storage) for brevity, but these are easily added.
### Before
```terraform
data "coder_provisioner" "me" {}
data "coder_workspace" "me" {}
data "coder_workspace_owner" "me" {}
resource "docker_container" "workspace" {
count = data.coder_workspace.me.start_count
image = "codercom/enterprise-base:ubuntu"
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
entrypoint = ["sh", "-c", coder_agent.main.init_script]
env = [
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
]
}
resource "coder_agent" "main" {
arch = data.coder_provisioner.me.arch
os = "linux"
}
module "git-clone" {
count = data.coder_workspace.me.start_count
source = "registry.coder.com/coder/git-clone/coder"
version = "1.2.3"
agent_id = coder_agent.main.id
url = "https://github.com/miguelgrinberg/microblog"
}
resource "coder_script" "setup" {
count = data.coder_workspace.me.start_count
agent_id = coder_agent.main.id
display_name = "Installing Dependencies"
run_on_start = true
script = <<EOT
sudo apt-get update
sudo apt-get install --yes python-dev-is-python3
cd ${module.git-clone[count.index].repo_dir}
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
EOT
}
```
We can note the following issues in the above template:
1. There is a race between cloning the repository and the `pip install` commands, which can lead to failed workspace startups in some cases.
2. The `apt` commands can run independently of the `git clone` command, meaning that there is a potential speedup here.
Based on the above, we can improve both the startup time and reliability of the template by splitting the monolithic startup script into multiple independent scripts:
- Install `apt` dependencies
- Install `pip` dependencies (depends on the `git-clone` module and the above step)
### After
Here is the updated version of the template:
```terraform
data "coder_provisioner" "me" {}
data "coder_workspace" "me" {}
data "coder_workspace_owner" "me" {}
resource "docker_container" "workspace" {
count = data.coder_workspace.me.start_count
image = "codercom/enterprise-base:ubuntu"
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
entrypoint = ["sh", "-c", coder_agent.main.init_script]
env = [
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
"CODER_AGENT_SOCKET_SERVER_ENABLED=true"
]
}
resource "coder_agent" "main" {
arch = data.coder_provisioner.me.arch
os = "linux"
}
module "git-clone" {
count = data.coder_workspace.me.start_count
source = "registry.coder.com/coder/git-clone/coder"
version = "1.2.3"
agent_id = coder_agent.main.id
url = "https://github.com/miguelgrinberg/microblog/"
post_clone_script = <<-EOT
coder exp sync start git-clone && coder exp sync complete git-clone
EOT
}
resource "coder_script" "apt-install" {
count = data.coder_workspace.me.start_count
agent_id = coder_agent.main.id
display_name = "Installing APT Dependencies"
run_on_start = true
script = <<EOT
trap 'coder exp sync complete apt-install' EXIT
coder exp sync start apt-install
sudo apt-get update
sudo apt-get install --yes python-dev-is-python3
EOT
}
resource "coder_script" "pip-install" {
count = data.coder_workspace.me.start_count
agent_id = coder_agent.main.id
display_name = "Installing Python Dependencies"
run_on_start = true
script = <<EOT
trap 'coder exp sync complete pip-install' EXIT
coder exp sync want pip-install git-clone apt-install
coder exp sync start pip-install
cd ${module.git-clone[count.index].repo_dir}
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
EOT
}
```
A short summary of the changes:
- We've added `CODER_AGENT_SOCKET_SERVER_ENABLED=true` to the environment variables of the Docker container in which the Coder agent runs.
- We've broken the monolithic "setup" script into two separate scripts: one for the `apt` commands, and one for the `pip` commands.
- In each script, we've added a `coder exp sync start $SCRIPT_NAME` command to mark the startup script as started.
- We've also added an exit trap to ensure that we mark the startup scripts as completed. Without this, the `coder exp sync wait` command would eventually time out.
- We have used the `post_clone_script` feature of the `git-clone` module to allow waiting on the Git repository clone.
- In the `pip-install` script, we have declared a dependency on both `git-clone` and `apt-install`.
With these changes, the startup time has been reduced significantly and there is no longer any possibility of a race condition.
@@ -0,0 +1,50 @@
# Workspace Startup Coordination
> [!NOTE]
> This feature is experimental and may change without notice in future releases.
When workspaces start, scripts often need to run in a specific order.
For example, an IDE or coding agent might need the repository cloned
before it can start. Without explicit coordination, these scripts can
race against each other, leading to startup failures and inconsistent
workspace states.
Coder's workspace startup coordination feature lets you declare
dependencies between startup scripts and ensure they run in the correct order.
This eliminates race conditions and makes workspace startup predictable and
reliable.
## Why use this?
Simply placing all of your workspace initialization logic in a single script works, but leads to slow workspace startup times.
Breaking this out into multiple independent `coder_script` resources improves startup times by allowing the scripts to run in parallel.
However, this can lead to intermittent failures between dependent scripts due to timing issues.
Up until now, template authors have had to rely on manual coordination methods (for example, touching a file upon completion).
The goal of startup script coordination is to provide a single reliable source of truth for coordination between workspace startup scripts.
## Quick Start
To start using workspace startup coordination, follow these steps:
1. Set the environment variable `CODER_AGENT_SOCKET_SERVER_ENABLED=true` in your template to enable the agent socket server. The environment variable *must* be readable to the agent process. For example, in a template using the `kreuzwerker/docker` provider:
```terraform
resource "docker_container" "workspace" {
image = "codercom/enterprise-base:ubuntu"
env = [
"CODER_AGENT_TOKEN=${coder_agent.main.token}",
"CODER_AGENT_SOCKET_SERVER_ENABLED=true",
]
}
```
1. Add calls to `coder exp sync (start|complete)` in your startup scripts where required:
```bash
trap 'coder exp sync complete my-script' EXIT
coder exp sync want my-script my-other-script
coder exp sync start my-script
# Existing startup logic
```
For more information, refer to the [usage documentation](./usage.md), [troubleshooting documentation](./troubleshooting.md), or view our [examples](./example.md).
@@ -0,0 +1,98 @@
# Workspace Startup Coordination Troubleshooting
> [!NOTE]
> This feature is experimental and may change without notice in future releases.
## Test Sync Availability
From a workspace terminal, test if sync is working using `coder exp sync ping`:
```bash
coder exp sync ping
```
* If sync is working, expect the output to be `Success`.
* Otherwise, you will see an error message similar to the below:
```bash
error: connect to agent socket: connect to socket: dial unix /tmp/coder-agent.sock: connect: permission denied
```
## Check Unit Status
You can check the status of a specific unit using `coder exp sync status`:
```bash
coder exp sync status git-clone
```
If the unit exists, you will see output similar to the below:
```bash
# coder exp sync status git-clone
Unit: git-clone
Status: completed
Ready: true
```
If the unit is not known to the agent, you will see output similar to the below:
```bash
# coder exp sync status doesnotexist
Unit: doesnotexist
Status: not registered
Ready: true
Dependencies:
No dependencies found
```
## Common Issues
### Socket not enabled
If the Coder Agent Socket Server is not enabled, you will see an error message similar to the below when running `coder exp sync ping`:
```bash
error: connect to agent socket: connect to socket: dial unix /tmp/coder-agent.sock: connect: no such file or directory
```
Verify `CODER_AGENT_SOCKET_SERVER_ENABLED=true` is set in the Coder agent's environment:
```bash
tr '\0' '\n' < /proc/$(pidof -s coder)/environ | grep CODER_AGENT_SOCKET_SERVER_ENABLED
```
If the output of the above command is empty, review your template and ensure that the environment variable is set such that it is readable by the Coder agent process. Setting it on the `coder_agent` resource directly is **not** sufficient.
## Workspace startup script hangs
If the workspace startup scripts appear to 'hang', one or more of your startup scripts may be waiting for a dependency that never completes.
* Inside the workspace, review `/tmp/coder-script-*.log` for more details on your script's execution.
> **Tip:** add `set -x` to the top of your script to enable debug mode and update/restart the workspace.
* Review your template and verify that `coder exp sync complete <unit>` is called after the script completes e.g. with an exit trap.
* View the unit status using `coder exp sync status <unit>`.
## Workspace startup scripts fail
If the workspace startup scripts fail:
* Review `/tmp/coder-script-*.log` inside the workspace for script errors.
* Verify the Coder CLI is available in `$PATH` inside the workspace:
```bash
command -v coder
```
## Cycle detected
If you see an error similar to the below in your startup script logs, you have defined a cyclic dependency:
```bash
error: declare dependency failed: cannot add dependency: adding edge for unit "bar": failed to add dependency
adding edge (bar -> foo): cycle detected
```
To fix this, review your dependency declarations and redesign them to remove the cycle. It may help to draw out the dependency graph to find
the cycle.
@@ -0,0 +1,283 @@
# Workspace Startup Coordination Usage
> [!NOTE]
> This feature is experimental and may change without notice in future releases.
Startup coordination is built around the concept of **units**. You declare units in your Coder workspace template using the `coder exp sync` command in `coder_script` resources. When the Coder agent starts, it keeps an in-memory directed acyclic graph (DAG) of all units of which it is aware. When you need to synchronize with another unit, you can use `coder exp sync start $UNIT_NAME` to block until all dependencies of that unit have been marked complete.
## What is a unit?
A **unit** is a named phase of work, typically corresponding to a script or initialization
task.
- Units **may** declare dependencies on other units, creating an explicit ordering for workspace initialization.
- Units **must** be registered before they can be marked as complete.
- Units **may** be marked as dependencies before they are registered.
- Units **must not** declare cyclic dependencies. Attempting to create a cyclic dependency will result in an error.
## Requirements
> [!IMPORTANT]
> The `coder exp sync` command is only available from Coder version >=v2.30 onwards.
To use startup dependencies in your templates, you must:
- Enable the Coder Agent Socket Server.
- Modify your workspace startup scripts to run in parallel and declare dependencies as required using `coder exp sync`.
### Enable the Coder Agent Socket Server
The agent socket server provides the communication layer for startup
coordination. To enable it, set `CODER_AGENT_SOCKET_SERVER_ENABLED=true` in the environment in which the agent is running.
The exact method for doing this depends on your infrastructure platform:
<div class="tabs">
#### Docker / Podman
```hcl
resource "docker_container" "workspace" {
count = data.coder_workspace.me.start_count
image = "codercom/enterprise-base:ubuntu"
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
env = [
"CODER_AGENT_SOCKET_SERVER_ENABLED=true"
]
command = ["sh", "-c", coder_agent.main.init_script]
}
```
#### Kubernetes
```hcl
resource "kubernetes_pod" "main" {
count = data.coder_workspace.me.start_count
metadata {
name = "coder-${data.coder_workspace_owner.me.name}-${lower(data.coder_workspace.me.name)}"
namespace = var.workspaces_namespace
}
spec {
container {
name = "dev"
image = "codercom/enterprise-base:ubuntu"
command = ["sh", "-c", coder_agent.main.init_script]
env {
name = "CODER_AGENT_SOCKET_SERVER_ENABLED"
value = "true"
}
}
}
}
```
#### AWS EC2 / VMs
For virtual machines, pass the environment variable through cloud-init or your
provisioning system:
```hcl
locals {
agent_env = {
"CODER_AGENT_SOCKET_SERVER_ENABLED" = "true"
}
}
# In your cloud-init userdata template:
# %{ for key, value in local.agent_env ~}
# export ${key}="${value}"
# %{ endfor ~}
```
</div>
### Declare Dependencies in your Workspace Startup Scripts
<div class="tabs">
#### Single Dependency
Here's a simple example of a script that depends on another unit completing
first:
```bash
#!/bin/bash
UNIT_NAME="my-setup"
# Declare dependency on git-clone
coder exp sync want "$UNIT_NAME" "git-clone"
# Wait for dependencies and mark as started
coder exp sync start "$UNIT_NAME"
# Do your work here
echo "Running after git-clone completes"
# Signal completion
coder exp sync complete "$UNIT_NAME"
```
This script will wait until the `git-clone` unit completes before starting its
own work.
#### Multiple Dependencies
If your unit depends on multiple other units, you can declare all dependencies
before starting:
```bash
#!/bin/bash
UNIT_NAME="my-app"
DEPENDENCIES="git-clone,env-setup,database-migration"
# Declare all dependencies
if [ -n "$DEPENDENCIES" ]; then
IFS=',' read -ra DEPS <<< "$DEPENDENCIES"
for dep in "${DEPS[@]}"; do
dep=$(echo "$dep" | xargs) # Trim whitespace
if [ -n "$dep" ]; then
coder exp sync want "$UNIT_NAME" "$dep"
fi
done
fi
# Wait for all dependencies
coder exp sync start "$UNIT_NAME"
# Your work here
echo "All dependencies satisfied, starting application"
# Signal completion
coder exp sync complete "$UNIT_NAME"
```
</div>
## Best Practices
### Test your changes before rolling out to all users
Before rolling out to all users:
1. Create a test workspace from the updated template
2. Check workspace build logs for sync messages
3. Verify all units reach "completed" status
4. Test workspace functionality
Once you're satisfied, [promote the new template version](../../../reference/cli/templates_versions_promote.md).
### Handle missing CLI gracefully
Not all workspaces will have the Coder CLI available in `$PATH`. Check for availability of the Coder CLI before using
sync commands:
```bash
if command -v coder > /dev/null 2>&1; then
coder exp sync start "$UNIT_NAME"
else
echo "Coder CLI not available, continuing without coordination"
fi
```
### Complete units that start successfully
Units **must** call `coder exp sync complete` to unblock dependent units. Use `trap` to ensure
completion even if your script exits early or encounters errors:
```bash
SYNC_STARTED=0
if coder exp sync start "$UNIT_NAME"; then
SYNC_STARTED=1
fi
cleanup_sync() {
if [ "$SYNC_STARTED" -eq 1 ]; then
coder exp sync complete "$UNIT_NAME"
fi
}
trap cleanup_sync EXIT
```
### Use descriptive unit names
Names should explain what the unit does, not its position in a sequence:
- Good: `git-clone`, `env-setup`, `database-migration`
- Avoid: `step1`, `init`, `script-1`
### Prefix a unique name to your units
When using `coder exp sync` in modules, note that unit names like `git-clone` might be common. Prefix the name of your module to your units to
ensure that your unit does not conflict with others.
- Good: `<module>.git-clone`, `<module>.claude`
- Bad: `git-clone`, `claude`
### Document dependencies
Add comments explaining why dependencies exist:
```hcl
resource "coder_script" "ide_setup" {
# Depends on git-clone because we need .vscode/extensions.json
# Depends on env-setup because we need $NODE_PATH configured
script = <<-EOT
coder exp sync want "ide-setup" "git-clone"
coder exp sync want "ide-setup" "env-setup"
# ...
EOT
}
```
### Avoid circular dependencies
The Coder Agent detects and rejects circular dependencies, but they indicate a design problem:
```bash
# This will fail
coder exp sync want "unit-a" "unit-b"
coder exp sync want "unit-b" "unit-a"
```
## Frequently Asked Questions
### How do I identify scripts that can benefit from startup coordination?
Look for these patterns in existing templates:
- `sleep` commands used to order scripts
- Using files to coordinate startup between scripts (e.g. `touch /tmp/startup-complete`)
- Scripts that fail intermittently on startup
- Comments like "must run after X" or "wait for Y"
### Will this slow down my workspace?
No. The socket server adds minimal overhead, and the default polling interval is 1
second, so waiting for dependencies adds at most a few seconds to startup.
You are more likely to notice an improvement in startup times as it becomes easier to manage complex dependencies in parallel.
### How do units interact with each other?
Units with no dependencies run immediately and in parallel.
Only units with unsatisfied dependencies wait for their dependencies.
### How long can a dependency take to complete?
By default, `coder exp sync start` has a 5-minute timeout to prevent indefinite hangs.
Upon timeout, the command will exit with an error code and print `timeout waiting for dependencies of unit <unit_name>` to stderr.
You can adjust this timeout as necessary for long-running operations:
```bash
coder exp sync start "long-operation" --timeout 10m
```
### Is state stored between restarts?
No. Sync state is kept in-memory only and resets on workspace restart.
This is intentional to ensure clean initialization on every start.
+23
View File
@@ -667,6 +667,29 @@
"description": "Log workspace processes",
"path": "./admin/templates/extending-templates/process-logging.md",
"state": ["premium"]
},
{
"title": "Startup Dependencies",
"description": "Coordinate workspace startup with dependency management",
"path": "./admin/templates/startup-coordination/index.md",
"state": ["early access"],
"children": [
{
"title": "Usage",
"description": "How to use startup coordination",
"path": "./admin/templates/startup-coordination/usage.md"
},
{
"title": "Troubleshooting",
"description": "Troubleshoot startup coordination",
"path": "./admin/templates/startup-coordination/troubleshooting.md"
},
{
"title": "Examples",
"description": "Examples of startup coordination",
"path": "./admin/templates/startup-coordination/example.md"
}
]
}
]
},
+3 -1
View File
@@ -60,7 +60,9 @@ as [JetBrains](./workspace-access/jetbrains/index.md) or
Once started, the Coder agent is responsible for running your workspace startup
scripts. These may configure tools, service connections, or personalization with
[dotfiles](./workspace-dotfiles.md).
[dotfiles](./workspace-dotfiles.md). For complex initialization with multiple
dependent scripts, see
[Workspace Startup Coordination](../admin/templates/startup-coordination/index.md).
Once these steps have completed, your workspace will now be in the `Running`
state. You can access it via any of the [supported methods](./index.md), stop it
@@ -0,0 +1,33 @@
//nolint:revive,gocritic,errname,unconvert
package audit
import "log/slog"
// LogAuditor implements proxy.Auditor by logging to slog
type LogAuditor struct {
logger *slog.Logger
}
// NewLogAuditor creates a new LogAuditor
func NewLogAuditor(logger *slog.Logger) *LogAuditor {
return &LogAuditor{
logger: logger,
}
}
// AuditRequest logs the request using structured logging
func (a *LogAuditor) AuditRequest(req Request) {
if req.Allowed {
a.logger.Info("ALLOW",
"method", req.Method,
"url", req.URL,
"host", req.Host,
"rule", req.Rule)
} else {
a.logger.Warn("DENY",
"method", req.Method,
"url", req.URL,
"host", req.Host,
)
}
}
@@ -0,0 +1,10 @@
//nolint:paralleltest,testpackage,revive,gocritic
package audit
import "testing"
// Stub test file - tests removed
func TestStub(t *testing.T) {
// This is a stub test
t.Skip("stub test file")
}
@@ -0,0 +1,65 @@
//nolint:revive,gocritic,errname,unconvert
package audit
import (
"context"
"log/slog"
"os"
"golang.org/x/xerrors"
)
// MultiAuditor wraps multiple auditors and sends audit events to all of them.
type MultiAuditor struct {
auditors []Auditor
}
// NewMultiAuditor creates a new MultiAuditor that sends to all provided auditors.
func NewMultiAuditor(auditors ...Auditor) *MultiAuditor {
return &MultiAuditor{auditors: auditors}
}
// AuditRequest sends the request to all wrapped auditors.
func (m *MultiAuditor) AuditRequest(req Request) {
for _, a := range m.auditors {
a.AuditRequest(req)
}
}
// SetupAuditor creates and configures the appropriate auditors based on the
// provided configuration. It always includes a LogAuditor for stderr logging,
// and conditionally adds a SocketAuditor if audit logs are enabled and the
// workspace agent's log proxy socket exists.
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) {
stderrAuditor := NewLogAuditor(logger)
auditors := []Auditor{stderrAuditor}
if !disableAuditLogs {
if logProxySocketPath == "" {
return nil, xerrors.New("log proxy socket path is undefined")
}
// Since boundary is separately versioned from a Coder deployment, it's possible
// Coder is on an older version that will not create the socket and listen for
// the audit logs. Here we check for the socket to determine if the workspace
// agent is on a new enough version to prevent boundary application log spam from
// trying to connect to the agent. This assumes the agent will run and start the
// log proxy server before boundary runs.
_, err := os.Stat(logProxySocketPath)
if err != nil && !os.IsNotExist(err) {
return nil, xerrors.Errorf("failed to stat log proxy socket: %v", err)
}
agentWillProxy := !os.IsNotExist(err)
if agentWillProxy {
socketAuditor := NewSocketAuditor(logger, logProxySocketPath)
go socketAuditor.Loop(ctx)
auditors = append(auditors, socketAuditor)
} else {
logger.Warn("Audit logs are disabled; workspace agent has not created log proxy socket",
"socket", logProxySocketPath)
}
} else {
logger.Warn("Audit logs are disabled by configuration")
}
return NewMultiAuditor(auditors...), nil
}
@@ -0,0 +1,143 @@
//nolint:paralleltest,testpackage,revive,gocritic
package audit
import (
"context"
"io"
"log/slog"
"os"
"path/filepath"
"testing"
)
type mockAuditor struct {
onAudit func(req Request)
}
func (m *mockAuditor) AuditRequest(req Request) {
if m.onAudit != nil {
m.onAudit(req)
}
}
func TestSetupAuditor_DisabledAuditLogs(t *testing.T) {
t.Parallel()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()
auditor, err := SetupAuditor(ctx, logger, true, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
multi, ok := auditor.(*MultiAuditor)
if !ok {
t.Fatalf("expected *MultiAuditor, got %T", auditor)
}
if len(multi.auditors) != 1 {
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
}
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
}
}
func TestSetupAuditor_EmptySocketPath(t *testing.T) {
t.Parallel()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()
_, err := SetupAuditor(ctx, logger, false, "")
if err == nil {
t.Fatal("expected error for empty socket path, got nil")
}
}
func TestSetupAuditor_SocketDoesNotExist(t *testing.T) {
t.Parallel()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx := context.Background()
auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
multi, ok := auditor.(*MultiAuditor)
if !ok {
t.Fatalf("expected *MultiAuditor, got %T", auditor)
}
if len(multi.auditors) != 1 {
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
}
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
}
}
func TestSetupAuditor_SocketExists(t *testing.T) {
t.Parallel()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a temporary file to simulate the socket existing
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
f, err := os.Create(socketPath)
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
err = f.Close()
if err != nil {
t.Fatalf("failed to close temp file: %v", err)
}
auditor, err := SetupAuditor(ctx, logger, false, socketPath)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
multi, ok := auditor.(*MultiAuditor)
if !ok {
t.Fatalf("expected *MultiAuditor, got %T", auditor)
}
if len(multi.auditors) != 2 {
t.Errorf("expected 2 auditors, got %d", len(multi.auditors))
}
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
t.Errorf("expected first auditor to be *LogAuditor, got %T", multi.auditors[0])
}
if _, ok := multi.auditors[1].(*SocketAuditor); !ok {
t.Errorf("expected second auditor to be *SocketAuditor, got %T", multi.auditors[1])
}
}
func TestMultiAuditor_AuditRequest(t *testing.T) {
t.Parallel()
var called1, called2 bool
auditor1 := &mockAuditor{onAudit: func(req Request) { called1 = true }}
auditor2 := &mockAuditor{onAudit: func(req Request) { called2 = true }}
multi := NewMultiAuditor(auditor1, auditor2)
multi.AuditRequest(Request{Method: "GET", URL: "https://example.com"})
if !called1 {
t.Error("expected first auditor to be called")
}
if !called2 {
t.Error("expected second auditor to be called")
}
}
+15
View File
@@ -0,0 +1,15 @@
//nolint:revive,gocritic,errname,unconvert
package audit
type Auditor interface {
AuditRequest(req Request)
}
// Request represents information about an HTTP request for auditing
type Request struct {
Method string
URL string // The fully qualified request URL (scheme, domain, optional path).
Host string
Allowed bool
Rule string // The rule that matched (if any)
}
@@ -0,0 +1,247 @@
//nolint:revive,gocritic,errname,unconvert
package audit
import (
"context"
"log/slog"
"net"
"time"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
agentproto "github.com/coder/coder/v2/agent/proto"
)
const (
// The batch size and timer duration are chosen to provide reasonable responsiveness
// for consumers of the aggregated logs while still minimizing the agent <-> coderd
// network I/O when an AI agent is actively making network requests.
defaultBatchSize = 10
defaultBatchTimerDuration = 5 * time.Second
)
// SocketAuditor implements the Auditor interface. It sends logs to the
// workspace agent's boundary log proxy socket. It queues logs and sends
// them in batches using a batch size and timer. The internal queue operates
// as a FIFO i.e., logs are sent in the order they are received and dropped
// if the queue is full.
type SocketAuditor struct {
dial func() (net.Conn, error)
logger *slog.Logger
logCh chan *agentproto.BoundaryLog
batchSize int
batchTimerDuration time.Duration
socketPath string
// onFlushAttempt is called after each flush attempt (intended for testing).
onFlushAttempt func()
}
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
// This channel buffer size intends to allow enough buffering for bursty
// AI agent network requests while a batch is being sent to the workspace
// agent.
const logChBufSize = 2 * defaultBatchSize
return &SocketAuditor{
dial: func() (net.Conn, error) {
return net.Dial("unix", socketPath)
},
logger: logger,
logCh: make(chan *agentproto.BoundaryLog, logChBufSize),
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
socketPath: socketPath,
}
}
// AuditRequest implements the Auditor interface. It queues the log to be sent to the
// agent in a batch.
func (s *SocketAuditor) AuditRequest(req Request) {
httpReq := &agentproto.BoundaryLog_HttpRequest{
Method: req.Method,
Url: req.URL,
}
// Only include the matched rule for allowed requests. Boundary is deny by
// default, so rules are what allow requests.
if req.Allowed {
httpReq.MatchedRule = req.Rule
}
log := &agentproto.BoundaryLog{
Allowed: req.Allowed,
Time: timestamppb.Now(),
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
}
select {
case s.logCh <- log:
default:
s.logger.Warn("audit log dropped, channel full")
}
}
// flushErr represents an error from flush, distinguishing between
// permanent errors (bad data) and transient errors (network issues).
type flushErr struct {
err error
permanent bool
}
func (e *flushErr) Error() string { return e.err.Error() }
// flush sends the current batch of logs to the given connection.
func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr {
if len(logs) == 0 {
return nil
}
req := &agentproto.ReportBoundaryLogsRequest{
Logs: logs,
}
data, err := proto.Marshal(req)
if err != nil {
return &flushErr{err: err, permanent: true}
}
err = codec.WriteFrame(conn, codec.TagV1, data)
if err != nil {
return &flushErr{err: xerrors.Errorf("write frame: %x", err)}
}
return nil
}
// Loop handles the I/O to send audit logs to the agent.
func (s *SocketAuditor) Loop(ctx context.Context) {
var conn net.Conn
batch := make([]*agentproto.BoundaryLog, 0, s.batchSize)
t := time.NewTimer(0)
t.Stop()
// connect attempts to establish a connection to the socket.
connect := func() {
if conn != nil {
return
}
var err error
conn, err = s.dial()
if err != nil {
s.logger.Warn("failed to connect to audit socket", "path", s.socketPath, "error", err)
conn = nil
}
}
// closeConn closes the current connection if open.
closeConn := func() {
if conn != nil {
_ = conn.Close()
conn = nil
}
}
// clearBatch resets the length of the batch and frees memory while preserving
// the batch slice backing array.
clearBatch := func() {
for i := range len(batch) {
batch[i] = nil
}
batch = batch[:0]
}
// doFlush flushes the batch and handles errors by reconnecting.
doFlush := func() {
t.Stop()
defer func() {
if s.onFlushAttempt != nil {
s.onFlushAttempt()
}
}()
if len(batch) == 0 {
return
}
connect()
if conn == nil {
// No connection: logs will be retried on next flush.
s.logger.Warn("no connection to flush; resetting batch timer",
"duration_sec", s.batchTimerDuration.Seconds(),
"batch_size", len(batch))
// Reset the timer so we aren't stuck waiting for the batch to fill
// or a new log to arrive before the next attempt.
t.Reset(s.batchTimerDuration)
return
}
if err := flush(conn, batch); err != nil {
if err.permanent {
// Data error: discard batch to avoid infinite retries.
s.logger.Warn("dropping batch due to data error on flush attempt",
"error", err, "batch_size", len(batch))
clearBatch()
} else {
// Network error: close connection but keep batch and retry.
s.logger.Warn("failed to flush audit logs; resetting batch timer to reconnect and retry",
"error", err, "duration_sec", s.batchTimerDuration.Seconds(),
"batch_size", len(batch))
closeConn()
// Reset the timer so we aren't stuck waiting for a new log to
// arrive before the next attempt.
t.Reset(s.batchTimerDuration)
}
return
}
clearBatch()
}
connect()
for {
select {
case <-ctx.Done():
// Drain any pending logs before the last flush. Not concerned about
// growing the batch slice here since we're exiting.
drain:
for {
select {
case log := <-s.logCh:
batch = append(batch, log)
default:
break drain
}
}
doFlush()
closeConn()
return
case <-t.C:
doFlush()
case log := <-s.logCh:
// If batch is at capacity, attempt flushing first and drop the log if
// the batch still full.
if len(batch) >= s.batchSize {
doFlush()
if len(batch) >= s.batchSize {
s.logger.Warn("audit log dropped, batch full")
continue
}
}
batch = append(batch, log)
if len(batch) == 1 {
t.Reset(s.batchTimerDuration)
}
if len(batch) >= s.batchSize {
doFlush()
}
}
}
}
@@ -0,0 +1,373 @@
//nolint:paralleltest,testpackage,revive,gocritic
package audit
import (
"context"
"io"
"log/slog"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/xerrors"
"google.golang.org/protobuf/proto"
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
agentproto "github.com/coder/coder/v2/agent/proto"
)
func TestSocketAuditor_AuditRequest_QueuesLog(t *testing.T) {
t.Parallel()
auditor := setupSocketAuditor(t)
auditor.AuditRequest(Request{
Method: "GET",
URL: "https://example.com",
Host: "example.com",
Allowed: true,
Rule: "allow-all",
})
select {
case log := <-auditor.logCh:
if log.Allowed != true {
t.Errorf("expected Allowed=true, got %v", log.Allowed)
}
httpReq := log.GetHttpRequest()
if httpReq == nil {
t.Fatal("expected HttpRequest, got nil")
}
if httpReq.Method != "GET" {
t.Errorf("expected Method=GET, got %s", httpReq.Method)
}
if httpReq.Url != "https://example.com" {
t.Errorf("expected URL=https://example.com, got %s", httpReq.Url)
}
// Rule should be set for allowed requests
if httpReq.MatchedRule != "allow-all" {
t.Errorf("unexpected MatchedRule %v", httpReq.MatchedRule)
}
default:
t.Fatal("expected log in channel, got none")
}
}
func TestSocketAuditor_AuditRequest_AllowIncludesRule(t *testing.T) {
t.Parallel()
auditor := setupSocketAuditor(t)
auditor.AuditRequest(Request{
Method: "POST",
URL: "https://evil.com",
Host: "evil.com",
Allowed: true,
Rule: "allow-evil",
})
select {
case log := <-auditor.logCh:
if log.Allowed != true {
t.Errorf("expected Allowed=false, got %v", log.Allowed)
}
httpReq := log.GetHttpRequest()
if httpReq == nil {
t.Fatal("expected HttpRequest, got nil")
}
if httpReq.MatchedRule != "allow-evil" {
t.Errorf("expected MatchedRule=allow-evil, got %s", httpReq.MatchedRule)
}
default:
t.Fatal("expected log in channel, got none")
}
}
func TestSocketAuditor_AuditRequest_DropsWhenFull(t *testing.T) {
t.Parallel()
auditor := setupSocketAuditor(t)
// Fill the channel (capacity is 2*batchSize = 20)
for i := 0; i < 2*auditor.batchSize; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
}
// This should not block and drop the log
auditor.AuditRequest(Request{Method: "GET", URL: "https://dropped.com", Allowed: true})
// Drain the channel and verify all entries are from the original batch (dropped.com was dropped)
for i := 0; i < 2*auditor.batchSize; i++ {
v := <-auditor.logCh
resource, ok := v.Resource.(*agentproto.BoundaryLog_HttpRequest_)
if !ok {
t.Fatal("unexpected resource type")
}
if resource.HttpRequest.Url != "https://example.com" {
t.Errorf("expected batch to be FIFO, got %s", resource.HttpRequest.Url)
}
}
select {
case v := <-auditor.logCh:
t.Errorf("expected empty channel, got %v", v)
default:
}
}
func TestSocketAuditor_Loop_FlushesOnBatchSize(t *testing.T) {
t.Parallel()
auditor, serverConn := setupTestAuditor(t)
auditor.batchTimerDuration = time.Hour // Ensure timer doesn't interfere with the test
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
go readFromConn(t, serverConn, received)
go auditor.Loop(t.Context())
// Send exactly a full batch of logs to trigger a flush
for i := 0; i < auditor.batchSize; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
}
select {
case req := <-received:
if len(req.Logs) != auditor.batchSize {
t.Errorf("expected %d logs, got %d", auditor.batchSize, len(req.Logs))
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for flush")
}
}
func TestSocketAuditor_Loop_FlushesOnTimer(t *testing.T) {
t.Parallel()
auditor, serverConn := setupTestAuditor(t)
auditor.batchTimerDuration = 3 * time.Second
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
go readFromConn(t, serverConn, received)
go auditor.Loop(t.Context())
// A single log should start the timer
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
// Should flush after the timer duration elapses
select {
case req := <-received:
if len(req.Logs) != 1 {
t.Errorf("expected 1 log, got %d", len(req.Logs))
}
case <-time.After(2 * auditor.batchTimerDuration):
t.Fatal("timeout waiting for timer flush")
}
}
func TestSocketAuditor_Loop_FlushesOnContextCancel(t *testing.T) {
t.Parallel()
auditor, serverConn := setupTestAuditor(t)
// Make the timer long to always exercise the context cancellation case
auditor.batchTimerDuration = time.Hour
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
go readFromConn(t, serverConn, received)
ctx, cancel := context.WithCancel(t.Context())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
auditor.Loop(ctx)
}()
// Send a log but don't fill the batch
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
cancel()
select {
case req := <-received:
if len(req.Logs) != 1 {
t.Errorf("expected 1 log, got %d", len(req.Logs))
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for shutdown flush")
}
wg.Wait()
}
func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) {
t.Parallel()
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
err := clientConn.Close()
if err != nil {
t.Errorf("close client connection: %v", err)
}
err = serverConn.Close()
if err != nil {
t.Errorf("close server connection: %v", err)
}
})
var dialCount atomic.Int32
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
auditor := &SocketAuditor{
dial: func() (net.Conn, error) {
// First dial attempt fails, subsequent ones succeed
if dialCount.Add(1) == 1 {
return nil, xerrors.New("connection refused")
}
return clientConn, nil
},
logger: logger,
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test
}
// Set up hook to detect flush attempts
flushed := make(chan struct{}, 1)
auditor.onFlushAttempt = func() {
select {
case flushed <- struct{}{}:
default:
}
}
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
go readFromConn(t, serverConn, received)
go auditor.Loop(t.Context())
// Send batchSize+1 logs so we can verify the last log here gets dropped.
for i := 0; i < auditor.batchSize+1; i++ {
auditor.AuditRequest(Request{Method: "GET", URL: "https://servernotup.com", Allowed: true})
}
// Wait for the first flush attempt (which will fail)
select {
case <-flushed:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for first flush attempt")
}
// Send one more log - batch is at capacity, so this triggers flush first
// The flush succeeds (dial now works), sending the retained batch.
auditor.AuditRequest(Request{Method: "POST", URL: "https://serverup.com", Allowed: true})
// Should receive the retained batch (the new log goes into a fresh batch)
select {
case req := <-received:
if len(req.Logs) != auditor.batchSize {
t.Errorf("expected %d logs from retry, got %d", auditor.batchSize, len(req.Logs))
}
for _, log := range req.Logs {
resource, ok := log.Resource.(*agentproto.BoundaryLog_HttpRequest_)
if !ok {
t.Fatal("unexpected resource type")
}
if resource.HttpRequest.Url != "https://servernotup.com" {
t.Errorf("expected URL https://servernotup.com, got %v", resource.HttpRequest.Url)
}
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for retry flush")
}
}
func TestFlush_EmptyBatch(t *testing.T) {
t.Parallel()
err := flush(nil, nil)
if err != nil {
t.Errorf("expected nil error for empty batch, got %v", err)
}
err = flush(nil, []*agentproto.BoundaryLog{})
if err != nil {
t.Errorf("expected nil error for empty slice, got %v", err)
}
}
// setupSocketAuditor creates a SocketAuditor for tests that only exercise
// the queueing behavior (no connection needed).
func setupSocketAuditor(t *testing.T) *SocketAuditor {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
return &SocketAuditor{
dial: func() (net.Conn, error) {
return nil, xerrors.New("not connected")
},
logger: logger,
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
}
}
// setupTestAuditor creates a SocketAuditor with an in-memory connection using
// net.Pipe(). Returns the auditor and the server-side connection for reading.
func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) {
t.Helper()
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
err := clientConn.Close()
if err != nil {
t.Error("Failed to close client connection", "error", err)
}
err = serverConn.Close()
if err != nil {
t.Error("Failed to close server connection", "error", err)
}
})
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
auditor := &SocketAuditor{
dial: func() (net.Conn, error) {
return clientConn, nil
},
logger: logger,
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
batchSize: defaultBatchSize,
batchTimerDuration: defaultBatchTimerDuration,
}
return auditor, serverConn
}
// readFromConn reads length-prefixed protobuf messages from a connection and
// sends them to the received channel.
func readFromConn(t *testing.T, conn net.Conn, received chan<- *agentproto.ReportBoundaryLogsRequest) {
t.Helper()
buf := make([]byte, 1<<10)
for {
tag, data, err := codec.ReadFrame(conn, buf)
if err != nil {
return // connection closed
}
if tag != codec.TagV1 {
t.Errorf("invalid tag: %d", tag)
}
var req agentproto.ReportBoundaryLogsRequest
if err := proto.Unmarshal(data, &req); err != nil {
t.Errorf("failed to unmarshal: %v", err)
return
}
received <- &req
}
}
+200
View File
@@ -0,0 +1,200 @@
//go:build linux
//nolint:revive,gocritic,errname,unconvert
package boundary
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/coder/coder/v2/agent/boundarylogproxy"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/log"
"github.com/coder/coder/v2/enterprise/cli/boundary/run"
"github.com/coder/serpent"
)
// printVersion prints version information.
func printVersion(version string) {
fmt.Println(version)
}
// NewCommand creates and returns the root serpent command
func NewCommand(version string) *serpent.Command {
// To make the top level boundary command, we just make some minor changes to the base command
cmd := BaseCommand(version)
cmd.Use = "boundary [flags] -- command [args...]" // Add the flags and args pieces to usage.
// Add example usage to the long description. This is different from usage as a subcommand because it
// may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder boundary` vs. `boundary`).
cmd.Long += `Examples:
# Allow only requests to github.com
boundary --allow "domain=github.com" -- curl https://github.com
# Monitor all requests to specific domains (allow only those)
boundary --allow "domain=github.com path=/api/issues/*" --allow "method=GET,HEAD domain=github.com" -- npm install
# Use allowlist from config file with additional CLI allow rules
boundary --allow "domain=example.com" -- curl https://example.com
# Block everything by default (implicit)`
return cmd
}
// Base command returns the boundary serpent command without the information involved in making it the
// *top level* serpent command. We are creating this split to make it easier to integrate into the coder
// CLI if needed.
func BaseCommand(version string) *serpent.Command {
cliConfig := config.CliConfig{}
var showVersion serpent.Bool
// Set default config path if file exists - serpent will load it automatically
if home, err := os.UserHomeDir(); err == nil {
defaultPath := filepath.Join(home, ".config", "coder_boundary", "config.yaml")
if _, err := os.Stat(defaultPath); err == nil {
cliConfig.Config = serpent.YAMLConfigPath(defaultPath)
}
}
return &serpent.Command{
Use: "boundary",
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
Options: []serpent.Option{
{
Flag: "config",
Env: "BOUNDARY_CONFIG",
Description: "Path to YAML config file.",
Value: &cliConfig.Config,
YAML: "",
},
{
Flag: "allow",
Env: "BOUNDARY_ALLOW",
Description: "Allow rule (repeatable). These are merged with allowlist from config file. Format: \"pattern\" or \"METHOD[,METHOD] pattern\".",
Value: &cliConfig.AllowStrings,
YAML: "", // CLI only, not loaded from YAML
},
{
Flag: "allowlist",
Description: "Allowlist rules from config file (YAML only).",
Value: &cliConfig.AllowListStrings,
YAML: "allowlist",
Hidden: true, // Hidden because it's primarily for YAML config
},
{
Flag: "log-level",
Env: "BOUNDARY_LOG_LEVEL",
Description: "Set log level (error, warn, info, debug).",
Default: "warn",
Value: &cliConfig.LogLevel,
YAML: "log_level",
},
{
Flag: "log-dir",
Env: "BOUNDARY_LOG_DIR",
Description: "Set a directory to write logs to rather than stderr.",
Value: &cliConfig.LogDir,
YAML: "log_dir",
},
{
Flag: "proxy-port",
Env: "PROXY_PORT",
Description: "Set a port for HTTP proxy.",
Default: "8080",
Value: &cliConfig.ProxyPort,
YAML: "proxy_port",
},
{
Flag: "pprof",
Env: "BOUNDARY_PPROF",
Description: "Enable pprof profiling server.",
Value: &cliConfig.PprofEnabled,
YAML: "pprof_enabled",
},
{
Flag: "pprof-port",
Env: "BOUNDARY_PPROF_PORT",
Description: "Set port for pprof profiling server.",
Default: "6060",
Value: &cliConfig.PprofPort,
YAML: "pprof_port",
},
{
Flag: "configure-dns-for-local-stub-resolver",
Env: "BOUNDARY_CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER",
Description: "Configure DNS for local stub resolver (e.g., systemd-resolved). Only needed when /etc/resolv.conf contains nameserver 127.0.0.53.",
Value: &cliConfig.ConfigureDNSForLocalStubResolver,
YAML: "configure_dns_for_local_stub_resolver",
},
{
Flag: "jail-type",
Env: "BOUNDARY_JAIL_TYPE",
Description: "Jail type to use for network isolation. Options: nsjail (default), landjail.",
Default: "nsjail",
Value: &cliConfig.JailType,
YAML: "jail_type",
},
{
Flag: "disable-audit-logs",
Env: "DISABLE_AUDIT_LOGS",
Description: "Disable sending of audit logs to the workspace agent when set to true.",
Value: &cliConfig.DisableAuditLogs,
YAML: "disable_audit_logs",
},
{
Flag: "log-proxy-socket-path",
Description: "Path to the socket where the boundary log proxy server listens for audit logs.",
// Important: this default must be the same default path used by the
// workspace agent to ensure agreement of the default socket path without
// explicit configuration.
Default: boundarylogproxy.DefaultSocketPath(),
// Important: this must be the same variable name used by the workspace agent
// to allow a single environment variable to configure both boundary and the
// workspace agent.
Env: "CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH",
Value: &cliConfig.LogProxySocketPath,
YAML: "", // CLI only, not loaded from YAML
},
{
Flag: "version",
Description: "Print version information and exit.",
Value: &showVersion,
YAML: "", // CLI only
},
},
Handler: func(inv *serpent.Invocation) error {
// Handle --version flag early
if showVersion.Value() {
printVersion(version)
return nil
}
appConfig, err := config.NewAppConfigFromCliConfig(cliConfig, inv.Args)
if err != nil {
return fmt.Errorf("failed to parse cli config file: %v", err)
}
// Get command arguments
if len(appConfig.TargetCMD) == 0 {
return fmt.Errorf("no command specified")
}
logger, err := log.SetupLogging(appConfig)
if err != nil {
return fmt.Errorf("could not set up logging: %v", err)
}
appConfigInJSON, err := json.Marshal(appConfig)
if err != nil {
return err
}
logger.Debug("Application config", "config", appConfigInJSON)
return run.Run(inv.Context(), logger, appConfig)
},
}
}
+26
View File
@@ -0,0 +1,26 @@
//go:build !linux
//nolint:revive,gocritic,errname,unconvert
package boundary
import (
"runtime"
"golang.org/x/xerrors"
"github.com/coder/serpent"
)
// BaseCommand returns the boundary serpent command. On non-Linux platforms,
// boundary is not supported and returns an error.
func BaseCommand(_ string) *serpent.Command {
return &serpent.Command{
Use: "boundary",
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
Handler: func(_ *serpent.Invocation) error {
return xerrors.Errorf("boundary is only supported on Linux (current OS: %s)", runtime.GOOS)
},
}
}
+119
View File
@@ -0,0 +1,119 @@
//nolint:revive,gocritic,errname,unconvert
package config
import (
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
"github.com/coder/serpent"
)
// JailType represents the type of jail to use for network isolation
type JailType string
const (
NSJailType JailType = "nsjail"
LandjailType JailType = "landjail"
)
func NewJailTypeFromString(str string) (JailType, error) {
switch str {
case "nsjail":
return NSJailType, nil
case "landjail":
return LandjailType, nil
default:
return NSJailType, xerrors.Errorf("invalid JailType: %s", str)
}
}
// AllowStringsArray is a custom type that implements pflag.Value to support
// repeatable --allow flags without splitting on commas. This allows comma-separated
// paths within a single allow rule (e.g., "path=/todos/1,/todos/2").
type AllowStringsArray []string
var _ pflag.Value = (*AllowStringsArray)(nil)
// Set implements pflag.Value. It appends the value to the slice without splitting on commas.
func (a *AllowStringsArray) Set(value string) error {
*a = append(*a, value)
return nil
}
// String implements pflag.Value.
func (a AllowStringsArray) String() string {
return strings.Join(a, ",")
}
// Type implements pflag.Value.
func (a AllowStringsArray) Type() string {
return "string"
}
// Value returns the underlying slice of strings.
func (a AllowStringsArray) Value() []string {
return []string(a)
}
type CliConfig struct {
Config serpent.YAMLConfigPath `yaml:"-"`
AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file
AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only
LogLevel serpent.String `yaml:"log_level"`
LogDir serpent.String `yaml:"log_dir"`
ProxyPort serpent.Int64 `yaml:"proxy_port"`
PprofEnabled serpent.Bool `yaml:"pprof_enabled"`
PprofPort serpent.Int64 `yaml:"pprof_port"`
ConfigureDNSForLocalStubResolver serpent.Bool `yaml:"configure_dns_for_local_stub_resolver"`
JailType serpent.String `yaml:"jail_type"`
DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"`
LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"`
}
type AppConfig struct {
AllowRules []string
LogLevel string
LogDir string
ProxyPort int64
PprofEnabled bool
PprofPort int64
ConfigureDNSForLocalStubResolver bool
JailType JailType
TargetCMD []string
UserInfo *UserInfo
DisableAuditLogs bool
LogProxySocketPath string
}
func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) {
// Merge allowlist from config file with allow from CLI flags
allowListStrings := cfg.AllowListStrings.Value()
allowStrings := cfg.AllowStrings.Value()
// Combine allowlist (config file) with allow (CLI flags)
allAllowStrings := append(allowListStrings, allowStrings...)
jailType, err := NewJailTypeFromString(cfg.JailType.Value())
if err != nil {
return AppConfig{}, err
}
userInfo := GetUserInfo()
return AppConfig{
AllowRules: allAllowStrings,
LogLevel: cfg.LogLevel.Value(),
LogDir: cfg.LogDir.Value(),
ProxyPort: cfg.ProxyPort.Value(),
PprofEnabled: cfg.PprofEnabled.Value(),
PprofPort: cfg.PprofPort.Value(),
ConfigureDNSForLocalStubResolver: cfg.ConfigureDNSForLocalStubResolver.Value(),
JailType: jailType,
TargetCMD: targetCMD,
UserInfo: userInfo,
DisableAuditLogs: cfg.DisableAuditLogs.Value(),
LogProxySocketPath: cfg.LogProxySocketPath.Value(),
}, nil
}
+103
View File
@@ -0,0 +1,103 @@
//nolint:revive,gocritic,errname,unconvert
package config
import (
"os"
"os/user"
"path/filepath"
"strconv"
)
const (
CAKeyName = "ca-key.pem"
CACertName = "ca-cert.pem"
)
type UserInfo struct {
SudoUser string
Uid int
Gid int
HomeDir string
ConfigDir string
}
// GetUserInfo returns information about the current user, handling sudo scenarios
func GetUserInfo() *UserInfo {
// Only consider SUDO_USER if we're actually running with elevated privileges
// In environments like Coder workspaces, SUDO_USER may be set to 'root'
// but we're not actually running under sudo
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" && os.Geteuid() == 0 && sudoUser != "root" {
// We're actually running under sudo with a non-root original user
user, err := user.Lookup(sudoUser)
if err != nil {
return getCurrentUserInfo() // Fallback to current user
}
uid, _ := strconv.Atoi(os.Getenv("SUDO_UID"))
gid, _ := strconv.Atoi(os.Getenv("SUDO_GID"))
// If we couldn't get UID/GID from env, parse from user info
if uid == 0 {
if parsedUID, err := strconv.Atoi(user.Uid); err == nil {
uid = parsedUID
}
}
if gid == 0 {
if parsedGID, err := strconv.Atoi(user.Gid); err == nil {
gid = parsedGID
}
}
configDir := getConfigDir(user.HomeDir)
return &UserInfo{
SudoUser: sudoUser,
Uid: uid,
Gid: gid,
HomeDir: user.HomeDir,
ConfigDir: configDir,
}
}
// Not actually running under sudo, use current user
return getCurrentUserInfo()
}
// getCurrentUserInfo gets information for the current user
func getCurrentUserInfo() *UserInfo {
currentUser, err := user.Current()
if err != nil {
// Fallback with empty values if we can't get user info
return &UserInfo{}
}
uid, _ := strconv.Atoi(currentUser.Uid)
gid, _ := strconv.Atoi(currentUser.Gid)
configDir := getConfigDir(currentUser.HomeDir)
return &UserInfo{
SudoUser: currentUser.Username,
Uid: uid,
Gid: gid,
HomeDir: currentUser.HomeDir,
ConfigDir: configDir,
}
}
// getConfigDir determines the config directory based on XDG_CONFIG_HOME or fallback
func getConfigDir(homeDir string) string {
// Use XDG_CONFIG_HOME if set, otherwise fallback to ~/.config/coder_boundary
if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" {
return filepath.Join(xdgConfigHome, "coder_boundary")
}
return filepath.Join(homeDir, ".config", "coder_boundary")
}
func (u *UserInfo) CAKeyPath() string {
return filepath.Join(u.ConfigDir, CAKeyName)
}
func (u *UserInfo) CACertPath() string {
return filepath.Join(u.ConfigDir, CACertName)
}
+105
View File
@@ -0,0 +1,105 @@
//go:build linux
package landjail
import (
"fmt"
"log/slog"
"os"
"os/exec"
"github.com/landlock-lsm/go-landlock/landlock"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/util"
)
type LandlockConfig struct {
// TODO(yevhenii):
// - should it be able to bind to any port?
// - should it be able to connect to any port on localhost?
// BindTCPPorts []int
ConnectTCPPorts []int
}
func ApplyLandlockRestrictions(logger *slog.Logger, cfg LandlockConfig) error {
// Get the Landlock version which works for Kernel 6.7+
llCfg := landlock.V4
// Collect our rules
var netRules []landlock.Rule
// Add rules for TCP connections
for _, port := range cfg.ConnectTCPPorts {
logger.Debug("Adding TCP connect port", "port", port)
netRules = append(netRules, landlock.ConnectTCP(uint16(port)))
}
err := llCfg.RestrictNet(netRules...)
if err != nil {
return fmt.Errorf("failed to apply Landlock network restrictions: %w", err)
}
return nil
}
func RunChild(logger *slog.Logger, config config.AppConfig) error {
landjailCfg := LandlockConfig{
ConnectTCPPorts: []int{int(config.ProxyPort)},
}
err := ApplyLandlockRestrictions(logger, landjailCfg)
if err != nil {
return fmt.Errorf("failed to apply Landlock network restrictions: %v", err)
}
// Build command
cmd := exec.Command(config.TargetCMD[0], config.TargetCMD[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
logger.Info("Executing target command", "command", config.TargetCMD)
// Run the command - this will block until it completes
err = cmd.Run()
if err != nil {
// Check if this is a normal exit with non-zero status code
if exitError, ok := err.(*exec.ExitError); ok {
exitCode := exitError.ExitCode()
logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
return fmt.Errorf("command exited with code %d", exitCode)
}
// This is an unexpected error
logger.Error("Command execution failed", "error", err)
return fmt.Errorf("command execution failed: %v", err)
}
logger.Debug("Command completed successfully")
return nil
}
// Returns environment variables intended to be set on the child process,
// so they can later be inherited by the target process.
func getEnvsForTargetProcess(configDir string, caCertPath string, httpProxyPort int) []string {
e := os.Environ()
proxyAddr := fmt.Sprintf("http://localhost:%d", httpProxyPort)
e = util.MergeEnvs(e, map[string]string{
// Set standard CA certificate environment variables for common tools
// This makes tools like curl, git, etc. trust our dynamically generated CA
"SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools
"SSL_CERT_DIR": configDir, // OpenSSL certificate directory
"CURL_CA_BUNDLE": caCertPath, // curl
"GIT_SSL_CAINFO": caCertPath, // Git
"REQUESTS_CA_BUNDLE": caCertPath, // Python requests
"NODE_EXTRA_CA_CERTS": caCertPath, // Node.js
"HTTP_PROXY": proxyAddr,
"HTTPS_PROXY": proxyAddr,
"http_proxy": proxyAddr,
"https_proxy": proxyAddr,
})
return e
}
+167
View File
@@ -0,0 +1,167 @@
//go:build linux
package landjail
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/proxy"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
)
type LandJail struct {
proxyServer *proxy.Server
logger *slog.Logger
config config.AppConfig
}
func NewLandJail(
ruleEngine rulesengine.Engine,
auditor audit.Auditor,
tlsConfig *tls.Config,
logger *slog.Logger,
config config.AppConfig,
) (*LandJail, error) {
// Create proxy server
proxyServer := proxy.NewProxyServer(proxy.Config{
HTTPPort: int(config.ProxyPort),
RuleEngine: ruleEngine,
Auditor: auditor,
Logger: logger,
TLSConfig: tlsConfig,
PprofEnabled: config.PprofEnabled,
PprofPort: int(config.PprofPort),
})
return &LandJail{
config: config,
proxyServer: proxyServer,
logger: logger,
}, nil
}
func (b *LandJail) Run(ctx context.Context) error {
b.logger.Info("Start landjail manager")
err := b.startProxy()
if err != nil {
return fmt.Errorf("failed to start landjail manager: %v", err)
}
defer func() {
b.logger.Info("Stop landjail manager")
err := b.stopProxy()
if err != nil {
b.logger.Error("Failed to stop landjail manager", "error", err)
}
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
defer cancel()
err := b.RunChildProcess(os.Args)
if err != nil {
b.logger.Error("Failed to run child process", "error", err)
}
}()
// Setup signal handling BEFORE any setup
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Wait for signal or context cancellation
select {
case sig := <-sigChan:
b.logger.Info("Received signal, shutting down...", "signal", sig)
cancel()
case <-ctx.Done():
// Context canceled by command completion
b.logger.Info("Command completed, shutting down...")
}
return nil
}
func (b *LandJail) RunChildProcess(command []string) error {
childCmd := b.getChildCommand(command)
b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " "))
err := childCmd.Start()
if err != nil {
b.logger.Error("Command failed to start", "error", err)
return err
}
b.logger.Debug("waiting on a child process to finish")
err = childCmd.Wait()
if err != nil {
// Check if this is a normal exit with non-zero status code
if exitError, ok := err.(*exec.ExitError); ok {
exitCode := exitError.ExitCode()
// Log at debug level for non-zero exits (normal behavior)
b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
return err
}
// This is an unexpected error (not just a non-zero exit)
b.logger.Error("Command execution failed", "error", err)
return err
}
b.logger.Debug("Command completed successfully")
return nil
}
func (b *LandJail) getChildCommand(command []string) *exec.Cmd {
cmd := exec.Command(command[0], command[1:]...)
// Set env vars for the child process; they will be inherited by the target process.
cmd.Env = getEnvsForTargetProcess(b.config.UserInfo.ConfigDir, b.config.UserInfo.CACertPath(), int(b.config.ProxyPort))
cmd.Env = append(cmd.Env, "CHILD=true")
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
}
return cmd
}
func (b *LandJail) startProxy() error {
// Start proxy server in background
err := b.proxyServer.Start()
if err != nil {
b.logger.Error("Proxy server error", "error", err)
return err
}
// Give proxy time to start
time.Sleep(100 * time.Millisecond)
return nil
}
func (b *LandJail) stopProxy() error {
// Stop proxy server
if b.proxyServer != nil {
err := b.proxyServer.Stop()
if err != nil {
b.logger.Error("Failed to stop proxy server", "error", err)
}
}
return nil
}
@@ -0,0 +1,61 @@
//go:build linux
package landjail
import (
"context"
"fmt"
"log/slog"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
"github.com/coder/coder/v2/enterprise/cli/boundary/tls"
)
func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
if len(config.AllowRules) == 0 {
logger.Warn("No allow rules specified; all network traffic will be denied by default")
}
// Parse allow rules
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules)
if err != nil {
logger.Error("Failed to parse allow rules", "error", err)
return fmt.Errorf("failed to parse allow rules: %v", err)
}
// Create rule engine
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
// Create auditor
auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath)
if err != nil {
return fmt.Errorf("failed to setup auditor: %v", err)
}
// Create TLS certificate manager
certManager, err := tls.NewCertificateManager(tls.Config{
Logger: logger,
ConfigDir: config.UserInfo.ConfigDir,
Uid: config.UserInfo.Uid,
Gid: config.UserInfo.Gid,
})
if err != nil {
logger.Error("Failed to create certificate manager", "error", err)
return fmt.Errorf("failed to create certificate manager: %v", err)
}
// Setup TLS to get cert path for jailer
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
if err != nil {
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
}
landjail, err := NewLandJail(ruleEngine, auditor, tlsConfig, logger, config)
if err != nil {
return fmt.Errorf("failed to create landjail: %v", err)
}
return landjail.Run(ctx)
}
+27
View File
@@ -0,0 +1,27 @@
//go:build linux
package landjail
import (
"context"
"log/slog"
"os"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
)
func isChild() bool {
return os.Getenv("CHILD") == "true"
}
// Run is the main entry point that determines whether to execute as a parent or child process.
// If running as a child (CHILD env var is set), it applies landlock restrictions
// and executes the target command. Otherwise, it runs as the parent process, sets up the proxy server,
// and manages the child process lifecycle.
func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
if isChild() {
return RunChild(logger, config)
}
return RunParent(ctx, logger, config)
}
+62
View File
@@ -0,0 +1,62 @@
//nolint:revive,gocritic,errname,unconvert
package log
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
)
// SetupLogging creates a slog logger with the specified level
func SetupLogging(config config.AppConfig) (*slog.Logger, error) {
var level slog.Level
switch strings.ToLower(config.LogLevel) {
case "error":
level = slog.LevelError
case "warn":
level = slog.LevelWarn
case "info":
level = slog.LevelInfo
case "debug":
level = slog.LevelDebug
default:
level = slog.LevelWarn // Default to warn if invalid level
}
logTarget := os.Stderr
logDir := config.LogDir
if logDir != "" {
// Set up the logging directory if it doesn't exist yet
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, xerrors.Errorf("could not set up log dir %s: %v", logDir, err)
}
// Create a logfile (timestamp and pid to avoid race conditions with multiple boundary calls running)
logFilePath := fmt.Sprintf("boundary-%s-%d.log",
time.Now().Format("2006-01-02_15-04-05"),
os.Getpid())
logFile, err := os.Create(filepath.Join(logDir, logFilePath))
if err != nil {
return nil, xerrors.Errorf("could not create log file %s: %v", logFilePath, err)
}
// Set the log target to the file rather than stderr.
logTarget = logFile
}
// Create a standard slog logger with the appropriate level
handler := slog.NewTextHandler(logTarget, &slog.HandlerOptions{
Level: level,
})
return slog.New(handler), nil
}
@@ -0,0 +1,111 @@
//go:build linux
package nsjail_manager
import (
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"syscall"
"time"
"github.com/cenkalti/backoff/v5"
"golang.org/x/sys/unix"
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
)
// waitForInterface waits for a network interface to appear in the namespace.
// It retries checking for the interface with exponential backoff up to the specified timeout.
func waitForInterface(interfaceName string, timeout time.Duration) error {
b := backoff.NewExponentialBackOff()
b.InitialInterval = 50 * time.Millisecond
b.MaxInterval = 500 * time.Millisecond
b.Multiplier = 2.0
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
operation := func() (bool, error) {
cmd := exec.Command("ip", "link", "show", interfaceName)
cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: []uintptr{uintptr(unix.CAP_NET_ADMIN)},
}
err := cmd.Run()
if err != nil {
return false, fmt.Errorf("interface %s not found: %w", interfaceName, err)
}
// Interface exists
return true, nil
}
_, err := backoff.Retry(ctx, operation, backoff.WithBackOff(b))
if err != nil {
return fmt.Errorf("interface %s did not appear within %v: %w", interfaceName, timeout, err)
}
return nil
}
func RunChild(logger *slog.Logger, targetCMD []string) error {
logger.Info("boundary CHILD process is started")
vethNetJail := os.Getenv("VETH_JAIL_NAME")
if vethNetJail == "" {
return fmt.Errorf("VETH_JAIL_NAME environment variable is not set")
}
// Wait for the veth interface to be moved into the namespace by the parent process
if err := waitForInterface(vethNetJail, 5*time.Second); err != nil {
return fmt.Errorf("failed to wait for interface %s: %w", vethNetJail, err)
}
err := nsjail.SetupChildNetworking(vethNetJail)
if err != nil {
return fmt.Errorf("failed to setup child networking: %v", err)
}
logger.Info("child networking is successfully configured")
if os.Getenv("CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER") == "true" {
err = nsjail.ConfigureDNSForLocalStubResolver()
if err != nil {
return fmt.Errorf("failed to configure DNS in namespace: %v", err)
}
logger.Info("DNS in namespace is configured successfully")
}
// Program to run
bin := targetCMD[0]
args := targetCMD[1:]
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.SysProcAttr = &syscall.SysProcAttr{
Pdeathsig: syscall.SIGTERM,
}
err = cmd.Run()
if err != nil {
// Check if this is a normal exit with non-zero status code
if exitError, ok := err.(*exec.ExitError); ok {
exitCode := exitError.ExitCode()
// Log at debug level for non-zero exits (normal behavior)
logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
// Exit with the same code as the command - don't log as error
// This is normal behavior (commands can exit with any code)
os.Exit(exitCode)
}
// This is an unexpected error (not just a non-zero exit)
// Only log actual errors like "command not found" or "permission denied"
logger.Error("Command execution failed", "error", err)
return err
}
// Command exited successfully
logger.Debug("Command completed successfully")
return nil
}
@@ -0,0 +1,162 @@
//go:build linux
package nsjail_manager
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
"github.com/coder/coder/v2/enterprise/cli/boundary/proxy"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
)
type NSJailManager struct {
jailer nsjail.Jailer
proxyServer *proxy.Server
logger *slog.Logger
config config.AppConfig
}
func NewNSJailManager(
ruleEngine rulesengine.Engine,
auditor audit.Auditor,
tlsConfig *tls.Config,
jailer nsjail.Jailer,
logger *slog.Logger,
config config.AppConfig,
) (*NSJailManager, error) {
// Create proxy server
proxyServer := proxy.NewProxyServer(proxy.Config{
HTTPPort: int(config.ProxyPort),
RuleEngine: ruleEngine,
Auditor: auditor,
Logger: logger,
TLSConfig: tlsConfig,
PprofEnabled: config.PprofEnabled,
PprofPort: int(config.PprofPort),
})
return &NSJailManager{
config: config,
jailer: jailer,
proxyServer: proxyServer,
logger: logger,
}, nil
}
func (b *NSJailManager) Run(ctx context.Context) error {
b.logger.Info("Start namespace-jail manager")
err := b.setupHostAndStartProxy()
if err != nil {
return fmt.Errorf("failed to start namespace-jail manager: %v", err)
}
defer func() {
b.logger.Info("Stop namespace-jail manager")
err := b.stopProxyAndCleanupHost()
if err != nil {
b.logger.Error("Failed to stop namespace-jail manager", "error", err)
}
}()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
defer cancel()
b.RunChildProcess(os.Args)
}()
// Setup signal handling BEFORE any setup
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Wait for signal or context cancellation
select {
case sig := <-sigChan:
b.logger.Info("Received signal, shutting down...", "signal", sig)
cancel()
case <-ctx.Done():
// Context canceled by command completion
b.logger.Info("Command completed, shutting down...")
}
return nil
}
func (b *NSJailManager) RunChildProcess(command []string) {
cmd := b.jailer.Command(command)
b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " "))
err := cmd.Start()
if err != nil {
b.logger.Error("Command failed to start", "error", err)
return
}
err = b.jailer.ConfigureHostNsCommunication(cmd.Process.Pid)
if err != nil {
b.logger.Error("configuration after command execution failed", "error", err)
return
}
b.logger.Debug("waiting on a child process to finish")
err = cmd.Wait()
if err != nil {
// Check if this is a normal exit with non-zero status code
if exitError, ok := err.(*exec.ExitError); ok {
exitCode := exitError.ExitCode()
// Log at debug level for non-zero exits (normal behavior)
b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
} else {
// This is an unexpected error (not just a non-zero exit)
b.logger.Error("Command execution failed", "error", err)
}
return
}
b.logger.Debug("Command completed successfully")
}
func (b *NSJailManager) setupHostAndStartProxy() error {
// Configure the jailer (network isolation)
err := b.jailer.ConfigureHost()
if err != nil {
return fmt.Errorf("failed to start jailer: %v", err)
}
// Start proxy server in background
err = b.proxyServer.Start()
if err != nil {
b.logger.Error("Proxy server error", "error", err)
return err
}
// Give proxy time to start
time.Sleep(100 * time.Millisecond)
return nil
}
func (b *NSJailManager) stopProxyAndCleanupHost() error {
// Stop proxy server
if b.proxyServer != nil {
err := b.proxyServer.Stop()
if err != nil {
b.logger.Error("Failed to stop proxy server", "error", err)
}
}
// Close jailer
return b.jailer.Close()
}
@@ -0,0 +1,90 @@
//go:build linux
package nsjail
import (
"fmt"
"log"
"os/exec"
"strings"
"syscall"
)
type command struct {
description string
cmd *exec.Cmd
ambientCaps []uintptr
// If ignoreErr isn't empty and this specific error occurs, suppress it (dont log it, dont return it).
ignoreErr string
}
func newCommand(
description string,
cmd *exec.Cmd,
ambientCaps []uintptr,
) *command {
return newCommandWithIgnoreErr(description, cmd, ambientCaps, "")
}
func newCommandWithIgnoreErr(
description string,
cmd *exec.Cmd,
ambientCaps []uintptr,
ignoreErr string,
) *command {
return &command{
description: description,
cmd: cmd,
ambientCaps: ambientCaps,
ignoreErr: ignoreErr,
}
}
func (cmd *command) isIgnorableError(err string) bool {
return cmd.ignoreErr != "" && strings.Contains(err, cmd.ignoreErr)
}
type commandRunner struct {
commands []*command
}
func newCommandRunner(commands []*command) *commandRunner {
return &commandRunner{
commands: commands,
}
}
func (r *commandRunner) run() error {
for _, command := range r.commands {
command.cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: command.ambientCaps,
}
output, err := command.cmd.CombinedOutput()
if err != nil && !command.isIgnorableError(err.Error()) && !command.isIgnorableError(string(output)) {
return fmt.Errorf("failed to %s: %v, output: %s", command.description, err, output)
}
}
return nil
}
func (r *commandRunner) runIgnoreErrors() error {
for _, command := range r.commands {
command.cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: command.ambientCaps,
}
output, err := command.cmd.CombinedOutput()
if err != nil && !command.isIgnorableError(err.Error()) && !command.isIgnorableError(string(output)) {
log.Printf("err: %v", err)
log.Printf("")
log.Printf("failed to %s: %v, output: %s", command.description, err, output)
continue
}
}
return nil
}
@@ -0,0 +1,28 @@
//go:build linux
package nsjail
import (
"os"
"github.com/coder/coder/v2/enterprise/cli/boundary/util"
)
// Returns environment variables intended to be set on the child process,
// so they can later be inherited by the target process.
func getEnvsForTargetProcess(configDir string, caCertPath string) []string {
e := os.Environ()
e = util.MergeEnvs(e, map[string]string{
// Set standard CA certificate environment variables for common tools
// This makes tools like curl, git, etc. trust our dynamically generated CA
"SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools
"SSL_CERT_DIR": configDir, // OpenSSL certificate directory
"CURL_CA_BUNDLE": caCertPath, // curl
"GIT_SSL_CAINFO": caCertPath, // Git
"REQUESTS_CA_BUNDLE": caCertPath, // Python requests
"NODE_EXTRA_CA_CERTS": caCertPath, // Node.js
})
return e
}
@@ -0,0 +1,145 @@
//go:build linux
package nsjail
import (
"fmt"
"log/slog"
"os"
"os/exec"
"syscall"
"golang.org/x/sys/unix"
)
type Jailer interface {
ConfigureHost() error
Command(command []string) *exec.Cmd
ConfigureHostNsCommunication(processPID int) error
Close() error
}
type Config struct {
Logger *slog.Logger
HttpProxyPort int
HomeDir string
ConfigDir string
CACertPath string
ConfigureDNSForLocalStubResolver bool
}
// LinuxJail implements Jailer using Linux network namespaces
type LinuxJail struct {
logger *slog.Logger
vethHostName string // Host-side veth interface name for iptables rules
vethJailName string // Jail-side veth interface name for iptables rules
httpProxyPort int
configDir string
caCertPath string
configureDNSForLocalStubResolver bool
}
func NewLinuxJail(config Config) (*LinuxJail, error) {
return &LinuxJail{
logger: config.Logger,
httpProxyPort: config.HttpProxyPort,
configDir: config.ConfigDir,
caCertPath: config.CACertPath,
configureDNSForLocalStubResolver: config.ConfigureDNSForLocalStubResolver,
}, nil
}
// ConfigureBeforeCommandExecution prepares the jail environment before the target
// process is launched. It sets environment variables, creates the veth pair, and
// installs iptables rules on the host. At this stage, the target PID and its netns
// are not yet known.
func (l *LinuxJail) ConfigureHost() error {
if err := l.configureHostNetworkBeforeCmdExec(); err != nil {
return err
}
if err := l.configureIptables(); err != nil {
return fmt.Errorf("failed to configure iptables: %v", err)
}
return nil
}
// Command returns an exec.Cmd configured to run within the network namespace.
func (l *LinuxJail) Command(command []string) *exec.Cmd {
l.logger.Debug("Creating command with namespace")
cmd := exec.Command(command[0], command[1:]...)
// Set env vars for the child process; they will be inherited by the target process.
cmd.Env = getEnvsForTargetProcess(l.configDir, l.caCertPath)
cmd.Env = append(cmd.Env, "CHILD=true")
cmd.Env = append(cmd.Env, fmt.Sprintf("VETH_JAIL_NAME=%v", l.vethJailName))
if l.configureDNSForLocalStubResolver {
cmd.Env = append(cmd.Env, "CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER=true")
}
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
l.logger.Debug("os.Getuid()", "os.Getuid()", os.Getuid())
l.logger.Debug("os.Getgid()", "os.Getgid()", os.Getgid())
currentUid := os.Getuid()
currentGid := os.Getgid()
cmd.SysProcAttr = &syscall.SysProcAttr{
Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET,
UidMappings: []syscall.SysProcIDMap{
{ContainerID: currentUid, HostID: currentUid, Size: 1},
},
GidMappings: []syscall.SysProcIDMap{
{ContainerID: currentGid, HostID: currentGid, Size: 1},
},
AmbientCaps: []uintptr{unix.CAP_NET_ADMIN},
Pdeathsig: syscall.SIGTERM,
}
return cmd
}
// ConfigureHostNsCommunication finalizes host-side networking after the target
// process has started. It moves the jail-side veth into the target process's network
// namespace using the provided PID. This requires the process to be running so
// its PID (and thus its netns) are available.
func (l *LinuxJail) ConfigureHostNsCommunication(pidInt int) error {
PID := fmt.Sprintf("%v", pidInt)
runner := newCommandRunner([]*command{
// Move the jail-side veth interface into the target network namespace.
// This isolates the interface so that it becomes visible only inside the
// jail's netns. From this point on, the jail will configure its end of
// the veth pair (IP address, routes, etc.) independently of the host.
newCommand(
"Move jail-side veth into network namespace",
exec.Command("ip", "link", "set", l.vethJailName, "netns", PID),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.run(); err != nil {
return err
}
return nil
}
// Close removes the network namespace and iptables rules
func (l *LinuxJail) Close() error {
// Clean up iptables rules
err := l.cleanupIptables()
if err != nil {
l.logger.Error("Failed to clean up iptables rules", "error", err)
// Continue with other cleanup even if this fails
}
// Clean up networking
err = l.cleanupNetworking()
if err != nil {
l.logger.Error("Failed to clean up networking", "error", err)
// Continue with other cleanup even if this fails
}
return nil
}
@@ -0,0 +1,52 @@
//go:build linux
package nsjail
import (
"os/exec"
"golang.org/x/sys/unix"
)
// ConfigureDNSForLocalStubResolver configures DNS redirection from the network namespace
// to the host's local stub resolver. This function should only be called when the host
// runs a local stub resolver such as systemd-resolved, and /etc/resolv.conf contains
// "nameserver 127.0.0.53" (listening on localhost). It redirects DNS requests from the
// namespace to the host by setting up iptables NAT rules. Additionally, /etc/systemd/resolved.conf
// should be configured with DNSStubListener=yes and DNSStubListenerExtra=192.168.100.1:53
// to listen on the additional server address.
// NOTE: it's called inside network namespace.
func ConfigureDNSForLocalStubResolver() error {
runner := newCommandRunner([]*command{
// Redirect all DNS queries inside the namespace to the host DNS listener.
// Needed because systemd-resolved listens on a host-side IP, not inside the namespace.
newCommand(
"Redirect DNS queries (DNAT 53 → host DNS)",
exec.Command("iptables", "-t", "nat", "-A", "OUTPUT", "-p", "udp", "--dport", "53", "-j", "DNAT", "--to-destination", "192.168.100.1:53"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Rewrite the SOURCE IP of redirected DNS packets.
// Required because DNS queries originating as 127.0.0.1 inside the namespace
// must not leave the namespace with a loopback source (kernel drops them).
// SNAT ensures packets arrive at systemd-resolved with a valid, routable source.
newCommand(
"Fix DNS source IP (SNAT 127.0.0.x → 192.168.100.2)",
exec.Command("iptables", "-t", "nat", "-A", "POSTROUTING", "-p", "udp", "--dport", "53", "-d", "192.168.100.1", "-j", "SNAT", "--to-source", "192.168.100.2"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Allow packets destined for 127.0.0.0/8 to go through routing and NAT.
// Without this, DNS queries to 127.0.0.53 never hit iptables OUTPUT
// and cannot be redirected to the host.
newCommand(
"Allow loopback-destined traffic to pass through NAT (route_localnet)",
// TODO(yevhenii): consider replacing with specific interfaces instead of all
exec.Command("sysctl", "-w", "net.ipv4.conf.all.route_localnet=1"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.run(); err != nil {
return err
}
return nil
}
@@ -0,0 +1,185 @@
//go:build linux
package nsjail
import (
"fmt"
"os/exec"
"time"
"golang.org/x/sys/unix"
)
// configureHostNetworkBeforeCmdExec prepares host-side networking before the target
// process is started. At this point the target process is not running, so its PID and network
// namespace ID are not yet known.
func (l *LinuxJail) configureHostNetworkBeforeCmdExec() error {
// Create veth pair with short names (Linux interface names limited to 15 chars)
// Generate unique ID to avoid conflicts
uniqueID := fmt.Sprintf("%d", time.Now().UnixNano()%10000000) // 7 digits max
vethHostName := fmt.Sprintf("veth_h_%s", uniqueID) // veth_h_1234567 = 14 chars
vethJailName := fmt.Sprintf("veth_n_%s", uniqueID) // veth_n_1234567 = 14 chars
// Store veth interface name for iptables rules
l.vethHostName = vethHostName
l.vethJailName = vethJailName
runner := newCommandRunner([]*command{
// Create a virtual Ethernet (veth) pair that forms a point-to-point link
// between the host and the jail namespace. One end stays on the host,
// the other will be moved into the jail. This provides a dedicated,
// isolated L2 network for the jail.
newCommand(
"Create hostjail veth interface pair",
exec.Command("ip", "link", "add", vethHostName, "type", "veth", "peer", "name", vethJailName),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Assign an IP address to the host side of the veth pair. The /24 mask
// implicitly defines the jail's entire subnet as 192.168.100.0/24.
// The host address (192.168.100.1) becomes the default gateway for
// processes inside the jail and is used by NAT and interception rules
// to route traffic out of the namespace.
newCommand(
"Assign IP to host-side veth",
exec.Command("ip", "addr", "add", "192.168.100.1/24", "dev", vethHostName),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
newCommand(
"Activate host-side veth interface",
exec.Command("ip", "link", "set", vethHostName, "up"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.run(); err != nil {
return err
}
return nil
}
// setupIptables configures iptables rules for comprehensive TCP traffic interception
func (l *LinuxJail) configureIptables() error {
runner := newCommandRunner([]*command{
// Enable IPv4 packet forwarding so the host can route packets between
// the jail's veth interface and the outside network. Without this,
// NAT and forwarding rules would have no effect because the kernel
// would drop transit packets.
newCommand(
"enable IP forwarding",
exec.Command("sysctl", "-w", "net.ipv4.ip_forward=1"),
[]uintptr{},
),
// Apply source NAT (MASQUERADE) for all traffic leaving the jails
// private subnet. This rewrites the source IP of packets originating
// from 192.168.100.0/24 to the hosts external interface IP. It enables:
//
// - outbound connectivity for jailed processes,
// - correct return routing from external endpoints,
// - avoidance of static IP assignment for the host interface.
//
// MASQUERADE is used instead of SNAT so it works even when the host IP
// changes dynamically.
newCommand(
"NAT rules for outgoing traffic (MASQUERADE for return traffic)",
exec.Command("iptables", "-t", "nat", "-A", "POSTROUTING", "-s", "192.168.100.0/24", "-j", "MASQUERADE"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Redirect *ALL TCP traffic* coming from the jails veth interface
// to the local HTTP/TLS-intercepting proxy. This causes *every* TCP
// connection (HTTP, HTTPS, plain TCP protocols) initiated by jailed
// processes to be transparently intercepted.
//
// The HTTP proxy will intelligently handle both HTTP and TLS traffic.
//
// PREROUTING is used so redirection happens before routing decisions.
// REDIRECT rewrites the destination IP to 127.0.0.1 and the destination
// port to the HTTP proxy's port, forcing traffic through the proxy without
// requiring any configuration inside the jail.
newCommand(
"Route ALL TCP traffic to HTTP proxy",
exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHostName, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Allow forwarding of non-TCP packets originating from the jails subnet.
// This rule is primarily needed for traffic that is *not* intercepted by
// the TCP REDIRECT rule — for example:
//
// - DNS queries (UDP/53)
// - ICMP (ping, errors)
// - Any other UDP or non-TCP protocols
//
// Redirected TCP flows never reach the FORWARD chain (they are locally
// redirected in PREROUTING), so this rule does not apply to TCP traffic.
newCommand(
"Allow outbound non-TCP traffic from jail subnet",
exec.Command("iptables", "-A", "FORWARD", "-s", "192.168.100.0/24", "-j", "ACCEPT"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Allow forwarding of return traffic destined for the jails subnet for
// non-TCP flows. This complements the previous FORWARD rule and ensures
// that responses to DNS (UDP) or ICMP packets can reach the jail.
//
// As with the previous rule, this has no effect on TCP traffic because
// all TCP connections from the jail are intercepted and redirected to
// the local proxy before reaching the forwarding path.
newCommand(
"Allow inbound return traffic to jail subnet (non-TCP)",
exec.Command("iptables", "-A", "FORWARD", "-d", "192.168.100.0/24", "-j", "ACCEPT"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.run(); err != nil {
return err
}
l.logger.Debug("Comprehensive TCP boundarying enabled", "interface", l.vethHostName, "proxy_port", l.httpProxyPort)
return nil
}
// cleanupNetworking removes networking configuration
func (l *LinuxJail) cleanupNetworking() error {
runner := newCommandRunner([]*command{
newCommandWithIgnoreErr(
"delete veth pair",
exec.Command("ip", "link", "del", l.vethHostName),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
"Cannot find device",
),
})
if err := runner.runIgnoreErrors(); err != nil {
return err
}
return nil
}
// cleanupIptables removes iptables rules
func (l *LinuxJail) cleanupIptables() error {
runner := newCommandRunner([]*command{
newCommand(
"Remove: NAT rules for outgoing traffic (MASQUERADE for return traffic)",
exec.Command("iptables", "-t", "nat", "-D", "POSTROUTING", "-s", "192.168.100.0/24", "-j", "MASQUERADE"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
newCommand(
"Remove: Route ALL TCP traffic to HTTP proxy",
exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHostName, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
newCommand(
"Remove: Allow outbound non-TCP traffic from jail subnet",
exec.Command("iptables", "-D", "FORWARD", "-s", "192.168.100.0/24", "-j", "ACCEPT"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
newCommand(
"Remove: Allow inbound return traffic to jail subnet (non-TCP)",
exec.Command("iptables", "-D", "FORWARD", "-d", "192.168.100.0/24", "-j", "ACCEPT"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.runIgnoreErrors(); err != nil {
return err
}
return nil
}
@@ -0,0 +1,56 @@
//go:build linux
package nsjail
import (
"os/exec"
"golang.org/x/sys/unix"
)
// SetupChildNetworking configures networking within the target process's network
// namespace. This runs inside the child process after it has been
// created and moved to its own network namespace.
func SetupChildNetworking(vethNetJail string) error {
runner := newCommandRunner([]*command{
// Assign an IP address to the jail-side veth interface. The /24 mask
// matches the subnet defined on the host side (192.168.100.0/24),
// ensuring both interfaces appear on the same L2 network. This address
// (192.168.100.2) will serve as the jail's primary outbound source IP.
newCommand(
"Assign IP to jail-side veth",
exec.Command("ip", "addr", "add", "192.168.100.2/24", "dev", vethNetJail),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Bring the jail-side veth interface up. Until the interface is set UP,
// the jail cannot send or receive any packets on this link, even if the
// IP address and routes are configured correctly.
newCommand(
"Activate jail-side veth interface",
exec.Command("ip", "link", "set", vethNetJail, "up"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Bring the jail-side veth interface up. Until the interface is set UP,
// the jail cannot send or receive any packets on this link, even if the
// IP address and routes are configured correctly.
newCommand(
"Enable loopback interface in jail",
exec.Command("ip", "link", "set", "lo", "up"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
// Set the default route for all outbound traffic inside the jail. The
// gateway is the host-side veth address (192.168.100.1), which performs
// NAT and transparent TCP interception. This ensures that packets not
// destined for the jail subnet are routed to the host for processing.
newCommand(
"Configure default gateway for jail",
exec.Command("ip", "route", "add", "default", "via", "192.168.100.1"),
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
),
})
if err := runner.run(); err != nil {
return err
}
return nil
}
@@ -0,0 +1,76 @@
//go:build linux
package nsjail_manager
import (
"context"
"fmt"
"log/slog"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
"github.com/coder/coder/v2/enterprise/cli/boundary/tls"
)
func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
if len(config.AllowRules) == 0 {
logger.Warn("No allow rules specified; all network traffic will be denied by default")
}
// Parse allow rules
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules)
if err != nil {
logger.Error("Failed to parse allow rules", "error", err)
return fmt.Errorf("failed to parse allow rules: %v", err)
}
// Create rule engine
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
// Create auditor
auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath)
if err != nil {
return fmt.Errorf("failed to setup auditor: %v", err)
}
// Create TLS certificate manager
certManager, err := tls.NewCertificateManager(tls.Config{
Logger: logger,
ConfigDir: config.UserInfo.ConfigDir,
Uid: config.UserInfo.Uid,
Gid: config.UserInfo.Gid,
})
if err != nil {
logger.Error("Failed to create certificate manager", "error", err)
return fmt.Errorf("failed to create certificate manager: %v", err)
}
// Setup TLS to get cert path for jailer
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
if err != nil {
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
}
// Create jailer with cert path from TLS setup
jailer, err := nsjail.NewLinuxJail(nsjail.Config{
Logger: logger,
HttpProxyPort: int(config.ProxyPort),
HomeDir: config.UserInfo.HomeDir,
ConfigDir: config.UserInfo.ConfigDir,
CACertPath: config.UserInfo.CACertPath(),
ConfigureDNSForLocalStubResolver: config.ConfigureDNSForLocalStubResolver,
})
if err != nil {
return fmt.Errorf("failed to create jailer: %v", err)
}
// Create boundary instance
nsJailMgr, err := NewNSJailManager(ruleEngine, auditor, tlsConfig, jailer, logger, config)
if err != nil {
return fmt.Errorf("failed to create boundary instance: %v", err)
}
return nsJailMgr.Run(ctx)
}
@@ -0,0 +1,27 @@
//go:build linux
package nsjail_manager
import (
"context"
"log/slog"
"os"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
)
func isChild() bool {
return os.Getenv("CHILD") == "true"
}
// Run is the main entry point that determines whether to execute as a parent or child process.
// If running as a child (CHILD env var is set), it sets up networking in the namespace
// and executes the target command. Otherwise, it runs as the parent process, setting up the jail,
// proxy server, and managing the child process lifecycle.
func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
if isChild() {
return RunChild(logger, config.TargetCMD)
}
return RunParent(ctx, logger, config)
}
+160
View File
@@ -0,0 +1,160 @@
// Package proxy implements HTTP CONNECT method for tunneling HTTPS traffic through a proxy.
//
// # HTTP CONNECT Method Overview
//
// The HTTP CONNECT method is used to establish a tunnel through a proxy server.
// This is essential for HTTPS proxying because HTTPS requires end-to-end encryption
// that cannot be inspected or modified by intermediaries.
//
// How HTTP_PROXY Works
//
// When a client is configured to use an HTTP proxy (via HTTP_PROXY environment variable
// or proxy settings), it behaves differently for HTTP vs HTTPS requests:
//
// - HTTP requests: The client sends the full request to the proxy, including the
// complete URL. The proxy forwards it to the destination server.
//
// - HTTPS requests: The client cannot send the encrypted request directly because
// the proxy needs to know where to connect. Instead, the client uses CONNECT
// to establish a tunnel, then performs the TLS handshake and sends HTTPS
// requests through that tunnel.
//
// # Non-Transparent Proxy
//
// This proxy is "non-transparent" because:
// - Clients must be explicitly configured to use it (via HTTP_PROXY)
// - Clients send CONNECT requests for HTTPS traffic
// - The proxy terminates TLS, inspects requests, and re-encrypts to the destination
// - Each HTTP request inside the tunnel is processed separately with rule evaluation
//
// # CONNECT Request Flow
//
// The following diagram illustrates how CONNECT works:
//
// Client Proxy (HTTP/1.1 Server) Real Server
// | | |
// |-- CONNECT example.com:443 -->| |
// | | |
// |<-- 200 Connection Established| |
// | | |
// |-- TLS Handshake ------------->| |
// | | |
// |<-- TLS Handshake -------------| |
// | | |
// |-- Request #1: GET /page1 --->| (decrypts) |
// | |-- GET /page1 --------------------->|
// | |<-- Response #1 --------------------|
// |<-- Response #1 --------------| (encrypts) |
// | | |
// |-- Request #2: GET /page2 --->| (decrypts) |
// | |-- GET /page2 --------------------->|
// | |<-- Response #2 --------------------|
// |<-- Response #2 --------------| (encrypts) |
// | | |
// |-- Request #3: GET /api ----->| (decrypts) |
// | |-- GET /api ----------------------->|
// | |<-- Response #3 --------------------|
// |<-- Response #3 --------------| (encrypts) |
// | | |
// | (connection stays open...) | |
// | | |
// |-- [closes connection] ------->| |
//
// Key Points:
//
// 1. CONNECT establishes the tunnel endpoint (e.g., "example.com:443")
// 2. The actual destination for each request is determined by the Host header
// in the HTTP request inside the tunnel, not the CONNECT target
// 3. The proxy acts as a TLS server to decrypt traffic from the client
// 4. Each HTTP request inside the tunnel is evaluated against rules separately
// 5. The connection remains open for multiple requests (HTTP/1.1 keep-alive)
//
// Implementation Details:
//
// - handleCONNECT: Receives the CONNECT request, sends "200 Connection Established"
// - handleCONNECTTunnel: Wraps the connection with TLS, processes requests in a loop
// - Each request uses req.Host to determine the actual destination, not the CONNECT target
//
//nolint:revive,gocritic,errname,unconvert,noctx,errorlint,bodyclose
package proxy
import (
"bufio"
"crypto/tls"
"io"
"net"
"net/http"
)
// handleCONNECT handles HTTP CONNECT requests for tunneling.
//
// When a client wants to make an HTTPS request through the proxy, it first sends
// a CONNECT request with the target hostname:port (e.g., "example.com:443").
// The proxy responds with "200 Connection Established" and then the client
// performs a TLS handshake over the same connection.
//
// After the tunnel is established, handleCONNECTTunnel processes the encrypted
// traffic and handles each HTTP request inside the tunnel separately.
func (p *Server) handleCONNECT(conn net.Conn, req *http.Request) {
p.logger.Debug("🔌 CONNECT request", "target", req.Host)
// Send 200 Connection established response
response := "HTTP/1.1 200 Connection established\r\n\r\n"
_, err := conn.Write([]byte(response))
if err != nil {
p.logger.Error("Failed to send CONNECT response", "error", err)
return
}
p.logger.Debug("CONNECT tunnel established", "target", req.Host)
// Handle the tunnel - decrypt TLS and process each HTTP request
p.handleCONNECTTunnel(conn)
}
// handleCONNECTTunnel handles the tunnel after CONNECT is established.
//
// This function:
// 1. Wraps the connection with TLS.Server to decrypt traffic from the client
// 2. Performs the TLS handshake
// 3. Reads HTTP requests from the tunnel in a loop
// 4. Processes each request separately (rule evaluation, forwarding)
//
// Important: The actual destination for each request is determined by the Host
// header in the HTTP request, not the CONNECT target. This allows multiple
// domains to be accessed over the same tunnel.
//
// The connection lifecycle is managed by handleHTTPConnection's defer, which
// closes the connection when this function returns.
func (p *Server) handleCONNECTTunnel(conn net.Conn) {
// Wrap connection with TLS server to decrypt traffic
tlsConn := tls.Server(conn, p.tlsConfig)
// Perform TLS handshake
if err := tlsConn.Handshake(); err != nil {
p.logger.Error("TLS handshake failed in CONNECT tunnel", "error", err)
return
}
p.logger.Debug("✅ TLS handshake successful in CONNECT tunnel")
// Process HTTP requests in a loop
reader := bufio.NewReader(tlsConn)
for {
// Read HTTP request from tunnel
req, err := http.ReadRequest(reader)
if err != nil {
if err == io.EOF {
p.logger.Debug("CONNECT tunnel closed by client")
break
}
p.logger.Error("Failed to read HTTP request from CONNECT tunnel", "error", err)
break
}
p.logger.Debug("🔒 HTTP Request in CONNECT tunnel", "method", req.Method, "url", req.URL.String(), "target", req.Host)
// Process this request - check if allowed and forward to target
p.processHTTPRequest(tlsConn, req, true)
}
}
+459
View File
@@ -0,0 +1,459 @@
//nolint:revive,gocritic,errname,unconvert,noctx,errorlint,bodyclose
package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
_ "net/http/pprof" // G108: pprof is intentionally exposed for debugging
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
)
// Server handles HTTP and HTTPS requests with rule-based filtering
type Server struct {
ruleEngine rulesengine.Engine
auditor audit.Auditor
logger *slog.Logger
tlsConfig *tls.Config
httpPort int
started atomic.Bool
listener net.Listener
pprofServer *http.Server
pprofEnabled bool
pprofPort int
}
// Config holds configuration for the proxy server
type Config struct {
HTTPPort int
RuleEngine rulesengine.Engine
Auditor audit.Auditor
Logger *slog.Logger
TLSConfig *tls.Config
PprofEnabled bool
PprofPort int
}
// NewProxyServer creates a new proxy server instance
func NewProxyServer(config Config) *Server {
return &Server{
ruleEngine: config.RuleEngine,
auditor: config.Auditor,
logger: config.Logger,
tlsConfig: config.TLSConfig,
httpPort: config.HTTPPort,
pprofEnabled: config.PprofEnabled,
pprofPort: config.PprofPort,
}
}
// Start starts the HTTP proxy server with TLS termination capability
func (p *Server) Start() error {
if p.isStarted() {
return nil
}
p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort)
// Start pprof server if enabled
if p.pprofEnabled {
p.pprofServer = &http.Server{ // G112: pprof server doesn't need ReadHeaderTimeout
Addr: fmt.Sprintf(":%d", p.pprofPort),
Handler: http.DefaultServeMux,
}
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", p.pprofPort))
if err != nil {
p.logger.Error("failed to listen on port for pprof server", "port", p.pprofPort, "error", err)
return xerrors.Errorf("failed to listen on port %v for pprof server: %v", p.pprofPort, err)
}
go func() {
p.logger.Info("Serving pprof on existing listener", "port", p.pprofPort)
if err := p.pprofServer.Serve(ln); err != nil && errors.Is(err, http.ErrServerClosed) {
p.logger.Error("pprof server error", "error", err)
}
}()
}
var err error
p.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort))
if err != nil {
p.logger.Error("Failed to create HTTP listener", "error", err)
return err
}
p.started.Store(true)
// Start HTTP server with custom listener for TLS detection
go func() {
for {
conn, err := p.listener.Accept()
if err != nil && errors.Is(err, net.ErrClosed) && p.isStopped() {
return
}
if err != nil {
p.logger.Error("Failed to accept connection", "error", err)
continue
}
// Handle connection with TLS detection
go p.handleConnectionWithTLSDetection(conn)
}
}()
return nil
}
// Stops proxy server
func (p *Server) Stop() error {
if p.isStopped() {
return nil
}
p.started.Store(false)
if p.listener == nil {
p.logger.Error("unexpected nil listener")
return xerrors.New("unexpected nil listener")
}
err := p.listener.Close()
if err != nil {
p.logger.Error("Failed to close listener", "error", err)
return err
}
// Close pprof server
if p.pprofServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := p.pprofServer.Shutdown(ctx); err != nil {
p.logger.Error("Failed to shutdown pprof server", "error", err)
}
}
return nil
}
func (p *Server) isStarted() bool {
return p.started.Load()
}
func (p *Server) isStopped() bool {
return !p.started.Load()
}
func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) {
// Detect protocol using TLS handshake detection
wrappedConn, isTLS, err := p.isTLSConnection(conn)
if err != nil {
p.logger.Error("Failed to check connection type", "error", err)
err := conn.Close()
if err != nil {
p.logger.Error("Failed to close connection", "error", err)
}
return
}
if isTLS {
p.logger.Debug("🔒 Detected TLS connection - handling as HTTPS")
p.handleTLSConnection(wrappedConn)
} else {
p.logger.Debug("🌐 Detected HTTP connection")
p.handleHTTPConnection(wrappedConn)
}
}
func (p *Server) isTLSConnection(conn net.Conn) (net.Conn, bool, error) {
// Read first byte to detect TLS
buf := make([]byte, 1)
n, err := conn.Read(buf)
if err != nil || n == 0 {
return nil, false, xerrors.Errorf("failed to read first byte from connection: %v, read %v bytes", err, n)
}
connWrapper := &connectionWrapper{conn, buf, false}
// TLS detection based on first byte:
// 0x16 (22) = TLS Handshake
// 0x17 (23) = TLS Application Data
// 0x14 (20) = TLS Change Cipher Spec
// 0x15 (21) = TLS Alert
isTLS := buf[0] == 0x16 || buf[0] == 0x17 || buf[0] == 0x14 || buf[0] == 0x15
if isTLS {
p.logger.Debug("TLS detected", "first byte", buf[0])
}
return connWrapper, isTLS, nil
}
func (p *Server) handleHTTPConnection(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
p.logger.Error("Failed to close connection", "error", err)
}
}()
// Read HTTP request
req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
p.logger.Error("Failed to read HTTP request", "error", err)
return
}
if req.Method == http.MethodConnect {
p.handleCONNECT(conn, req)
return
}
p.logger.Debug("🌐 HTTP Request", "method", req.Method, "url", req.URL.String())
p.processHTTPRequest(conn, req, false)
}
func (p *Server) handleTLSConnection(conn net.Conn) {
// Create TLS connection
tlsConn := tls.Server(conn, p.tlsConfig)
defer func() {
err := tlsConn.Close()
if err != nil {
p.logger.Error("Failed to close TLS connection", "error", err)
}
}()
// Perform TLS handshake
if err := tlsConn.Handshake(); err != nil {
p.logger.Error("TLS handshake failed", "error", err)
return
}
p.logger.Debug("✅ TLS handshake successful")
// Read HTTP request over TLS
req, err := http.ReadRequest(bufio.NewReader(tlsConn))
if err != nil {
p.logger.Error("Failed to read HTTPS request", "error", err)
return
}
p.logger.Debug("🔒 HTTPS Request", "method", req.Method, "url", req.URL.String())
p.processHTTPRequest(tlsConn, req, true)
}
func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool) {
p.logger.Debug(" Host", "host", req.Host)
p.logger.Debug(" User-Agent", "user-agent", req.Header.Get("User-Agent"))
// Construct fully qualified URL for rule evaluation and auditing.
// In boundary's normal transparent proxy operation, req.URL only contains
// the path since clients don't know they're going through a proxy.
// When clients explicitly configure a proxy, req.URL contains the full URL.
fullURL := req.URL.String()
if req.URL.Scheme == "" {
scheme := "http"
if https {
scheme = "https"
}
fullURL = scheme + "://" + req.Host + fullURL
}
result := p.ruleEngine.Evaluate(req.Method, fullURL)
p.auditor.AuditRequest(audit.Request{
Method: req.Method,
URL: fullURL,
Host: req.Host,
Allowed: result.Allowed,
Rule: result.Rule,
})
if !result.Allowed {
p.writeBlockedResponse(conn, req)
return
}
// Forward request to destination
p.forwardRequest(conn, req, https)
}
func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) {
// Create HTTP client
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // Don't follow redirects
},
}
scheme := "http"
if https {
scheme = "https"
}
// Create a new request to the target server
targetURL := &url.URL{
Scheme: scheme,
Host: req.Host,
Path: req.URL.Path,
RawQuery: req.URL.RawQuery,
}
body := req.Body
if req.Method == http.MethodGet || req.Method == http.MethodHead {
body = nil
}
newReq, err := http.NewRequest(req.Method, targetURL.String(), body)
if err != nil {
p.logger.Error("can't create http request", "error", err)
return
}
// Copy headers
for name, values := range req.Header {
// Skip connection-specific headers
if strings.ToLower(name) == "connection" || strings.ToLower(name) == "proxy-connection" {
continue
}
for _, value := range values {
newReq.Header.Add(name, value)
}
}
// Make request to destination
resp, err := client.Do(newReq)
if err != nil {
p.logger.Error("Failed to forward HTTPS request", "error", err)
return
}
p.logger.Debug("🔒 HTTPS Response", "status code", resp.StatusCode, "status", resp.Status)
p.logger.Debug("Forwarded Request",
"method", newReq.Method,
"host", newReq.Host,
"URL", newReq.URL,
)
// Read the body and explicitly set Content-Length header, otherwise client can hung up on the request.
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
p.logger.Error("can't read response body", "error", err)
return
}
resp.Header.Add("Content-Length", strconv.Itoa(len(bodyBytes)))
resp.ContentLength = int64(len(bodyBytes))
err = resp.Body.Close()
if err != nil {
p.logger.Error("Failed to close HTTP response body", "error", err)
return
}
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// The downstream client (Claude) always communicates over HTTP/1.1.
// However, Go's default HTTP client may negotiate an HTTP/2 connection
// with the upstream server via ALPN during TLS handshake.
// This can cause the response's Proto field to be set to "HTTP/2.0",
// which would produce an invalid response for an HTTP/1.1 client.
// To prevent this mismatch, we explicitly normalize the response
// to HTTP/1.1 before writing it back to the client.
resp.Proto = "HTTP/1.1"
resp.ProtoMajor = 1
resp.ProtoMinor = 1
// Copy response back to client
err = resp.Write(conn)
if err != nil {
p.logger.Error("Failed to forward back HTTP response",
"error", err,
"host", req.Host,
"method", req.Method,
// "bodyBytes", string(bodyBytes),
)
return
}
p.logger.Debug("Successfully wrote to connection")
}
func (p *Server) writeBlockedResponse(conn net.Conn, req *http.Request) {
// Create a response object
resp := &http.Response{
Status: "403 Forbidden",
StatusCode: http.StatusForbidden,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: nil,
ContentLength: 0,
}
// Set headers
resp.Header.Set("Content-Type", "text/plain")
// Create the response body
host := req.URL.Host
if host == "" {
host = req.Host
}
body := fmt.Sprintf(`🚫 Request Blocked by Boundary
Request: %s %s
Host: %s
To allow this request, restart boundary with:
--allow "domain=%s" # Allow all methods to this host
--allow "method=%s domain=%s" # Allow only %s requests to this host
For more help: https://github.com/coder/boundary
`,
req.Method, req.URL.Path, host, host, req.Method, host, req.Method)
resp.Body = io.NopCloser(strings.NewReader(body))
resp.ContentLength = int64(len(body))
// Copy response back to client
err := resp.Write(conn)
if err != nil {
p.logger.Error("Failed to write blocker response", "error", err)
return
}
p.logger.Debug("Successfully wrote to connection")
}
// connectionWrapper lets us "unread" the peeked byte
type connectionWrapper struct {
net.Conn
buf []byte
bufUsed bool
}
func (c *connectionWrapper) Read(p []byte) (int, error) {
if !c.bufUsed && len(c.buf) > 0 {
n := copy(p, c.buf)
c.bufUsed = true
return n, nil
}
return c.Conn.Read(p)
}
@@ -0,0 +1,134 @@
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
package proxy
import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
)
// capturingAuditor captures all audit requests for test verification.
type capturingAuditor struct {
mu sync.Mutex
requests []audit.Request
}
func (c *capturingAuditor) AuditRequest(req audit.Request) {
c.mu.Lock()
defer c.mu.Unlock()
c.requests = append(c.requests, req)
}
func (c *capturingAuditor) getRequests() []audit.Request {
c.mu.Lock()
defer c.mu.Unlock()
return append([]audit.Request{}, c.requests...)
}
func TestAuditURLIsFullyFormed_HTTP(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
serverURL, err := url.Parse(server.URL)
require.NoError(t, err)
auditor := &capturingAuditor{}
pt := NewProxyTest(t,
WithCertManager(t.TempDir()),
WithAllowedRule("domain="+serverURL.Hostname()+" path=/allowed/*"),
WithAuditor(auditor),
).Start()
defer pt.Stop()
t.Run("allowed", func(t *testing.T) {
resp, err := pt.proxyClient.Get(server.URL + "/allowed/path?q=1")
require.NoError(t, err)
defer func() {
err = resp.Body.Close()
require.NoError(t, err)
}()
require.Equal(t, http.StatusOK, resp.StatusCode)
requests := auditor.getRequests()
require.NotEmpty(t, requests)
req := requests[len(requests)-1]
require.True(t, req.Allowed)
expectedURL := "http://" + net.JoinHostPort(serverURL.Hostname(), serverURL.Port()) + "/allowed/path?q=1"
assert.Equal(t, expectedURL, req.URL)
})
t.Run("denied", func(t *testing.T) {
resp, err := pt.proxyClient.Get(server.URL + "/denied/path")
require.NoError(t, err)
defer func() {
err = resp.Body.Close()
require.NoError(t, err)
}()
require.Equal(t, http.StatusForbidden, resp.StatusCode)
requests := auditor.getRequests()
require.NotEmpty(t, requests)
req := requests[len(requests)-1]
require.False(t, req.Allowed)
expectedURL := "http://" + net.JoinHostPort(serverURL.Hostname(), serverURL.Port()) + "/denied/path"
assert.Equal(t, expectedURL, req.URL)
})
}
func TestAuditURLIsFullyFormed_HTTPS(t *testing.T) {
auditor := &capturingAuditor{}
pt := NewProxyTest(t,
WithCertManager(t.TempDir()),
WithAllowedDomain("dev.coder.com"),
WithAuditor(auditor),
).Start()
defer pt.Stop()
tunnel, err := pt.establishExplicitCONNECT("dev.coder.com:443")
require.NoError(t, err)
defer func() {
assert.NoError(t, tunnel.close())
}()
t.Run("allowed", func(t *testing.T) {
_, err := tunnel.sendRequest("dev.coder.com", "/api/v2?q=1")
require.NoError(t, err)
requests := auditor.getRequests()
require.NotEmpty(t, requests)
req := requests[len(requests)-1]
require.True(t, req.Allowed)
assert.Equal(t, "https://dev.coder.com/api/v2?q=1", req.URL)
})
t.Run("denied", func(t *testing.T) {
err := tunnel.sendRequestAndExpectDeny("blocked.example.com", "/some/path")
require.NoError(t, err)
requests := auditor.getRequests()
require.NotEmpty(t, requests)
req := requests[len(requests)-1]
require.False(t, req.Allowed)
assert.Equal(t, "https://blocked.example.com/some/path", req.URL)
})
}
@@ -0,0 +1,93 @@
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
package proxy
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestProxyServerImplicitCONNECT tests HTTP CONNECT method for HTTPS tunneling
// CONNECT happens implicitly when using proxy transport with HTTPS requests
func TestProxyServerImplicitCONNECT(t *testing.T) {
pt := NewProxyTest(t,
WithCertManager("/tmp/boundary_connect_test"),
WithAllowedDomain("dev.coder.com"),
WithAllowedDomain("jsonplaceholder.typicode.com"),
).
Start()
defer pt.Stop()
// Test HTTPS request through proxy transport (automatic CONNECT)
t.Run("HTTPSRequestThroughProxyTransport", func(t *testing.T) {
expectedResponse := `{"message":"👋"}
`
// Because this is HTTPS, Go will issue CONNECT localhost:8080 → dev.coder.com:443
pt.ExpectAllowedViaProxy("https://dev.coder.com/api/v2", expectedResponse)
})
// Test HTTP request through proxy transport
t.Run("HTTPRequestThroughProxyTransport", func(t *testing.T) {
expectedResponse := `{
"userId": 1,
"id": 1,
"title": "delectus aut autem",
"completed": false
}`
// For HTTP requests, Go will send the request directly to the proxy
// The proxy will forward it to the target server
pt.ExpectAllowedViaProxy("http://jsonplaceholder.typicode.com/todos/1", expectedResponse)
})
}
// TestMultipleRequestsOverExplicitCONNECT tests explicit CONNECT requests with multiple requests over the same tunnel
func TestMultipleRequestsOverExplicitCONNECT(t *testing.T) {
pt := NewProxyTest(t,
WithCertManager("/tmp/boundary_explicit_connect_test"),
WithAllowedDomain("dev.coder.com"),
WithAllowedDomain("jsonplaceholder.typicode.com"),
).
Start()
defer pt.Stop()
// Establish explicit CONNECT tunnel
// Note: The CONNECT target is just the tunnel endpoint. The actual destination
// for each request is determined by the Host header in the HTTP request inside the tunnel.
tunnel, err := pt.establishExplicitCONNECT("dev.coder.com:443")
require.NoError(t, err, "Failed to establish CONNECT tunnel")
defer tunnel.close()
// Positive test: Send first request to dev.coder.com over the tunnel
t.Run("AllowedRequestToDevCoder", func(t *testing.T) {
body1, err := tunnel.sendRequest("dev.coder.com", "/api/v2")
require.NoError(t, err, "Failed to send first request")
expectedResponse1 := `{"message":"👋"}
`
require.Equal(t, expectedResponse1, string(body1), "First response does not match")
})
// Positive test: Send second request to a different domain (jsonplaceholder.typicode.com) over the same tunnel
t.Run("AllowedRequestToJsonPlaceholder", func(t *testing.T) {
body2, err := tunnel.sendRequest("jsonplaceholder.typicode.com", "/todos/1")
require.NoError(t, err, "Failed to send second request")
expectedResponse2 := `{
"userId": 1,
"id": 1,
"title": "delectus aut autem",
"completed": false
}`
require.Equal(t, expectedResponse2, string(body2), "Second response does not match")
})
// Negative test: Try to send request to a blocked domain over the same tunnel
t.Run("BlockedDomainOverSameTunnel", func(t *testing.T) {
err := tunnel.sendRequestAndExpectDeny("example.com", "/")
require.NoError(t, err, "Expected request to be blocked")
})
// Negative test: Try to send request to another blocked domain
t.Run("AnotherBlockedDomainOverSameTunnel", func(t *testing.T) {
err := tunnel.sendRequestAndExpectDeny("github.com", "/")
require.NoError(t, err, "Expected request to be blocked")
})
}
@@ -0,0 +1,438 @@
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
package proxy
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"os/user"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
boundary_tls "github.com/coder/coder/v2/enterprise/cli/boundary/tls"
)
// mockAuditor is a simple mock auditor for testing
type mockAuditor struct{}
func (m *mockAuditor) AuditRequest(req audit.Request) {
// No-op for testing
}
// ProxyTest is a high-level test framework for proxy tests
type ProxyTest struct {
t *testing.T
server *Server
client *http.Client
proxyClient *http.Client
port int
useCertManager bool
configDir string
startupDelay time.Duration
allowedRules []string
auditor audit.Auditor
}
// ProxyTestOption is a function that configures ProxyTest
type ProxyTestOption func(*ProxyTest)
// NewProxyTest creates a new ProxyTest instance
func NewProxyTest(t *testing.T, opts ...ProxyTestOption) *ProxyTest {
pt := &ProxyTest{
t: t,
port: 8080,
useCertManager: false,
configDir: "/tmp/boundary",
startupDelay: 100 * time.Millisecond,
allowedRules: []string{}, // Default: deny all (no rules = deny by default)
}
// Apply options
for _, opt := range opts {
opt(pt)
}
return pt
}
// WithProxyPort sets the proxy server port
func WithProxyPort(port int) ProxyTestOption {
return func(pt *ProxyTest) {
pt.port = port
}
}
// WithCertManager enables TLS certificate manager
func WithCertManager(configDir string) ProxyTestOption {
return func(pt *ProxyTest) {
pt.useCertManager = true
pt.configDir = configDir
}
}
// WithStartupDelay sets how long to wait after starting server before making requests
func WithStartupDelay(delay time.Duration) ProxyTestOption {
return func(pt *ProxyTest) {
pt.startupDelay = delay
}
}
// WithAllowedDomain adds an allowed domain rule
func WithAllowedDomain(domain string) ProxyTestOption {
return func(pt *ProxyTest) {
pt.allowedRules = append(pt.allowedRules, fmt.Sprintf("domain=%s", domain))
}
}
// WithAllowedRule adds a full allow rule (e.g., "method=GET domain=example.com path=/api/*")
func WithAllowedRule(rule string) ProxyTestOption {
return func(pt *ProxyTest) {
pt.allowedRules = append(pt.allowedRules, rule)
}
}
// WithAuditor sets a custom auditor for capturing audit requests
func WithAuditor(auditor audit.Auditor) ProxyTestOption {
return func(pt *ProxyTest) {
pt.auditor = auditor
}
}
// Start starts the proxy server
func (pt *ProxyTest) Start() *ProxyTest {
pt.t.Helper()
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelError,
}))
testRules, err := rulesengine.ParseAllowSpecs(pt.allowedRules)
require.NoError(pt.t, err, "Failed to parse test rules")
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
// Use custom auditor if provided, otherwise use no-op mock
auditor := pt.auditor
if auditor == nil {
auditor = &mockAuditor{}
}
var tlsConfig *tls.Config
if pt.useCertManager {
currentUser, err := user.Current()
require.NoError(pt.t, err, "Failed to get current user")
uid, _ := strconv.Atoi(currentUser.Uid)
gid, _ := strconv.Atoi(currentUser.Gid)
certManager, err := boundary_tls.NewCertificateManager(boundary_tls.Config{
Logger: logger,
ConfigDir: pt.configDir,
Uid: uid,
Gid: gid,
})
require.NoError(pt.t, err, "Failed to create certificate manager")
tlsConfig, err = certManager.SetupTLSAndWriteCACert()
require.NoError(pt.t, err, "Failed to setup TLS")
} else {
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
pt.server = NewProxyServer(Config{
HTTPPort: pt.port,
RuleEngine: ruleEngine,
Auditor: auditor,
Logger: logger,
TLSConfig: tlsConfig,
})
err = pt.server.Start()
require.NoError(pt.t, err, "Failed to start server")
// Give server time to start
time.Sleep(pt.startupDelay)
// Create HTTP client for direct proxy requests
pt.client = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // G402: Skip cert verification for testing
},
},
Timeout: 5 * time.Second,
}
// Create HTTP client for proxy transport (implicit CONNECT)
proxyURL, err := url.Parse("http://localhost:" + strconv.Itoa(pt.port))
require.NoError(pt.t, err, "Failed to parse proxy URL")
pt.proxyClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // G402: Skip cert verification for testing
},
},
Timeout: 10 * time.Second,
}
return pt
}
// Stop gracefully stops the proxy server
func (pt *ProxyTest) Stop() {
if pt.server != nil {
err := pt.server.Stop()
if err != nil {
pt.t.Logf("Failed to stop proxy server: %v", err)
}
}
}
// ExpectAllowed makes a request through the proxy and expects it to be allowed with the given response body
func (pt *ProxyTest) ExpectAllowed(proxyURL, hostHeader, expectedBody string) {
pt.t.Helper()
req, err := http.NewRequest("GET", proxyURL, nil)
require.NoError(pt.t, err, "Failed to create request")
req.Host = hostHeader
resp, err := pt.client.Do(req)
require.NoError(pt.t, err, "Failed to make request")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Equal(pt.t, expectedBody, string(body), "Expected response body does not match")
}
// ExpectAllowedContains makes a request through the proxy and expects it to be allowed, checking that response contains the given text
func (pt *ProxyTest) ExpectAllowedContains(proxyURL, hostHeader, containsText string) {
pt.t.Helper()
req, err := http.NewRequest("GET", proxyURL, nil)
require.NoError(pt.t, err, "Failed to create request")
req.Host = hostHeader
resp, err := pt.client.Do(req)
require.NoError(pt.t, err, "Failed to make request")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Contains(pt.t, string(body), containsText, "Response does not contain expected text")
}
// ExpectDeny makes a request through the proxy and expects it to be denied
func (pt *ProxyTest) ExpectDeny(proxyURL, hostHeader string) {
pt.t.Helper()
req, err := http.NewRequest("GET", proxyURL, nil)
require.NoError(pt.t, err, "Failed to create request")
req.Host = hostHeader
resp, err := pt.client.Do(req)
require.NoError(pt.t, err, "Failed to make request")
defer resp.Body.Close()
require.Equal(pt.t, http.StatusForbidden, resp.StatusCode, "Expected 403 Forbidden status")
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Contains(pt.t, string(body), "Request Blocked by Boundary", "Expected request to be blocked")
}
// ExpectDenyViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
// and expects it to be denied
func (pt *ProxyTest) ExpectDenyViaProxy(targetURL string) {
pt.t.Helper()
resp, err := pt.proxyClient.Get(targetURL)
require.NoError(pt.t, err, "Failed to make request via proxy")
defer resp.Body.Close()
require.Equal(pt.t, http.StatusForbidden, resp.StatusCode, "Expected 403 Forbidden status")
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Contains(pt.t, string(body), "Request Blocked by Boundary", "Expected request to be blocked")
}
// ExpectAllowedViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
// and expects it to be allowed with the given response body
func (pt *ProxyTest) ExpectAllowedViaProxy(targetURL, expectedBody string) {
pt.t.Helper()
resp, err := pt.proxyClient.Get(targetURL)
require.NoError(pt.t, err, "Failed to make request via proxy")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Equal(pt.t, expectedBody, string(body), "Expected response body does not match")
}
// ExpectAllowedContainsViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
// and expects it to be allowed, checking that response contains the given text
func (pt *ProxyTest) ExpectAllowedContainsViaProxy(targetURL, containsText string) {
pt.t.Helper()
resp, err := pt.proxyClient.Get(targetURL)
require.NoError(pt.t, err, "Failed to make request via proxy")
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(pt.t, err, "Failed to read response body")
require.Contains(pt.t, string(body), containsText, "Response does not contain expected text")
}
// explicitCONNECTTunnel represents an established CONNECT tunnel
type explicitCONNECTTunnel struct {
tlsConn *tls.Conn
reader *bufio.Reader
}
// establishExplicitCONNECT establishes a CONNECT tunnel and returns a tunnel object
// targetHost should be in format "hostname:port" (e.g., "dev.coder.com:443")
func (pt *ProxyTest) establishExplicitCONNECT(targetHost string) (*explicitCONNECTTunnel, error) {
pt.t.Helper()
// Extract hostname for TLS ServerName (remove port if present)
hostParts := strings.Split(targetHost, ":")
serverName := hostParts[0]
// Connect to proxy
conn, err := net.Dial("tcp", "localhost:"+strconv.Itoa(pt.port))
if err != nil {
return nil, err
}
// Send explicit CONNECT request
connectReq := "CONNECT " + targetHost + " HTTP/1.1\r\n" +
"Host: " + targetHost + "\r\n" +
"\r\n"
_, err = conn.Write([]byte(connectReq))
if err != nil {
conn.Close()
return nil, err
}
// Read CONNECT response
reader := bufio.NewReader(conn)
resp, err := http.ReadResponse(reader, nil)
if err != nil {
conn.Close()
return nil, err
}
if resp.StatusCode != 200 {
conn.Close()
return nil, xerrors.Errorf("CONNECT failed with status: %d", resp.StatusCode)
}
// Wrap connection with TLS client
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true, // G402: Skip cert verification for testing
ServerName: serverName,
})
// Perform TLS handshake
err = tlsConn.Handshake()
if err != nil {
conn.Close()
return nil, err
}
return &explicitCONNECTTunnel{
tlsConn: tlsConn,
reader: bufio.NewReader(tlsConn),
}, nil
}
// sendRequest sends an HTTP request over the tunnel and returns the response body
func (tunnel *explicitCONNECTTunnel) sendRequest(targetHost, path string) ([]byte, error) {
// Send HTTP request over the tunnel
httpReq := "GET " + path + " HTTP/1.1\r\n" +
"Host: " + targetHost + "\r\n" +
"Connection: keep-alive\r\n" +
"\r\n"
_, err := tunnel.tlsConn.Write([]byte(httpReq))
if err != nil {
return nil, err
}
// Read HTTP response
httpResp, err := http.ReadResponse(tunnel.reader, nil)
if err != nil {
return nil, err
}
defer httpResp.Body.Close()
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, err
}
return body, nil
}
// sendRequestAndExpectDeny sends an HTTP request over the tunnel and expects it to be denied
func (tunnel *explicitCONNECTTunnel) sendRequestAndExpectDeny(targetHost, path string) error {
// Send HTTP request over the tunnel
httpReq := "GET " + path + " HTTP/1.1\r\n" +
"Host: " + targetHost + "\r\n" +
"Connection: keep-alive\r\n" +
"\r\n"
_, err := tunnel.tlsConn.Write([]byte(httpReq))
if err != nil {
return err
}
// Read HTTP response
httpResp, err := http.ReadResponse(tunnel.reader, nil)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusForbidden {
return xerrors.Errorf("expected 403 Forbidden, got %d", httpResp.StatusCode)
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return err
}
if !strings.Contains(string(body), "Request Blocked by Boundary") {
return xerrors.Errorf("expected blocked response, got: %s", string(body))
}
return nil
}
// close closes the tunnel connection
func (tunnel *explicitCONNECTTunnel) close() error {
return tunnel.tlsConn.Close()
}
@@ -0,0 +1,49 @@
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
package proxy
import (
"testing"
)
// TestProxyServerBasicHTTP tests basic HTTP request handling
func TestProxyServerBasicHTTP(t *testing.T) {
pt := NewProxyTest(t,
WithAllowedDomain("jsonplaceholder.typicode.com"),
).
Start()
defer pt.Stop()
t.Run("BasicHTTPRequest", func(t *testing.T) {
expectedResponse := `{
"userId": 1,
"id": 1,
"title": "delectus aut autem",
"completed": false
}`
pt.ExpectAllowed("http://localhost:8080/todos/1", "jsonplaceholder.typicode.com", expectedResponse)
})
t.Run("BlockedHTTPRequest", func(t *testing.T) {
pt.ExpectDeny("http://localhost:8080/", "example.com")
})
}
// TestProxyServerBasicHTTPS tests basic HTTPS request handling
func TestProxyServerBasicHTTPS(t *testing.T) {
pt := NewProxyTest(t,
WithCertManager("/tmp/boundary"),
WithAllowedDomain("dev.coder.com"),
).
Start()
defer pt.Stop()
t.Run("BasicHTTPSRequest", func(t *testing.T) {
expectedResponse := `{"message":"👋"}
`
pt.ExpectAllowed("https://localhost:8080/api/v2", "dev.coder.com", expectedResponse)
})
t.Run("BlockedHTTPSRequest", func(t *testing.T) {
pt.ExpectDeny("https://localhost:8080/", "example.com")
})
}
@@ -0,0 +1,151 @@
//nolint:revive,gocritic,errname,unconvert
package rulesengine
import (
"log/slog"
neturl "net/url"
"strings"
)
// Engine evaluates HTTP requests against a set of rules.
type Engine struct {
rules []Rule
logger *slog.Logger
}
// NewRuleEngine creates a new rule engine
func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine {
return Engine{
rules: rules,
logger: logger,
}
}
// Result contains the result of rule evaluation
type Result struct {
Allowed bool
Rule string // The rule that matched (if any)
}
// Evaluate evaluates a request and returns both result and matching rule
func (re *Engine) Evaluate(method, url string) Result {
// Check if any allow rule matches
for _, rule := range re.rules {
if re.matches(rule, method, url) {
return Result{
Allowed: true,
Rule: rule.Raw,
}
}
}
// Default deny if no allow rules match
return Result{
Allowed: false,
Rule: "",
}
}
// Matches checks if the rule matches the given method and URL using wildcard patterns
func (re *Engine) matches(r Rule, method, url string) bool {
// Check method patterns if they exist
if r.MethodPatterns != nil {
methodMatches := false
for mp := range r.MethodPatterns {
if string(mp) == method || mp == "*" {
methodMatches = true
break
}
}
if !methodMatches {
re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url)
return false
}
}
// If the provided url doesn't have a scheme parsing will fail. This can happen when you do something like `curl google.com`
if !strings.Contains(url, "://") {
// This is just for parsing, we won't use the scheme.
url = "https://" + url
}
parsedURL, err := neturl.Parse(url)
if err != nil {
re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err)
return false
}
if r.HostPattern != nil {
// For a host pattern to match, every label has to match or be an `*`.
// Subdomains also match automatically, meaning if the pattern is "example.com"
// and the real is "api.example.com", it should match. We check this by comparing
// from the end of the actual hostname with the pattern (which is in normal order).
labels := strings.Split(parsedURL.Hostname(), ".")
// If the host pattern is longer than the actual host, it's definitely not a match
if len(r.HostPattern) > len(labels) {
re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels))
return false
}
// Since host patterns cannot end with asterisk, we only need to handle:
// "example.com" or "*.example.com" - match from the end (allowing subdomains)
for i, lp := range r.HostPattern {
labelIndex := len(labels) - len(r.HostPattern) + i
if string(lp) != labels[labelIndex] && lp != "*" {
re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex])
return false
}
}
}
if r.PathPattern != nil {
segments := strings.Split(parsedURL.Path, "/")
// Skip the first empty segment if the path starts with "/"
if len(segments) > 0 && segments[0] == "" {
segments = segments[1:]
}
// Check if any of the path patterns match
pathMatches := false
for _, pattern := range r.PathPattern {
// If the path pattern is longer than the actual path, definitely not a match
if len(pattern) > len(segments) {
continue
}
// Each segment in the pattern must be either as asterisk or match the actual path segment
patternMatches := true
for i, sp := range pattern {
if sp != segments[i] && sp != "*" {
patternMatches = false
break
}
}
if !patternMatches {
continue
}
// If the path is longer than the path pattern, it should only match if:
// 1. The pattern is empty (root path matches any path), OR
// 2. The final segment of the pattern is an asterisk
if len(segments) > len(pattern) && len(pattern) > 0 && pattern[len(pattern)-1] != "*" {
continue
}
pathMatches = true
break
}
if !pathMatches {
re.logger.Debug("rule does not match", "reason", "no path pattern matches", "rule", r.Raw, "method", method, "url", url)
return false
}
}
re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url)
return true
}
@@ -0,0 +1,299 @@
//nolint:paralleltest,testpackage
package rulesengine
import (
"log/slog"
"testing"
)
func TestEngineMatches(t *testing.T) {
logger := slog.Default()
engine := NewRuleEngine(nil, logger)
tests := []struct {
name string
rule Rule
method string
url string
expected bool
}{
// Method pattern tests
{
name: "method matches exact",
rule: Rule{
MethodPatterns: map[string]struct{}{"GET": {}},
},
method: "GET",
url: "https://example.com/api",
expected: true,
},
{
name: "method does not match",
rule: Rule{
MethodPatterns: map[string]struct{}{"POST": {}},
},
method: "GET",
url: "https://example.com/api",
expected: false,
},
{
name: "method wildcard matches any",
rule: Rule{
MethodPatterns: map[string]struct{}{"*": {}},
},
method: "PUT",
url: "https://example.com/api",
expected: true,
},
{
name: "no method pattern allows all methods",
rule: Rule{
HostPattern: []string{"example", "com"},
},
method: "DELETE",
url: "https://example.com/api",
expected: true,
},
// Host pattern tests
{
name: "host matches exact",
rule: Rule{
HostPattern: []string{"example", "com"},
},
method: "GET",
url: "https://example.com/api",
expected: true,
},
{
name: "host does not match",
rule: Rule{
HostPattern: []string{"example", "org"},
},
method: "GET",
url: "https://example.com/api",
expected: false,
},
{
name: "subdomain matches",
rule: Rule{
HostPattern: []string{"example", "com"},
},
method: "GET",
url: "https://api.example.com/users",
expected: true,
},
{
name: "host pattern too long",
rule: Rule{
HostPattern: []string{"v1", "api", "example", "com"},
},
method: "GET",
url: "https://api.example.com/users",
expected: false,
},
{
name: "host wildcard matches",
rule: Rule{
HostPattern: []string{"*", "com"},
},
method: "GET",
url: "https://test.com/api",
expected: true,
},
{
name: "multiple host wildcards",
rule: Rule{
HostPattern: []string{"*", "*"},
},
method: "GET",
url: "https://api.example.com/users",
expected: true,
},
// Path pattern tests
{
name: "path matches exact",
rule: Rule{
PathPattern: [][]string{{"api", "users"}},
},
method: "GET",
url: "https://example.com/api/users",
expected: true,
},
{
name: "path does not match",
rule: Rule{
PathPattern: [][]string{{"api", "posts"}},
},
method: "GET",
url: "https://example.com/api/users",
expected: false,
},
{
name: "subpath does not implicitly match",
rule: Rule{
PathPattern: [][]string{{"api"}},
},
method: "GET",
url: "https://example.com/api/users/123",
expected: false,
},
{
name: "asterisk matches in path",
rule: Rule{
PathPattern: [][]string{{"api", "*"}},
},
method: "GET",
url: "https://example.com/api/users/123",
expected: true,
},
{
name: "one asterisk at end matches any number of trailing segments",
rule: Rule{
PathPattern: [][]string{{"api", "*"}},
},
method: "GET",
url: "https://example.com/api/foo/bar/baz",
expected: true,
},
{
name: "asterisk in middle of path only matches one segment",
rule: Rule{
PathPattern: [][]string{{"api", "*", "foo"}},
},
method: "GET",
url: "https://example.com/api/users/admin/foo",
expected: false,
},
{
name: "path pattern too long",
rule: Rule{
PathPattern: [][]string{{"api", "v1", "users", "profile"}},
},
method: "GET",
url: "https://example.com/api/v1/users",
expected: false,
},
{
name: "path wildcard matches",
rule: Rule{
PathPattern: [][]string{{"api", "*", "profile"}},
},
method: "GET",
url: "https://example.com/api/users/profile",
expected: true,
},
{
name: "multiple path wildcards",
rule: Rule{
PathPattern: [][]string{{"*", "*"}},
},
method: "GET",
url: "https://example.com/api/users/123",
expected: true,
},
// Combined pattern tests
{
name: "all patterns match",
rule: Rule{
MethodPatterns: map[string]struct{}{"POST": {}},
HostPattern: []string{"api", "com"},
PathPattern: [][]string{{"users"}},
},
method: "POST",
url: "https://api.com/users",
expected: true,
},
{
name: "method fails combined test",
rule: Rule{
MethodPatterns: map[string]struct{}{"POST": {}},
HostPattern: []string{"api", "com"},
PathPattern: [][]string{{"users"}},
},
method: "GET",
url: "https://api.com/users",
expected: false,
},
{
name: "host fails combined test",
rule: Rule{
MethodPatterns: map[string]struct{}{"POST": {}},
HostPattern: []string{"api", "org"},
PathPattern: [][]string{{"users"}},
},
method: "POST",
url: "https://api.com/users",
expected: false,
},
{
name: "path fails combined test",
rule: Rule{
MethodPatterns: map[string]struct{}{"POST": {}},
HostPattern: []string{"api", "com"},
PathPattern: [][]string{{"posts"}},
},
method: "POST",
url: "https://api.com/users",
expected: false,
},
{
name: "all wildcards match",
rule: Rule{
MethodPatterns: map[string]struct{}{"*": {}},
HostPattern: []string{"*", "*"},
PathPattern: [][]string{{"*", "*"}},
},
method: "PATCH",
url: "https://test.example.com/api/users/123",
expected: true,
},
// Edge cases
{
name: "empty rule matches everything",
rule: Rule{},
method: "GET",
url: "https://example.com/api/users",
expected: true,
},
{
name: "invalid URL",
rule: Rule{
HostPattern: []string{"example", "com"},
},
method: "GET",
url: "not-a-valid-url",
expected: false,
},
{
name: "root path",
rule: Rule{
PathPattern: [][]string{{}},
},
method: "GET",
url: "https://example.com/",
expected: true,
},
{
name: "localhost host",
rule: Rule{
HostPattern: []string{"localhost"},
},
method: "GET",
url: "http://localhost:8080/api",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := engine.matches(tt.rule, tt.method, tt.url)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}
@@ -0,0 +1,320 @@
//nolint:paralleltest,testpackage
package rulesengine
import (
"log/slog"
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestRoundTrip(t *testing.T) {
tcs := []struct {
name string
rules []string
url string
method string
expectParse bool
expectMatch bool
}{
{
name: "basic all three",
rules: []string{"method=GET,HEAD domain=github.com path=/wibble/wobble"},
url: "https://github.com/wibble/wobble",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "method rejects properly",
rules: []string{"method=GET"},
url: "https://github.com/wibble/wobble",
method: "POST",
expectParse: true,
expectMatch: false,
},
{
name: "domain rejects properly",
rules: []string{"domain=github.com"},
url: "https://example.com/wibble/wobble",
method: "GET",
expectParse: true,
expectMatch: false,
},
{
name: "path rejects properly",
rules: []string{"path=/wibble/wobble"},
url: "https://github.com/different/path",
method: "GET",
expectParse: true,
expectMatch: false,
},
{
name: "multiple rules - one matches",
rules: []string{"domain=github.com", "domain=example.com"},
url: "https://github.com/wibble/wobble",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "method wildcard matches anything",
rules: []string{"method=*"},
url: "https://github.com/wibble/wobble",
method: "POST",
expectParse: true,
expectMatch: true,
},
{
name: "domain wildcard matches anything",
rules: []string{"domain=*"},
url: "https://example.com/wibble/wobble",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "path wildcard matches anything",
rules: []string{"path=*"},
url: "https://github.com/any/path/here",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "all three wildcards match anything",
rules: []string{"method=* domain=* path=*"},
url: "https://example.com/some/random/path",
method: "DELETE",
expectParse: true,
expectMatch: true,
},
{
name: "query parameters don't break matching",
rules: []string{"domain=github.com path=/wibble/wobble"},
url: "https://github.com/wibble/wobble?param1=value1&param2=value2",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain wildcard segment matches",
rules: []string{"domain=*.github.com"},
url: "https://api.github.com/repos",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain cannot end with asterisk",
rules: []string{"domain=github.*"},
url: "https://github.com/repos",
method: "GET",
expectParse: false,
expectMatch: false,
},
{
name: "domain asterisk in middle matches",
rules: []string{"domain=github.*.com"},
url: "https://github.api.com/repos",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain wildcard matches multiple subdomains",
rules: []string{"domain=*.github.com"},
url: "https://v1.api.github.com/repos",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "path asterisk in middle matches",
rules: []string{"path=/api/*/users"},
url: "https://github.com/api/v1/users",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "path asterisk at start matches",
rules: []string{"path=/*/users"},
url: "https://github.com/api/users",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "path asterisk doesn't match multiple segments",
rules: []string{"path=/api/*/users"},
url: "https://github.com/api/../admin/users",
method: "GET",
expectParse: true,
expectMatch: false,
},
{
name: "path asterisk at end matches",
rules: []string{"path=/api/v1/*"},
url: "https://github.com/api/v1/users",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "path asterisk at end matches multiple segments",
rules: []string{"path=/api/*"},
url: "https://github.com/api/v1/users/123/details",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "subpaths do not match automatically",
rules: []string{"path=/api"},
url: "https://github.com/api/users",
method: "GET",
expectParse: true,
expectMatch: false,
},
{
name: "multiple rules match specific path and subpaths",
rules: []string{"path=/wibble/wobble,/wibble/wobble/*"},
url: "https://github.com/wibble/wobble/sub",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain matches without scheme - example.com case",
rules: []string{"domain=example.com"},
url: "example.com",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain matches without scheme - jsonplaceholder case",
rules: []string{"domain=jsonplaceholder.typicode.com"},
url: "jsonplaceholder.typicode.com",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain matches without scheme - dev.coder.com case",
rules: []string{"domain=dev.coder.com"},
url: "dev.coder.com",
method: "GET",
expectParse: true,
expectMatch: true,
},
}
logHandler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
})
logger := slog.New(logHandler)
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
rules, err := ParseAllowSpecs(tc.rules)
if tc.expectParse {
require.Nil(t, err)
engine := NewRuleEngine(rules, logger)
result := engine.Evaluate(tc.method, tc.url)
require.Equal(t, tc.expectMatch, result.Allowed)
} else {
require.NotNil(t, err)
}
})
}
}
func TestRoundTripExtraRules(t *testing.T) {
tcs := []struct {
name string
rules []string
url string
method string
expectParse bool
expectMatch bool
}{
{
name: "domain=* allows everything",
rules: []string{"domain=*"},
url: "https://github.com/wibble/wobble",
method: "DELETE",
expectParse: true,
expectMatch: true,
},
{
name: "specifying port in Domain key is NOT allowed",
rules: []string{"domain=github.com:8080"},
url: "https://github.com/wibble/wobble",
method: "DELETE",
expectParse: false,
expectMatch: false,
},
{
name: "specifying port in URL is allowed",
rules: []string{"domain=github.com"},
url: "https://github.com:8080/wibble/wobble",
method: "DELETE",
expectParse: true,
expectMatch: true,
},
{
name: "wildcard symbol at the end of path",
rules: []string{"method=GET,POST,PUT domain=github.com path=/api/issues/*"},
url: "https://github.com/api/issues/123/edit",
method: "POST",
expectParse: true,
expectMatch: true,
},
{
name: "wildcard symbol at the end of path doesn't match base path",
rules: []string{"method=GET domain=github.com path=/api/issues/*"},
url: "https://github.com/api/issues",
method: "GET",
expectParse: true,
expectMatch: false,
},
{
name: "includes all subdomains by default",
rules: []string{"domain=github.com"},
url: "https://x.users.api.github.com",
method: "GET",
expectParse: true,
expectMatch: true,
},
{
name: "domain wildcard in the middle matches exactly one label",
rules: []string{"domain=api.*.com"},
url: "https://api.v1.github.com",
method: "POST",
expectParse: true,
expectMatch: false,
},
}
logHandler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
})
logger := slog.New(logHandler)
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
rules, err := ParseAllowSpecs(tc.rules)
if tc.expectParse {
require.Nil(t, err)
engine := NewRuleEngine(rules, logger)
result := engine.Evaluate(tc.method, tc.url)
require.Equal(t, tc.expectMatch, result.Allowed)
} else {
require.NotNil(t, err)
}
})
}
}
@@ -0,0 +1,403 @@
//nolint:revive,gocritic,errname,unconvert
package rulesengine
import (
"strings"
"golang.org/x/xerrors"
)
// Rule represents an allow rule passed to the cli with --allow or read from the config file.
// Rules have a specific grammar that we need to parse carefully.
// Example: --allow="method=GET,PATCH domain=wibble.wobble.com, path=/posts/*"
type Rule struct {
// The path patterns that can match for this rule.
// - nil means all paths allowed
// - Each []string represents a path pattern (list of segments)
// - a path segment of `*` acts as a wild card.
PathPattern [][]string
// The labels of the host, i.e. ["google", "com"].
// - nil means all hosts allowed
// - A label of `*` acts as a wild card.
// - subdomains automatically match
HostPattern []string
// The allowed http methods.
// - nil means all methods allowed
MethodPatterns map[string]struct{}
// Raw rule string for logging
Raw string
}
// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
func ParseAllowSpecs(allowStrings []string) ([]Rule, error) {
var out []Rule
for _, s := range allowStrings {
r, err := parseAllowRule(s)
if err != nil {
return nil, xerrors.Errorf("failed to parse allow '%s': %v", s, err)
}
out = append(out, r)
}
return out, nil
}
// parseAllowRule takes an allow rule string and tries to parse it as a rule.
func parseAllowRule(ruleStr string) (Rule, error) {
rule := Rule{
Raw: ruleStr,
}
// Functions called by this function used a really common pattern: recursive descent parsing.
// All the helper functions for parsing an allow rule will be called like `thing, rest, err := parseThing(rest)`.
// What's going on here is that we try to parse some expected text from the front of the string.
// If we succeed, we get back the thing we parsed and the remaining text. If we fail, we get back a non nil error.
rest := ruleStr
var key string
var err error
// Ann allow rule can have as many key=value pairs as needed, we go until there's no more text in the rule.
for rest != "" {
// Parse the key
key, rest, err = parseKey(rest)
if err != nil {
return Rule{}, xerrors.Errorf("failed to parse key: %v", err)
}
// Parse the value based on the key type
switch key {
case "method":
// Initialize Methods map if needed
if rule.MethodPatterns == nil {
rule.MethodPatterns = make(map[string]struct{})
}
var method string
for {
method, rest, err = parseMethodPattern(rest)
if err != nil {
return Rule{}, xerrors.Errorf("failed to parse method: %v", err)
}
rule.MethodPatterns[method] = struct{}{}
// Check if there's a comma for more methods
if rest != "" && rest[0] == ',' {
rest = rest[1:] // Skip the comma
continue
}
break
}
case "domain":
var host []string
host, rest, err = parseHostPattern(rest)
if err != nil {
return Rule{}, xerrors.Errorf("failed to parse domain: %v", err)
}
// Convert labels to strings
rule.HostPattern = append(rule.HostPattern, host...)
case "path":
for {
var segments []string
segments, rest, err = parsePathPattern(rest)
if err != nil {
return Rule{}, xerrors.Errorf("failed to parse path: %v", err)
}
// Add this path pattern to the list of patterns
rule.PathPattern = append(rule.PathPattern, segments)
// Check if there's a comma for more paths
if rest != "" && rest[0] == ',' {
rest = rest[1:] // Skip the comma
continue
}
break
}
default:
return Rule{}, xerrors.Errorf("unknown key: %s", key)
}
// Skip whitespace separators (only support mac and linux so \r\n shouldn't be a thing)
for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == '\n') {
rest = rest[1:]
}
}
return rule, nil
}
// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
// various systems.
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
func parseMethodPattern(token string) (string, string, error) {
if token == "" {
return "", "", xerrors.New("expected http token, got empty string")
}
// Find the first invalid HTTP token character
for i := 0; i < len(token); i++ {
if !isHTTPTokenChar(token[i]) {
return token[:i], token[i:], nil
}
}
// Entire string is a valid HTTP token
return token, "", nil
}
// The valid characters that can be in an http token (like the lexer/parser kind of token).
func isHTTPTokenChar(c byte) bool {
switch {
// Alpha numeric is fine.
case c >= 'A' && c <= 'Z':
return true
case c >= 'a' && c <= 'z':
return true
case c >= '0' && c <= '9':
return true
// These special characters are also allowed unbelievably.
case c == '!' || c == '#' || c == '$' || c == '%' || c == '&' ||
c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' ||
c == '^' || c == '_' || c == '`' || c == '|' || c == '~':
return true
default:
return false
}
}
// Represents a valid host.
// https://datatracker.ietf.org/doc/html/rfc952
// https://datatracker.ietf.org/doc/html/rfc1123#page-13
func parseHostPattern(input string) ([]string, string, error) {
rest := input
var host []string
var err error
if input == "" {
return nil, "", xerrors.New("expected host, got empty string")
}
// There should be at least one label.
var label string
label, rest, err = parseLabelPattern(rest)
if err != nil {
return nil, "", err
}
host = append(host, label)
// A host is just a bunch of labels separated by `.` characters.
var found bool
for {
rest, found = strings.CutPrefix(rest, ".")
if !found {
break
}
label, rest, err = parseLabelPattern(rest)
if err != nil {
return nil, "", err
}
host = append(host, label)
}
// If the host is a single standalone asterisk, that's the same as "matches anything"
if len(host) == 1 && host[0] == "*" {
return host, rest, nil
}
// Validate: host patterns other than a single `*` cannot end with asterisk
if len(host) > 0 && host[len(host)-1] == "*" {
return nil, "", xerrors.New("host patterns cannot end with asterisk")
}
return host, rest, nil
}
func parseLabelPattern(rest string) (string, string, error) {
if rest == "" {
return "", "", xerrors.New("expected label, got empty string")
}
// If the label is simply an asterisk, good to go.
if rest[0] == '*' {
return "*", rest[1:], nil
}
// First try to get a valid leading char. Leading char in a label cannot be a hyphen.
if !isValidLabelChar(rest[0]) || rest[0] == '-' {
return "", "", xerrors.Errorf("could not pull label from front of string: %s", rest)
}
// Go until the next character is not a valid char
var i int
for i = 1; i < len(rest) && isValidLabelChar(rest[i]); i++ {
}
// Final char in a label cannot be a hyphen.
if rest[i-1] == '-' {
return "", "", xerrors.Errorf("invalid label: %s", rest[:i])
}
return rest[:i], rest[i:], nil
}
func isValidLabelChar(c byte) bool {
switch {
// Alpha numeric is fine.
case c >= 'A' && c <= 'Z':
return true
case c >= 'a' && c <= 'z':
return true
case c >= '0' && c <= '9':
return true
// Hyphens are good
case c == '-':
return true
default:
return false
}
}
// https://myfileserver.com/"my file"
func parsePathPattern(input string) ([]string, string, error) {
if input == "" {
return nil, "", nil
}
rest := input
var segments []string
var err error
// If the path doesn't start with '/', it's not a valid absolute path
// But we'll be flexible and parse relative paths too
for {
// Skip leading slash if present
if rest != "" && rest[0] == '/' {
rest = rest[1:]
}
// If we've consumed all input, we're done
if rest == "" {
break
}
// Parse the next segment
var segment string
segment, rest, err = parsePathSegmentPattern(rest)
if err != nil {
return nil, "", err
}
// If we got an empty segment and there's still input,
// it means we hit an invalid character
if segment == "" && rest != "" {
break
}
segments = append(segments, segment)
// If there's no slash after the segment, we're done parsing the path
if rest == "" || rest[0] != '/' {
break
}
}
return segments, rest, nil
}
func parsePathSegmentPattern(input string) (string, string, error) {
if input == "" {
return "", "", nil
}
if len(input) > 0 && input[0] == '*' {
if len(input) > 1 && input[1] != '/' {
return "", "", xerrors.Errorf("path segment wildcards must be for the entire segment, got: %s", input)
}
return "*", input[1:], nil
}
var i int
for i = 0; i < len(input); i++ {
c := input[i]
// Check for percent-encoded characters (%XX)
if c == '%' {
if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) {
break
}
i += 2
continue
}
// Check for valid pchar characters
if !isPChar(c) {
break
}
}
return input[:i], input[i:], nil
}
// isUnreserved returns true if the character is unreserved per RFC 3986
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
func isUnreserved(c byte) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~'
}
// isPChar returns true if the character is valid in a path segment (excluding percent-encoded)
// pchar = unreserved / sub-delims / ":" / "@"
// Note: We exclude comma from sub-delims for our rule parsing to support comma-separated paths
func isPChar(c byte) bool {
return isUnreserved(c) || isSubDelimExceptComma(c) || c == ':' || c == '@'
}
// isSubDelimExceptComma returns true if the character is a sub-delimiter except comma
func isSubDelimExceptComma(c byte) bool {
return c == '!' || c == '$' || c == '&' || c == '\'' ||
c == '(' || c == ')' || c == '*' || c == '+' ||
c == ';' || c == '='
}
// isHexDigit returns true if the character is a hexadecimal digit
func isHexDigit(c byte) bool {
return (c >= '0' && c <= '9') ||
(c >= 'A' && c <= 'F') ||
(c >= 'a' && c <= 'f')
}
// parseKey parses the predefined keys that the cli can handle. Also strips the `=` following the key.
func parseKey(rule string) (string, string, error) {
if rule == "" {
return "", "", xerrors.New("expected key")
}
// These are the current keys we support.
keys := []string{"method", "domain", "path"}
for _, key := range keys {
if rest, found := strings.CutPrefix(rule, key+"="); found {
return key, rest, nil
}
}
return "", "", xerrors.New("expected key")
}
File diff suppressed because it is too large Load Diff
+24
View File
@@ -0,0 +1,24 @@
//go:build linux
package run
import (
"context"
"fmt"
"log/slog"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
"github.com/coder/coder/v2/enterprise/cli/boundary/landjail"
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager"
)
func Run(ctx context.Context, logger *slog.Logger, cfg config.AppConfig) error {
switch cfg.JailType {
case config.NSJailType:
return nsjail_manager.Run(ctx, logger, cfg)
case config.LandjailType:
return landjail.Run(ctx, logger, cfg)
default:
return fmt.Errorf("unknown jail type: %s", cfg.JailType)
}
}
+361
View File
@@ -0,0 +1,361 @@
//nolint:revive,gocritic,errname,unconvert
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"log/slog"
"math/big"
"net"
"os"
"path/filepath"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
)
type Manager interface {
SetupTLSAndWriteCACert() (*tls.Config, string, string, error)
}
type Config struct {
Logger *slog.Logger
ConfigDir string
Uid int
Gid int
}
// CertificateManager manages TLS certificates for the proxy
type CertificateManager struct {
caKey *rsa.PrivateKey
caCert *x509.Certificate
certCache map[string]*tls.Certificate
mutex sync.RWMutex
logger *slog.Logger
configDir string
uid int
gid int
}
// NewCertificateManager creates a new certificate manager
func NewCertificateManager(config Config) (*CertificateManager, error) {
cm := &CertificateManager{
certCache: make(map[string]*tls.Certificate),
logger: config.Logger,
configDir: config.ConfigDir,
uid: config.Uid,
gid: config.Gid,
}
// Load or generate CA certificate
err := cm.loadOrGenerateCA()
if err != nil {
return nil, xerrors.Errorf("failed to load or generate CA: %v", err)
}
return cm, nil
}
// SetupTLSAndWriteCACert sets up TLS config and writes CA certificate to file
// Returns the TLS config, CA cert path, and config directory
func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, error) {
// Get TLS config
tlsConfig := cm.getTLSConfig()
// Get CA certificate PEM
caCertPEM, err := cm.getCACertPEM()
if err != nil {
return nil, xerrors.Errorf("failed to get CA certificate: %v", err)
}
// Write CA certificate to file
caCertPath := filepath.Join(cm.configDir, config.CACertName)
err = os.WriteFile(caCertPath, caCertPEM, 0o600)
if err != nil {
return nil, xerrors.Errorf("failed to write CA certificate file: %v", err)
}
return tlsConfig, nil
}
// loadOrGenerateCA loads existing CA or generates a new one
func (cm *CertificateManager) loadOrGenerateCA() error {
caKeyPath := filepath.Join(cm.configDir, config.CAKeyName)
caCertPath := filepath.Join(cm.configDir, config.CACertName)
cm.logger.Debug("paths", "cm.configDir", cm.configDir, "caCertPath", caCertPath)
// Try to load existing CA
if cm.loadExistingCA(caKeyPath, caCertPath) {
cm.logger.Debug("Loaded existing CA certificate")
return nil
}
// Generate new CA
cm.logger.Info("Generating new CA certificate")
return cm.generateCA(caKeyPath, caCertPath)
}
// getTLSConfig returns a TLS config that generates certificates on-demand
func (cm *CertificateManager) getTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: cm.getCertificate,
MinVersion: tls.VersionTLS12,
}
}
// getCACertPEM returns the CA certificate in PEM format
func (cm *CertificateManager) getCACertPEM() ([]byte, error) {
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cm.caCert.Raw,
}), nil
}
// loadExistingCA attempts to load existing CA files
func (cm *CertificateManager) loadExistingCA(keyPath, certPath string) bool {
// Check if files exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
return false
}
if _, err := os.Stat(certPath); os.IsNotExist(err) {
return false
}
// Load private key
keyData, err := os.ReadFile(keyPath)
if err != nil {
cm.logger.Warn("Failed to read CA key", "error", err)
return false
}
keyBlock, _ := pem.Decode(keyData)
if keyBlock == nil {
cm.logger.Warn("Failed to decode CA key PEM")
return false
}
privateKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
if err != nil {
cm.logger.Warn("Failed to parse CA private key", "error", err)
return false
}
// Load certificate
certData, err := os.ReadFile(certPath)
if err != nil {
cm.logger.Warn("Failed to read CA cert", "error", err)
return false
}
certBlock, _ := pem.Decode(certData)
if certBlock == nil {
cm.logger.Warn("Failed to decode CA cert PEM")
return false
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
cm.logger.Warn("Failed to parse CA certificate", "error", err)
return false
}
// Check if certificate is still valid
if time.Now().After(cert.NotAfter) {
cm.logger.Warn("CA certificate has expired")
return false
}
cm.caKey = privateKey
cm.caCert = cert
return true
}
// generateCA generates a new CA certificate and key
func (cm *CertificateManager) generateCA(keyPath, certPath string) error {
// Create config directory if it doesn't exist
err := os.MkdirAll(cm.configDir, 0o700)
if err != nil {
return xerrors.Errorf("failed to create config directory at %s: %v", cm.configDir, err)
}
// ensure the directory is owned by the original user
err = os.Chown(cm.configDir, cm.uid, cm.gid)
if err != nil {
cm.logger.Warn("Failed to change config directory ownership", "error", err)
}
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return xerrors.Errorf("failed to generate private key: %v", err)
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"coder"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{""},
StreetAddress: []string{""},
PostalCode: []string{""},
CommonName: "coder CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return xerrors.Errorf("failed to create certificate: %v", err)
}
// Parse certificate
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return xerrors.Errorf("failed to parse certificate: %v", err)
}
// Save private key
keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return xerrors.Errorf("failed to create key file: %v", err)
}
defer func() {
err := keyFile.Close()
if err != nil {
cm.logger.Error("Failed to close key file", "error", err)
}
}()
err = pem.Encode(keyFile, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
if err != nil {
return xerrors.Errorf("failed to write key to file: %v", err)
}
// Save certificate
certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
return xerrors.Errorf("failed to create cert file: %v", err)
}
defer func() {
err := certFile.Close()
if err != nil {
cm.logger.Error("Failed to close cert file", "error", err)
}
}()
err = pem.Encode(certFile, &pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
if err != nil {
return xerrors.Errorf("failed to write cert to file: %v", err)
}
cm.caKey = privateKey
cm.caCert = cert
return nil
}
// getCertificate generates or retrieves a certificate for the given hostname
func (cm *CertificateManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
hostname := hello.ServerName
if hostname == "" {
return nil, xerrors.New("no server name provided")
}
// Check cache first
cm.mutex.RLock()
if cert, exists := cm.certCache[hostname]; exists {
cm.mutex.RUnlock()
return cert, nil
}
cm.mutex.RUnlock()
// Generate new certificate
cm.mutex.Lock()
defer cm.mutex.Unlock()
// Double-check cache (another goroutine might have generated it)
if cert, exists := cm.certCache[hostname]; exists {
return cert, nil
}
cert, err := cm.generateServerCertificate(hostname)
if err != nil {
return nil, xerrors.Errorf("failed to generate certificate for %s: %v", hostname, err)
}
cm.certCache[hostname] = cert
cm.logger.Debug("Generated certificate", "hostname", hostname)
return cert, nil
}
// generateServerCertificate generates a server certificate for the given hostname
func (cm *CertificateManager) generateServerCertificate(hostname string) (*tls.Certificate, error) {
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, xerrors.Errorf("failed to generate private key: %v", err)
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
Organization: []string{"coder"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{""},
StreetAddress: []string{""},
PostalCode: []string{""},
CommonName: hostname,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour), // 1 day
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{hostname},
}
// Add IP address if hostname is an IP
if ip := net.ParseIP(hostname); ip != nil {
template.IPAddresses = []net.IP{ip}
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, cm.caCert, &privateKey.PublicKey, cm.caKey)
if err != nil {
return nil, xerrors.Errorf("failed to create certificate: %v", err)
}
// Create TLS certificate
tlsCert := &tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: privateKey,
}
cm.logger.Debug("Generated certificate", "hostname", hostname)
return tlsCert, nil
}
+10
View File
@@ -0,0 +1,10 @@
//nolint:paralleltest,testpackage,revive,gocritic
package tls
import "testing"
// Stub test file - tests removed
func TestStub(t *testing.T) {
// This is a stub test
t.Skip("stub test file")
}
+25
View File
@@ -0,0 +1,25 @@
//nolint:revive,gocritic,errname,unconvert
package util
import "strings"
func MergeEnvs(base []string, extra map[string]string) []string {
envMap := make(map[string]string)
for _, env := range base {
parts := strings.SplitN(env, "=", 2)
if len(parts) == 2 {
envMap[parts[0]] = parts[1]
}
}
for key, value := range extra {
envMap[key] = value
}
merged := make([]string, 0, len(envMap))
for key, value := range envMap {
merged = append(merged, key+"="+value)
}
return merged
}
+13
View File
@@ -0,0 +1,13 @@
package cli
import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/enterprise/cli/boundary"
"github.com/coder/serpent"
)
func (*RootCmd) boundary() *serpent.Command {
cmd := boundary.BaseCommand(buildinfo.Version())
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
return cmd
}
+4 -2
View File
@@ -29,8 +29,10 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command {
}
}
func (*RootCmd) enterpriseExperimental() []*serpent.Command {
return []*serpent.Command{}
func (r *RootCmd) enterpriseExperimental() []*serpent.Command {
return []*serpent.Command{
r.boundary(),
}
}
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
+3 -3
View File
@@ -453,7 +453,7 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
howett.net/plist v1.0.0 // indirect
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73 // indirect
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect
sigs.k8s.io/yaml v1.5.0 // indirect
)
@@ -472,10 +472,10 @@ require (
require (
github.com/anthropics/anthropic-sdk-go v1.19.0
github.com/brianvoe/gofakeit/v7 v7.14.0
github.com/cenkalti/backoff/v5 v5.0.3
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1
github.com/coder/aisdk-go v0.0.9
github.com/coder/boundary v0.0.1-alpha
github.com/coder/preview v1.0.4
github.com/danieljoos/wincred v1.2.3
github.com/dgraph-io/ristretto/v2 v2.3.0
@@ -483,6 +483,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/go-git/go-git/v5 v5.16.2
github.com/icholy/replace v0.6.0
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c
github.com/mark3labs/mcp-go v0.38.0
gonum.org/v1/gonum v0.17.0
)
@@ -516,7 +517,6 @@ require (
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect
github.com/bits-and-blooms/bitset v1.24.4 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
+4 -3
View File
@@ -931,8 +931,6 @@ github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1 h1:cr2K36NgU1fHKt
github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1/go.mod h1:5Ztcl+9HF0tog85iEEuFdaBkBe8EkxJe5XjbMOFviQs=
github.com/coder/aisdk-go v0.0.9 h1:Vzo/k2qwVGLTR10ESDeP2Ecek1SdPfZlEjtTfMveiVo=
github.com/coder/aisdk-go v0.0.9/go.mod h1:KF6/Vkono0FJJOtWtveh5j7yfNrSctVTpwgweYWSp5M=
github.com/coder/boundary v0.0.1-alpha h1:6shUQ2zkrWrfbgVcqWvpV2ibljOQvPvYqTctWBkKoUA=
github.com/coder/boundary v0.0.1-alpha/go.mod h1:d1AMFw81rUgrGHuZzWdPNhkY0G8w7pvLNLYF0e3ceC4=
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwuWwPHPYoCZ/KLAjHv5g4h2MS4f2/MTI=
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
github.com/coder/clistat v1.2.0 h1:37KJKqiCllJsRvWqTHf3qiLIXX0JB6oqE5oxcqgdLkY=
@@ -1544,6 +1542,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/kyokomi/emoji/v2 v2.2.13 h1:GhTfQa67venUUvmleTNFnb+bi7S3aocF7ZCXU9fSO7U=
github.com/kyokomi/emoji/v2 v2.2.13/go.mod h1:JUcn42DTdsXJo1SWanHh4HKDEyPaR5CqkmoirZZP9qE=
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c h1:QcKqiunpt7hooa/xIx0iyepA6Cs2BgKexaYOxHvHNCs=
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c/go.mod h1:stwyhp9tfeEy3A4bRJLdOEvjW/CetRJg/vcijNG8M5A=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@@ -2845,8 +2845,9 @@ k8s.io/utils v0.0.0-20241210054802-24370beab758 h1:sdbE21q2nlQtFh65saZY+rRM6x6aJ
k8s.io/utils v0.0.0-20241210054802-24370beab758/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
kernel.org/pub/linux/libs/security/libcap/cap v1.2.73 h1:Th2b8jljYqkyZKS3aD3N9VpYsQpHuXLgea+SZUIfODA=
kernel.org/pub/linux/libs/security/libcap/cap v1.2.73/go.mod h1:hbeKwKcboEsxARYmcy/AdPVN11wmT/Wnpgv4k4ftyqY=
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73 h1:SEAEUiPVylTD4vqqi+vtGkSnXeP2FcRO3FoZB1MklMw=
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73/go.mod h1:+l6Ee2F59XiJ2I6WR5ObpC1utCQJZ/VLsEbQCD8RG24=
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 h1:Z06sMOzc0GNCwp6efaVrIrz4ywGJ1v+DP0pjVkOfDuA=
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77/go.mod h1:+l6Ee2F59XiJ2I6WR5ObpC1utCQJZ/VLsEbQCD8RG24=
lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI=
+536 -593
View File
File diff suppressed because it is too large Load Diff
-5
View File
@@ -244,11 +244,6 @@ message Devcontainer {
string workspace_folder = 1;
string config_path = 2;
string name = 3;
optional bytes id = 4;
optional bytes subagent_id = 5;
repeated App apps = 6;
repeated Script scripts = 7;
repeated Env envs = 8;
}
enum AppOpenIn {
-20
View File
@@ -306,11 +306,6 @@ export interface Devcontainer {
workspaceFolder: string;
configPath: string;
name: string;
id?: Uint8Array | undefined;
subagentId?: Uint8Array | undefined;
apps: App[];
scripts: Script[];
envs: Env[];
}
/** App represents a dev-accessible application on the workspace. */
@@ -1100,21 +1095,6 @@ export const Devcontainer = {
if (message.name !== "") {
writer.uint32(26).string(message.name);
}
if (message.id !== undefined) {
writer.uint32(34).bytes(message.id);
}
if (message.subagentId !== undefined) {
writer.uint32(42).bytes(message.subagentId);
}
for (const v of message.apps) {
App.encode(v!, writer.uint32(50).fork()).ldelim();
}
for (const v of message.scripts) {
Script.encode(v!, writer.uint32(58).fork()).ldelim();
}
for (const v of message.envs) {
Env.encode(v!, writer.uint32(66).fork()).ldelim();
}
return writer;
},
};