Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ce2aed9002 |
@@ -211,6 +211,14 @@ issues:
|
||||
- path: scripts/rules.go
|
||||
linters:
|
||||
- ALL
|
||||
# Boundary code is imported from github.com/coder/boundary and has different
|
||||
# lint standards. Suppress lint issues in this imported code.
|
||||
- path: enterprise/cli/boundary/
|
||||
linters:
|
||||
- revive
|
||||
- gocritic
|
||||
- gosec
|
||||
- errorlint
|
||||
|
||||
fix: true
|
||||
max-issues-per-linter: 0
|
||||
|
||||
+10
-4
@@ -1,12 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) boundary() *serpent.Command {
|
||||
cmd := boundarycli.BaseCommand() // Package coder/boundary/cli exports a "base command" designed to be integrated as a subcommand.
|
||||
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
|
||||
return cmd
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests (enterprise)",
|
||||
Long: `boundary creates an isolated network environment for target processes. This is an enterprise feature.`,
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
return xerrors.New("boundary is an enterprise feature; upgrade to use this command")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,15 +5,13 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
boundarycli "github.com/coder/boundary/cli"
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// Actually testing the functionality of coder/boundary takes place in the
|
||||
// coder/boundary repo, since it's a dependency of coder.
|
||||
// Here we want to test basically that integrating it as a subcommand doesn't break anything.
|
||||
// Here we want to test that integrating boundary as a subcommand doesn't break anything.
|
||||
// The full boundary functionality is tested in enterprise/cli.
|
||||
func TestBoundarySubcommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
@@ -27,7 +25,5 @@ func TestBoundarySubcommand(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Expect the --help output to include the short description.
|
||||
// We're simply confirming that `coder boundary --help` ran without a runtime error as
|
||||
// a good chunk of serpents self validation logic happens at runtime.
|
||||
pty.ExpectMatch(boundarycli.BaseCommand().Short)
|
||||
pty.ExpectMatch("Network isolation tool")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import "log/slog"
|
||||
|
||||
// LogAuditor implements proxy.Auditor by logging to slog
|
||||
type LogAuditor struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLogAuditor creates a new LogAuditor
|
||||
func NewLogAuditor(logger *slog.Logger) *LogAuditor {
|
||||
return &LogAuditor{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// AuditRequest logs the request using structured logging
|
||||
func (a *LogAuditor) AuditRequest(req Request) {
|
||||
if req.Allowed {
|
||||
a.logger.Info("ALLOW",
|
||||
"method", req.Method,
|
||||
"url", req.URL,
|
||||
"host", req.Host,
|
||||
"rule", req.Rule)
|
||||
} else {
|
||||
a.logger.Warn("DENY",
|
||||
"method", req.Method,
|
||||
"url", req.URL,
|
||||
"host", req.Host,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import "testing"
|
||||
|
||||
// Stub test file - tests removed
|
||||
func TestStub(t *testing.T) {
|
||||
// This is a stub test
|
||||
t.Skip("stub test file")
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// MultiAuditor wraps multiple auditors and sends audit events to all of them.
|
||||
type MultiAuditor struct {
|
||||
auditors []Auditor
|
||||
}
|
||||
|
||||
// NewMultiAuditor creates a new MultiAuditor that sends to all provided auditors.
|
||||
func NewMultiAuditor(auditors ...Auditor) *MultiAuditor {
|
||||
return &MultiAuditor{auditors: auditors}
|
||||
}
|
||||
|
||||
// AuditRequest sends the request to all wrapped auditors.
|
||||
func (m *MultiAuditor) AuditRequest(req Request) {
|
||||
for _, a := range m.auditors {
|
||||
a.AuditRequest(req)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupAuditor creates and configures the appropriate auditors based on the
|
||||
// provided configuration. It always includes a LogAuditor for stderr logging,
|
||||
// and conditionally adds a SocketAuditor if audit logs are enabled and the
|
||||
// workspace agent's log proxy socket exists.
|
||||
func SetupAuditor(ctx context.Context, logger *slog.Logger, disableAuditLogs bool, logProxySocketPath string) (Auditor, error) {
|
||||
stderrAuditor := NewLogAuditor(logger)
|
||||
auditors := []Auditor{stderrAuditor}
|
||||
|
||||
if !disableAuditLogs {
|
||||
if logProxySocketPath == "" {
|
||||
return nil, xerrors.New("log proxy socket path is undefined")
|
||||
}
|
||||
// Since boundary is separately versioned from a Coder deployment, it's possible
|
||||
// Coder is on an older version that will not create the socket and listen for
|
||||
// the audit logs. Here we check for the socket to determine if the workspace
|
||||
// agent is on a new enough version to prevent boundary application log spam from
|
||||
// trying to connect to the agent. This assumes the agent will run and start the
|
||||
// log proxy server before boundary runs.
|
||||
_, err := os.Stat(logProxySocketPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return nil, xerrors.Errorf("failed to stat log proxy socket: %v", err)
|
||||
}
|
||||
agentWillProxy := !os.IsNotExist(err)
|
||||
if agentWillProxy {
|
||||
socketAuditor := NewSocketAuditor(logger, logProxySocketPath)
|
||||
go socketAuditor.Loop(ctx)
|
||||
auditors = append(auditors, socketAuditor)
|
||||
} else {
|
||||
logger.Warn("Audit logs are disabled; workspace agent has not created log proxy socket",
|
||||
"socket", logProxySocketPath)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Audit logs are disabled by configuration")
|
||||
}
|
||||
|
||||
return NewMultiAuditor(auditors...), nil
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockAuditor struct {
|
||||
onAudit func(req Request)
|
||||
}
|
||||
|
||||
func (m *mockAuditor) AuditRequest(req Request) {
|
||||
if m.onAudit != nil {
|
||||
m.onAudit(req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_DisabledAuditLogs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, true, "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 1 {
|
||||
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_EmptySocketPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := SetupAuditor(ctx, logger, false, "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty socket path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_SocketDoesNotExist(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, false, "/nonexistent/socket/path")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 1 {
|
||||
t.Errorf("expected 1 auditor, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupAuditor_SocketExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a temporary file to simulate the socket existing
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
f, err := os.Create(socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp file: %v", err)
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close temp file: %v", err)
|
||||
}
|
||||
|
||||
auditor, err := SetupAuditor(ctx, logger, false, socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
multi, ok := auditor.(*MultiAuditor)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MultiAuditor, got %T", auditor)
|
||||
}
|
||||
|
||||
if len(multi.auditors) != 2 {
|
||||
t.Errorf("expected 2 auditors, got %d", len(multi.auditors))
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[0].(*LogAuditor); !ok {
|
||||
t.Errorf("expected first auditor to be *LogAuditor, got %T", multi.auditors[0])
|
||||
}
|
||||
|
||||
if _, ok := multi.auditors[1].(*SocketAuditor); !ok {
|
||||
t.Errorf("expected second auditor to be *SocketAuditor, got %T", multi.auditors[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiAuditor_AuditRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var called1, called2 bool
|
||||
auditor1 := &mockAuditor{onAudit: func(req Request) { called1 = true }}
|
||||
auditor2 := &mockAuditor{onAudit: func(req Request) { called2 = true }}
|
||||
|
||||
multi := NewMultiAuditor(auditor1, auditor2)
|
||||
multi.AuditRequest(Request{Method: "GET", URL: "https://example.com"})
|
||||
|
||||
if !called1 {
|
||||
t.Error("expected first auditor to be called")
|
||||
}
|
||||
if !called2 {
|
||||
t.Error("expected second auditor to be called")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
type Auditor interface {
|
||||
AuditRequest(req Request)
|
||||
}
|
||||
|
||||
// Request represents information about an HTTP request for auditing
|
||||
type Request struct {
|
||||
Method string
|
||||
URL string // The fully qualified request URL (scheme, domain, optional path).
|
||||
Host string
|
||||
Allowed bool
|
||||
Rule string // The rule that matched (if any)
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// The batch size and timer duration are chosen to provide reasonable responsiveness
|
||||
// for consumers of the aggregated logs while still minimizing the agent <-> coderd
|
||||
// network I/O when an AI agent is actively making network requests.
|
||||
defaultBatchSize = 10
|
||||
defaultBatchTimerDuration = 5 * time.Second
|
||||
)
|
||||
|
||||
// SocketAuditor implements the Auditor interface. It sends logs to the
|
||||
// workspace agent's boundary log proxy socket. It queues logs and sends
|
||||
// them in batches using a batch size and timer. The internal queue operates
|
||||
// as a FIFO i.e., logs are sent in the order they are received and dropped
|
||||
// if the queue is full.
|
||||
type SocketAuditor struct {
|
||||
dial func() (net.Conn, error)
|
||||
logger *slog.Logger
|
||||
logCh chan *agentproto.BoundaryLog
|
||||
batchSize int
|
||||
batchTimerDuration time.Duration
|
||||
socketPath string
|
||||
|
||||
// onFlushAttempt is called after each flush attempt (intended for testing).
|
||||
onFlushAttempt func()
|
||||
}
|
||||
|
||||
// NewSocketAuditor creates a new SocketAuditor that sends logs to the agent's
|
||||
// boundary log proxy socket after SocketAuditor.Loop is called. The socket path
|
||||
// is read from EnvAuditSocketPath, falling back to defaultAuditSocketPath.
|
||||
func NewSocketAuditor(logger *slog.Logger, socketPath string) *SocketAuditor {
|
||||
// This channel buffer size intends to allow enough buffering for bursty
|
||||
// AI agent network requests while a batch is being sent to the workspace
|
||||
// agent.
|
||||
const logChBufSize = 2 * defaultBatchSize
|
||||
|
||||
return &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, logChBufSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
socketPath: socketPath,
|
||||
}
|
||||
}
|
||||
|
||||
// AuditRequest implements the Auditor interface. It queues the log to be sent to the
|
||||
// agent in a batch.
|
||||
func (s *SocketAuditor) AuditRequest(req Request) {
|
||||
httpReq := &agentproto.BoundaryLog_HttpRequest{
|
||||
Method: req.Method,
|
||||
Url: req.URL,
|
||||
}
|
||||
// Only include the matched rule for allowed requests. Boundary is deny by
|
||||
// default, so rules are what allow requests.
|
||||
if req.Allowed {
|
||||
httpReq.MatchedRule = req.Rule
|
||||
}
|
||||
|
||||
log := &agentproto.BoundaryLog{
|
||||
Allowed: req.Allowed,
|
||||
Time: timestamppb.Now(),
|
||||
Resource: &agentproto.BoundaryLog_HttpRequest_{HttpRequest: httpReq},
|
||||
}
|
||||
|
||||
select {
|
||||
case s.logCh <- log:
|
||||
default:
|
||||
s.logger.Warn("audit log dropped, channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// flushErr represents an error from flush, distinguishing between
|
||||
// permanent errors (bad data) and transient errors (network issues).
|
||||
type flushErr struct {
|
||||
err error
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (e *flushErr) Error() string { return e.err.Error() }
|
||||
|
||||
// flush sends the current batch of logs to the given connection.
|
||||
func flush(conn net.Conn, logs []*agentproto.BoundaryLog) *flushErr {
|
||||
if len(logs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &agentproto.ReportBoundaryLogsRequest{
|
||||
Logs: logs,
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return &flushErr{err: err, permanent: true}
|
||||
}
|
||||
|
||||
err = codec.WriteFrame(conn, codec.TagV1, data)
|
||||
if err != nil {
|
||||
return &flushErr{err: xerrors.Errorf("write frame: %x", err)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop handles the I/O to send audit logs to the agent.
|
||||
func (s *SocketAuditor) Loop(ctx context.Context) {
|
||||
var conn net.Conn
|
||||
batch := make([]*agentproto.BoundaryLog, 0, s.batchSize)
|
||||
t := time.NewTimer(0)
|
||||
t.Stop()
|
||||
|
||||
// connect attempts to establish a connection to the socket.
|
||||
connect := func() {
|
||||
if conn != nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
conn, err = s.dial()
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to connect to audit socket", "path", s.socketPath, "error", err)
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeConn closes the current connection if open.
|
||||
closeConn := func() {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// clearBatch resets the length of the batch and frees memory while preserving
|
||||
// the batch slice backing array.
|
||||
clearBatch := func() {
|
||||
for i := range len(batch) {
|
||||
batch[i] = nil
|
||||
}
|
||||
batch = batch[:0]
|
||||
}
|
||||
|
||||
// doFlush flushes the batch and handles errors by reconnecting.
|
||||
doFlush := func() {
|
||||
t.Stop()
|
||||
defer func() {
|
||||
if s.onFlushAttempt != nil {
|
||||
s.onFlushAttempt()
|
||||
}
|
||||
}()
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
connect()
|
||||
if conn == nil {
|
||||
// No connection: logs will be retried on next flush.
|
||||
s.logger.Warn("no connection to flush; resetting batch timer",
|
||||
"duration_sec", s.batchTimerDuration.Seconds(),
|
||||
"batch_size", len(batch))
|
||||
// Reset the timer so we aren't stuck waiting for the batch to fill
|
||||
// or a new log to arrive before the next attempt.
|
||||
t.Reset(s.batchTimerDuration)
|
||||
return
|
||||
}
|
||||
|
||||
if err := flush(conn, batch); err != nil {
|
||||
if err.permanent {
|
||||
// Data error: discard batch to avoid infinite retries.
|
||||
s.logger.Warn("dropping batch due to data error on flush attempt",
|
||||
"error", err, "batch_size", len(batch))
|
||||
clearBatch()
|
||||
} else {
|
||||
// Network error: close connection but keep batch and retry.
|
||||
s.logger.Warn("failed to flush audit logs; resetting batch timer to reconnect and retry",
|
||||
"error", err, "duration_sec", s.batchTimerDuration.Seconds(),
|
||||
"batch_size", len(batch))
|
||||
closeConn()
|
||||
// Reset the timer so we aren't stuck waiting for a new log to
|
||||
// arrive before the next attempt.
|
||||
t.Reset(s.batchTimerDuration)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
clearBatch()
|
||||
}
|
||||
|
||||
connect()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Drain any pending logs before the last flush. Not concerned about
|
||||
// growing the batch slice here since we're exiting.
|
||||
drain:
|
||||
for {
|
||||
select {
|
||||
case log := <-s.logCh:
|
||||
batch = append(batch, log)
|
||||
default:
|
||||
break drain
|
||||
}
|
||||
}
|
||||
|
||||
doFlush()
|
||||
closeConn()
|
||||
return
|
||||
case <-t.C:
|
||||
doFlush()
|
||||
case log := <-s.logCh:
|
||||
// If batch is at capacity, attempt flushing first and drop the log if
|
||||
// the batch still full.
|
||||
if len(batch) >= s.batchSize {
|
||||
doFlush()
|
||||
if len(batch) >= s.batchSize {
|
||||
s.logger.Warn("audit log dropped, batch full")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
batch = append(batch, log)
|
||||
|
||||
if len(batch) == 1 {
|
||||
t.Reset(s.batchTimerDuration)
|
||||
}
|
||||
|
||||
if len(batch) >= s.batchSize {
|
||||
doFlush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy/codec"
|
||||
agentproto "github.com/coder/coder/v2/agent/proto"
|
||||
)
|
||||
|
||||
func TestSocketAuditor_AuditRequest_QueuesLog(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
auditor.AuditRequest(Request{
|
||||
Method: "GET",
|
||||
URL: "https://example.com",
|
||||
Host: "example.com",
|
||||
Allowed: true,
|
||||
Rule: "allow-all",
|
||||
})
|
||||
|
||||
select {
|
||||
case log := <-auditor.logCh:
|
||||
if log.Allowed != true {
|
||||
t.Errorf("expected Allowed=true, got %v", log.Allowed)
|
||||
}
|
||||
httpReq := log.GetHttpRequest()
|
||||
if httpReq == nil {
|
||||
t.Fatal("expected HttpRequest, got nil")
|
||||
}
|
||||
if httpReq.Method != "GET" {
|
||||
t.Errorf("expected Method=GET, got %s", httpReq.Method)
|
||||
}
|
||||
if httpReq.Url != "https://example.com" {
|
||||
t.Errorf("expected URL=https://example.com, got %s", httpReq.Url)
|
||||
}
|
||||
// Rule should be set for allowed requests
|
||||
if httpReq.MatchedRule != "allow-all" {
|
||||
t.Errorf("unexpected MatchedRule %v", httpReq.MatchedRule)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected log in channel, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_AuditRequest_AllowIncludesRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
auditor.AuditRequest(Request{
|
||||
Method: "POST",
|
||||
URL: "https://evil.com",
|
||||
Host: "evil.com",
|
||||
Allowed: true,
|
||||
Rule: "allow-evil",
|
||||
})
|
||||
|
||||
select {
|
||||
case log := <-auditor.logCh:
|
||||
if log.Allowed != true {
|
||||
t.Errorf("expected Allowed=false, got %v", log.Allowed)
|
||||
}
|
||||
httpReq := log.GetHttpRequest()
|
||||
if httpReq == nil {
|
||||
t.Fatal("expected HttpRequest, got nil")
|
||||
}
|
||||
if httpReq.MatchedRule != "allow-evil" {
|
||||
t.Errorf("expected MatchedRule=allow-evil, got %s", httpReq.MatchedRule)
|
||||
}
|
||||
default:
|
||||
t.Fatal("expected log in channel, got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_AuditRequest_DropsWhenFull(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor := setupSocketAuditor(t)
|
||||
|
||||
// Fill the channel (capacity is 2*batchSize = 20)
|
||||
for i := 0; i < 2*auditor.batchSize; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
}
|
||||
|
||||
// This should not block and drop the log
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://dropped.com", Allowed: true})
|
||||
|
||||
// Drain the channel and verify all entries are from the original batch (dropped.com was dropped)
|
||||
for i := 0; i < 2*auditor.batchSize; i++ {
|
||||
v := <-auditor.logCh
|
||||
resource, ok := v.Resource.(*agentproto.BoundaryLog_HttpRequest_)
|
||||
if !ok {
|
||||
t.Fatal("unexpected resource type")
|
||||
}
|
||||
if resource.HttpRequest.Url != "https://example.com" {
|
||||
t.Errorf("expected batch to be FIFO, got %s", resource.HttpRequest.Url)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case v := <-auditor.logCh:
|
||||
t.Errorf("expected empty channel, got %v", v)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnBatchSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
auditor.batchTimerDuration = time.Hour // Ensure timer doesn't interfere with the test
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// Send exactly a full batch of logs to trigger a flush
|
||||
for i := 0; i < auditor.batchSize; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
}
|
||||
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != auditor.batchSize {
|
||||
t.Errorf("expected %d logs, got %d", auditor.batchSize, len(req.Logs))
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
auditor.batchTimerDuration = 3 * time.Second
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// A single log should start the timer
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
|
||||
// Should flush after the timer duration elapses
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != 1 {
|
||||
t.Errorf("expected 1 log, got %d", len(req.Logs))
|
||||
}
|
||||
case <-time.After(2 * auditor.batchTimerDuration):
|
||||
t.Fatal("timeout waiting for timer flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_FlushesOnContextCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auditor, serverConn := setupTestAuditor(t)
|
||||
// Make the timer long to always exercise the context cancellation case
|
||||
auditor.batchTimerDuration = time.Hour
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
auditor.Loop(ctx)
|
||||
}()
|
||||
|
||||
// Send a log but don't fill the batch
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://example.com", Allowed: true})
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != 1 {
|
||||
t.Errorf("expected 1 log, got %d", len(req.Logs))
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for shutdown flush")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSocketAuditor_Loop_RetriesOnConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
err := clientConn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("close client connection: %v", err)
|
||||
}
|
||||
err = serverConn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("close server connection: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
var dialCount atomic.Int32
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
auditor := &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
// First dial attempt fails, subsequent ones succeed
|
||||
if dialCount.Add(1) == 1 {
|
||||
return nil, xerrors.New("connection refused")
|
||||
}
|
||||
return clientConn, nil
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: time.Hour, // Ensure timer doesn't interfere with the test
|
||||
}
|
||||
|
||||
// Set up hook to detect flush attempts
|
||||
flushed := make(chan struct{}, 1)
|
||||
auditor.onFlushAttempt = func() {
|
||||
select {
|
||||
case flushed <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
received := make(chan *agentproto.ReportBoundaryLogsRequest, 1)
|
||||
go readFromConn(t, serverConn, received)
|
||||
|
||||
go auditor.Loop(t.Context())
|
||||
|
||||
// Send batchSize+1 logs so we can verify the last log here gets dropped.
|
||||
for i := 0; i < auditor.batchSize+1; i++ {
|
||||
auditor.AuditRequest(Request{Method: "GET", URL: "https://servernotup.com", Allowed: true})
|
||||
}
|
||||
|
||||
// Wait for the first flush attempt (which will fail)
|
||||
select {
|
||||
case <-flushed:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for first flush attempt")
|
||||
}
|
||||
|
||||
// Send one more log - batch is at capacity, so this triggers flush first
|
||||
// The flush succeeds (dial now works), sending the retained batch.
|
||||
auditor.AuditRequest(Request{Method: "POST", URL: "https://serverup.com", Allowed: true})
|
||||
|
||||
// Should receive the retained batch (the new log goes into a fresh batch)
|
||||
select {
|
||||
case req := <-received:
|
||||
if len(req.Logs) != auditor.batchSize {
|
||||
t.Errorf("expected %d logs from retry, got %d", auditor.batchSize, len(req.Logs))
|
||||
}
|
||||
for _, log := range req.Logs {
|
||||
resource, ok := log.Resource.(*agentproto.BoundaryLog_HttpRequest_)
|
||||
if !ok {
|
||||
t.Fatal("unexpected resource type")
|
||||
}
|
||||
if resource.HttpRequest.Url != "https://servernotup.com" {
|
||||
t.Errorf("expected URL https://servernotup.com, got %v", resource.HttpRequest.Url)
|
||||
}
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for retry flush")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlush_EmptyBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := flush(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for empty batch, got %v", err)
|
||||
}
|
||||
|
||||
err = flush(nil, []*agentproto.BoundaryLog{})
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for empty slice, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// setupSocketAuditor creates a SocketAuditor for tests that only exercise
|
||||
// the queueing behavior (no connection needed).
|
||||
func setupSocketAuditor(t *testing.T) *SocketAuditor {
|
||||
t.Helper()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
return &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return nil, xerrors.New("not connected")
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
}
|
||||
}
|
||||
|
||||
// setupTestAuditor creates a SocketAuditor with an in-memory connection using
|
||||
// net.Pipe(). Returns the auditor and the server-side connection for reading.
|
||||
func setupTestAuditor(t *testing.T) (*SocketAuditor, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
err := clientConn.Close()
|
||||
if err != nil {
|
||||
t.Error("Failed to close client connection", "error", err)
|
||||
}
|
||||
err = serverConn.Close()
|
||||
if err != nil {
|
||||
t.Error("Failed to close server connection", "error", err)
|
||||
}
|
||||
})
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
auditor := &SocketAuditor{
|
||||
dial: func() (net.Conn, error) {
|
||||
return clientConn, nil
|
||||
},
|
||||
logger: logger,
|
||||
logCh: make(chan *agentproto.BoundaryLog, 2*defaultBatchSize),
|
||||
batchSize: defaultBatchSize,
|
||||
batchTimerDuration: defaultBatchTimerDuration,
|
||||
}
|
||||
|
||||
return auditor, serverConn
|
||||
}
|
||||
|
||||
// readFromConn reads length-prefixed protobuf messages from a connection and
|
||||
// sends them to the received channel.
|
||||
func readFromConn(t *testing.T, conn net.Conn, received chan<- *agentproto.ReportBoundaryLogsRequest) {
|
||||
t.Helper()
|
||||
|
||||
buf := make([]byte, 1<<10)
|
||||
for {
|
||||
tag, data, err := codec.ReadFrame(conn, buf)
|
||||
if err != nil {
|
||||
return // connection closed
|
||||
}
|
||||
|
||||
if tag != codec.TagV1 {
|
||||
t.Errorf("invalid tag: %d", tag)
|
||||
}
|
||||
|
||||
var req agentproto.ReportBoundaryLogsRequest
|
||||
if err := proto.Unmarshal(data, &req); err != nil {
|
||||
t.Errorf("failed to unmarshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
received <- &req
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
//go:build linux
|
||||
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
|
||||
package boundary
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/coder/coder/v2/agent/boundarylogproxy"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/log"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/run"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// printVersion prints version information.
|
||||
func printVersion(version string) {
|
||||
fmt.Println(version)
|
||||
}
|
||||
|
||||
// NewCommand creates and returns the root serpent command
|
||||
func NewCommand(version string) *serpent.Command {
|
||||
// To make the top level boundary command, we just make some minor changes to the base command
|
||||
cmd := BaseCommand(version)
|
||||
cmd.Use = "boundary [flags] -- command [args...]" // Add the flags and args pieces to usage.
|
||||
|
||||
// Add example usage to the long description. This is different from usage as a subcommand because it
|
||||
// may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder boundary` vs. `boundary`).
|
||||
cmd.Long += `Examples:
|
||||
# Allow only requests to github.com
|
||||
boundary --allow "domain=github.com" -- curl https://github.com
|
||||
|
||||
# Monitor all requests to specific domains (allow only those)
|
||||
boundary --allow "domain=github.com path=/api/issues/*" --allow "method=GET,HEAD domain=github.com" -- npm install
|
||||
|
||||
# Use allowlist from config file with additional CLI allow rules
|
||||
boundary --allow "domain=example.com" -- curl https://example.com
|
||||
|
||||
# Block everything by default (implicit)`
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Base command returns the boundary serpent command without the information involved in making it the
|
||||
// *top level* serpent command. We are creating this split to make it easier to integrate into the coder
|
||||
// CLI if needed.
|
||||
func BaseCommand(version string) *serpent.Command {
|
||||
cliConfig := config.CliConfig{}
|
||||
var showVersion serpent.Bool
|
||||
|
||||
// Set default config path if file exists - serpent will load it automatically
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
defaultPath := filepath.Join(home, ".config", "coder_boundary", "config.yaml")
|
||||
if _, err := os.Stat(defaultPath); err == nil {
|
||||
cliConfig.Config = serpent.YAMLConfigPath(defaultPath)
|
||||
}
|
||||
}
|
||||
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
|
||||
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
|
||||
Options: []serpent.Option{
|
||||
{
|
||||
Flag: "config",
|
||||
Env: "BOUNDARY_CONFIG",
|
||||
Description: "Path to YAML config file.",
|
||||
Value: &cliConfig.Config,
|
||||
YAML: "",
|
||||
},
|
||||
{
|
||||
Flag: "allow",
|
||||
Env: "BOUNDARY_ALLOW",
|
||||
Description: "Allow rule (repeatable). These are merged with allowlist from config file. Format: \"pattern\" or \"METHOD[,METHOD] pattern\".",
|
||||
Value: &cliConfig.AllowStrings,
|
||||
YAML: "", // CLI only, not loaded from YAML
|
||||
},
|
||||
{
|
||||
Flag: "allowlist",
|
||||
Description: "Allowlist rules from config file (YAML only).",
|
||||
Value: &cliConfig.AllowListStrings,
|
||||
YAML: "allowlist",
|
||||
Hidden: true, // Hidden because it's primarily for YAML config
|
||||
},
|
||||
{
|
||||
Flag: "log-level",
|
||||
Env: "BOUNDARY_LOG_LEVEL",
|
||||
Description: "Set log level (error, warn, info, debug).",
|
||||
Default: "warn",
|
||||
Value: &cliConfig.LogLevel,
|
||||
YAML: "log_level",
|
||||
},
|
||||
{
|
||||
Flag: "log-dir",
|
||||
Env: "BOUNDARY_LOG_DIR",
|
||||
Description: "Set a directory to write logs to rather than stderr.",
|
||||
Value: &cliConfig.LogDir,
|
||||
YAML: "log_dir",
|
||||
},
|
||||
{
|
||||
Flag: "proxy-port",
|
||||
Env: "PROXY_PORT",
|
||||
Description: "Set a port for HTTP proxy.",
|
||||
Default: "8080",
|
||||
Value: &cliConfig.ProxyPort,
|
||||
YAML: "proxy_port",
|
||||
},
|
||||
{
|
||||
Flag: "pprof",
|
||||
Env: "BOUNDARY_PPROF",
|
||||
Description: "Enable pprof profiling server.",
|
||||
Value: &cliConfig.PprofEnabled,
|
||||
YAML: "pprof_enabled",
|
||||
},
|
||||
{
|
||||
Flag: "pprof-port",
|
||||
Env: "BOUNDARY_PPROF_PORT",
|
||||
Description: "Set port for pprof profiling server.",
|
||||
Default: "6060",
|
||||
Value: &cliConfig.PprofPort,
|
||||
YAML: "pprof_port",
|
||||
},
|
||||
{
|
||||
Flag: "configure-dns-for-local-stub-resolver",
|
||||
Env: "BOUNDARY_CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER",
|
||||
Description: "Configure DNS for local stub resolver (e.g., systemd-resolved). Only needed when /etc/resolv.conf contains nameserver 127.0.0.53.",
|
||||
Value: &cliConfig.ConfigureDNSForLocalStubResolver,
|
||||
YAML: "configure_dns_for_local_stub_resolver",
|
||||
},
|
||||
{
|
||||
Flag: "jail-type",
|
||||
Env: "BOUNDARY_JAIL_TYPE",
|
||||
Description: "Jail type to use for network isolation. Options: nsjail (default), landjail.",
|
||||
Default: "nsjail",
|
||||
Value: &cliConfig.JailType,
|
||||
YAML: "jail_type",
|
||||
},
|
||||
{
|
||||
Flag: "disable-audit-logs",
|
||||
Env: "DISABLE_AUDIT_LOGS",
|
||||
Description: "Disable sending of audit logs to the workspace agent when set to true.",
|
||||
Value: &cliConfig.DisableAuditLogs,
|
||||
YAML: "disable_audit_logs",
|
||||
},
|
||||
{
|
||||
Flag: "log-proxy-socket-path",
|
||||
Description: "Path to the socket where the boundary log proxy server listens for audit logs.",
|
||||
// Important: this default must be the same default path used by the
|
||||
// workspace agent to ensure agreement of the default socket path without
|
||||
// explicit configuration.
|
||||
Default: boundarylogproxy.DefaultSocketPath(),
|
||||
// Important: this must be the same variable name used by the workspace agent
|
||||
// to allow a single environment variable to configure both boundary and the
|
||||
// workspace agent.
|
||||
Env: "CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH",
|
||||
Value: &cliConfig.LogProxySocketPath,
|
||||
YAML: "", // CLI only, not loaded from YAML
|
||||
},
|
||||
{
|
||||
Flag: "version",
|
||||
Description: "Print version information and exit.",
|
||||
Value: &showVersion,
|
||||
YAML: "", // CLI only
|
||||
},
|
||||
},
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
// Handle --version flag early
|
||||
if showVersion.Value() {
|
||||
printVersion(version)
|
||||
return nil
|
||||
}
|
||||
appConfig, err := config.NewAppConfigFromCliConfig(cliConfig, inv.Args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse cli config file: %v", err)
|
||||
}
|
||||
|
||||
// Get command arguments
|
||||
if len(appConfig.TargetCMD) == 0 {
|
||||
return fmt.Errorf("no command specified")
|
||||
}
|
||||
|
||||
logger, err := log.SetupLogging(appConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not set up logging: %v", err)
|
||||
}
|
||||
|
||||
appConfigInJSON, err := json.Marshal(appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Debug("Application config", "config", appConfigInJSON)
|
||||
|
||||
return run.Run(inv.Context(), logger, appConfig)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
//go:build !linux
|
||||
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
|
||||
package boundary
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// BaseCommand returns the boundary serpent command. On non-Linux platforms,
|
||||
// boundary is not supported and returns an error.
|
||||
func BaseCommand(_ string) *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "boundary",
|
||||
Short: "Network isolation tool for monitoring and restricting HTTP/HTTPS requests",
|
||||
Long: `boundary creates an isolated network environment for target processes, intercepting HTTP/HTTPS traffic through a transparent proxy that enforces user-defined allow rules.`,
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
return xerrors.Errorf("boundary is only supported on Linux (current OS: %s)", runtime.GOOS)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
// JailType represents the type of jail to use for network isolation
|
||||
type JailType string
|
||||
|
||||
const (
|
||||
NSJailType JailType = "nsjail"
|
||||
LandjailType JailType = "landjail"
|
||||
)
|
||||
|
||||
func NewJailTypeFromString(str string) (JailType, error) {
|
||||
switch str {
|
||||
case "nsjail":
|
||||
return NSJailType, nil
|
||||
case "landjail":
|
||||
return LandjailType, nil
|
||||
default:
|
||||
return NSJailType, xerrors.Errorf("invalid JailType: %s", str)
|
||||
}
|
||||
}
|
||||
|
||||
// AllowStringsArray is a custom type that implements pflag.Value to support
|
||||
// repeatable --allow flags without splitting on commas. This allows comma-separated
|
||||
// paths within a single allow rule (e.g., "path=/todos/1,/todos/2").
|
||||
type AllowStringsArray []string
|
||||
|
||||
var _ pflag.Value = (*AllowStringsArray)(nil)
|
||||
|
||||
// Set implements pflag.Value. It appends the value to the slice without splitting on commas.
|
||||
func (a *AllowStringsArray) Set(value string) error {
|
||||
*a = append(*a, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements pflag.Value.
|
||||
func (a AllowStringsArray) String() string {
|
||||
return strings.Join(a, ",")
|
||||
}
|
||||
|
||||
// Type implements pflag.Value.
|
||||
func (a AllowStringsArray) Type() string {
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Value returns the underlying slice of strings.
|
||||
func (a AllowStringsArray) Value() []string {
|
||||
return []string(a)
|
||||
}
|
||||
|
||||
type CliConfig struct {
|
||||
Config serpent.YAMLConfigPath `yaml:"-"`
|
||||
AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file
|
||||
AllowStrings AllowStringsArray `yaml:"-"` // From CLI flags only
|
||||
LogLevel serpent.String `yaml:"log_level"`
|
||||
LogDir serpent.String `yaml:"log_dir"`
|
||||
ProxyPort serpent.Int64 `yaml:"proxy_port"`
|
||||
PprofEnabled serpent.Bool `yaml:"pprof_enabled"`
|
||||
PprofPort serpent.Int64 `yaml:"pprof_port"`
|
||||
ConfigureDNSForLocalStubResolver serpent.Bool `yaml:"configure_dns_for_local_stub_resolver"`
|
||||
JailType serpent.String `yaml:"jail_type"`
|
||||
DisableAuditLogs serpent.Bool `yaml:"disable_audit_logs"`
|
||||
LogProxySocketPath serpent.String `yaml:"log_proxy_socket_path"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
AllowRules []string
|
||||
LogLevel string
|
||||
LogDir string
|
||||
ProxyPort int64
|
||||
PprofEnabled bool
|
||||
PprofPort int64
|
||||
ConfigureDNSForLocalStubResolver bool
|
||||
JailType JailType
|
||||
TargetCMD []string
|
||||
UserInfo *UserInfo
|
||||
DisableAuditLogs bool
|
||||
LogProxySocketPath string
|
||||
}
|
||||
|
||||
func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) {
|
||||
// Merge allowlist from config file with allow from CLI flags
|
||||
allowListStrings := cfg.AllowListStrings.Value()
|
||||
allowStrings := cfg.AllowStrings.Value()
|
||||
|
||||
// Combine allowlist (config file) with allow (CLI flags)
|
||||
allAllowStrings := append(allowListStrings, allowStrings...)
|
||||
|
||||
jailType, err := NewJailTypeFromString(cfg.JailType.Value())
|
||||
if err != nil {
|
||||
return AppConfig{}, err
|
||||
}
|
||||
|
||||
userInfo := GetUserInfo()
|
||||
|
||||
return AppConfig{
|
||||
AllowRules: allAllowStrings,
|
||||
LogLevel: cfg.LogLevel.Value(),
|
||||
LogDir: cfg.LogDir.Value(),
|
||||
ProxyPort: cfg.ProxyPort.Value(),
|
||||
PprofEnabled: cfg.PprofEnabled.Value(),
|
||||
PprofPort: cfg.PprofPort.Value(),
|
||||
ConfigureDNSForLocalStubResolver: cfg.ConfigureDNSForLocalStubResolver.Value(),
|
||||
JailType: jailType,
|
||||
TargetCMD: targetCMD,
|
||||
UserInfo: userInfo,
|
||||
DisableAuditLogs: cfg.DisableAuditLogs.Value(),
|
||||
LogProxySocketPath: cfg.LogProxySocketPath.Value(),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
CAKeyName = "ca-key.pem"
|
||||
CACertName = "ca-cert.pem"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
SudoUser string
|
||||
Uid int
|
||||
Gid int
|
||||
HomeDir string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
// GetUserInfo returns information about the current user, handling sudo scenarios
|
||||
func GetUserInfo() *UserInfo {
|
||||
// Only consider SUDO_USER if we're actually running with elevated privileges
|
||||
// In environments like Coder workspaces, SUDO_USER may be set to 'root'
|
||||
// but we're not actually running under sudo
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" && os.Geteuid() == 0 && sudoUser != "root" {
|
||||
// We're actually running under sudo with a non-root original user
|
||||
user, err := user.Lookup(sudoUser)
|
||||
if err != nil {
|
||||
return getCurrentUserInfo() // Fallback to current user
|
||||
}
|
||||
|
||||
uid, _ := strconv.Atoi(os.Getenv("SUDO_UID"))
|
||||
gid, _ := strconv.Atoi(os.Getenv("SUDO_GID"))
|
||||
|
||||
// If we couldn't get UID/GID from env, parse from user info
|
||||
if uid == 0 {
|
||||
if parsedUID, err := strconv.Atoi(user.Uid); err == nil {
|
||||
uid = parsedUID
|
||||
}
|
||||
}
|
||||
if gid == 0 {
|
||||
if parsedGID, err := strconv.Atoi(user.Gid); err == nil {
|
||||
gid = parsedGID
|
||||
}
|
||||
}
|
||||
|
||||
configDir := getConfigDir(user.HomeDir)
|
||||
|
||||
return &UserInfo{
|
||||
SudoUser: sudoUser,
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
HomeDir: user.HomeDir,
|
||||
ConfigDir: configDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Not actually running under sudo, use current user
|
||||
return getCurrentUserInfo()
|
||||
}
|
||||
|
||||
// getCurrentUserInfo gets information for the current user
|
||||
func getCurrentUserInfo() *UserInfo {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
// Fallback with empty values if we can't get user info
|
||||
return &UserInfo{}
|
||||
}
|
||||
|
||||
uid, _ := strconv.Atoi(currentUser.Uid)
|
||||
gid, _ := strconv.Atoi(currentUser.Gid)
|
||||
|
||||
configDir := getConfigDir(currentUser.HomeDir)
|
||||
|
||||
return &UserInfo{
|
||||
SudoUser: currentUser.Username,
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
HomeDir: currentUser.HomeDir,
|
||||
ConfigDir: configDir,
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigDir determines the config directory based on XDG_CONFIG_HOME or fallback
|
||||
func getConfigDir(homeDir string) string {
|
||||
// Use XDG_CONFIG_HOME if set, otherwise fallback to ~/.config/coder_boundary
|
||||
if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" {
|
||||
return filepath.Join(xdgConfigHome, "coder_boundary")
|
||||
}
|
||||
return filepath.Join(homeDir, ".config", "coder_boundary")
|
||||
}
|
||||
|
||||
func (u *UserInfo) CAKeyPath() string {
|
||||
return filepath.Join(u.ConfigDir, CAKeyName)
|
||||
}
|
||||
|
||||
func (u *UserInfo) CACertPath() string {
|
||||
return filepath.Join(u.ConfigDir, CACertName)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/landlock-lsm/go-landlock/landlock"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/util"
|
||||
)
|
||||
|
||||
type LandlockConfig struct {
|
||||
// TODO(yevhenii):
|
||||
// - should it be able to bind to any port?
|
||||
// - should it be able to connect to any port on localhost?
|
||||
// BindTCPPorts []int
|
||||
ConnectTCPPorts []int
|
||||
}
|
||||
|
||||
func ApplyLandlockRestrictions(logger *slog.Logger, cfg LandlockConfig) error {
|
||||
// Get the Landlock version which works for Kernel 6.7+
|
||||
llCfg := landlock.V4
|
||||
|
||||
// Collect our rules
|
||||
var netRules []landlock.Rule
|
||||
|
||||
// Add rules for TCP connections
|
||||
for _, port := range cfg.ConnectTCPPorts {
|
||||
logger.Debug("Adding TCP connect port", "port", port)
|
||||
netRules = append(netRules, landlock.ConnectTCP(uint16(port)))
|
||||
}
|
||||
|
||||
err := llCfg.RestrictNet(netRules...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply Landlock network restrictions: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunChild(logger *slog.Logger, config config.AppConfig) error {
|
||||
landjailCfg := LandlockConfig{
|
||||
ConnectTCPPorts: []int{int(config.ProxyPort)},
|
||||
}
|
||||
|
||||
err := ApplyLandlockRestrictions(logger, landjailCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to apply Landlock network restrictions: %v", err)
|
||||
}
|
||||
|
||||
// Build command
|
||||
cmd := exec.Command(config.TargetCMD[0], config.TargetCMD[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
logger.Info("Executing target command", "command", config.TargetCMD)
|
||||
|
||||
// Run the command - this will block until it completes
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
return fmt.Errorf("command exited with code %d", exitCode)
|
||||
}
|
||||
// This is an unexpected error
|
||||
logger.Error("Command execution failed", "error", err)
|
||||
return fmt.Errorf("command execution failed: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Command completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns environment variables intended to be set on the child process,
|
||||
// so they can later be inherited by the target process.
|
||||
func getEnvsForTargetProcess(configDir string, caCertPath string, httpProxyPort int) []string {
|
||||
e := os.Environ()
|
||||
|
||||
proxyAddr := fmt.Sprintf("http://localhost:%d", httpProxyPort)
|
||||
e = util.MergeEnvs(e, map[string]string{
|
||||
// Set standard CA certificate environment variables for common tools
|
||||
// This makes tools like curl, git, etc. trust our dynamically generated CA
|
||||
"SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools
|
||||
"SSL_CERT_DIR": configDir, // OpenSSL certificate directory
|
||||
"CURL_CA_BUNDLE": caCertPath, // curl
|
||||
"GIT_SSL_CAINFO": caCertPath, // Git
|
||||
"REQUESTS_CA_BUNDLE": caCertPath, // Python requests
|
||||
"NODE_EXTRA_CA_CERTS": caCertPath, // Node.js
|
||||
|
||||
"HTTP_PROXY": proxyAddr,
|
||||
"HTTPS_PROXY": proxyAddr,
|
||||
"http_proxy": proxyAddr,
|
||||
"https_proxy": proxyAddr,
|
||||
})
|
||||
|
||||
return e
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/proxy"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
)
|
||||
|
||||
type LandJail struct {
|
||||
proxyServer *proxy.Server
|
||||
logger *slog.Logger
|
||||
config config.AppConfig
|
||||
}
|
||||
|
||||
func NewLandJail(
|
||||
ruleEngine rulesengine.Engine,
|
||||
auditor audit.Auditor,
|
||||
tlsConfig *tls.Config,
|
||||
logger *slog.Logger,
|
||||
config config.AppConfig,
|
||||
) (*LandJail, error) {
|
||||
// Create proxy server
|
||||
proxyServer := proxy.NewProxyServer(proxy.Config{
|
||||
HTTPPort: int(config.ProxyPort),
|
||||
RuleEngine: ruleEngine,
|
||||
Auditor: auditor,
|
||||
Logger: logger,
|
||||
TLSConfig: tlsConfig,
|
||||
PprofEnabled: config.PprofEnabled,
|
||||
PprofPort: int(config.PprofPort),
|
||||
})
|
||||
|
||||
return &LandJail{
|
||||
config: config,
|
||||
proxyServer: proxyServer,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *LandJail) Run(ctx context.Context) error {
|
||||
b.logger.Info("Start landjail manager")
|
||||
err := b.startProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start landjail manager: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
b.logger.Info("Stop landjail manager")
|
||||
err := b.stopProxy()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop landjail manager", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
err := b.RunChildProcess(os.Args)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to run child process", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Setup signal handling BEFORE any setup
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal or context cancellation
|
||||
select {
|
||||
case sig := <-sigChan:
|
||||
b.logger.Info("Received signal, shutting down...", "signal", sig)
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
// Context canceled by command completion
|
||||
b.logger.Info("Command completed, shutting down...")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) RunChildProcess(command []string) error {
|
||||
childCmd := b.getChildCommand(command)
|
||||
|
||||
b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " "))
|
||||
err := childCmd.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Command failed to start", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
b.logger.Debug("waiting on a child process to finish")
|
||||
err = childCmd.Wait()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
// Log at debug level for non-zero exits (normal behavior)
|
||||
b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
return err
|
||||
}
|
||||
|
||||
// This is an unexpected error (not just a non-zero exit)
|
||||
b.logger.Error("Command execution failed", "error", err)
|
||||
return err
|
||||
}
|
||||
b.logger.Debug("Command completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) getChildCommand(command []string) *exec.Cmd {
|
||||
cmd := exec.Command(command[0], command[1:]...)
|
||||
// Set env vars for the child process; they will be inherited by the target process.
|
||||
cmd.Env = getEnvsForTargetProcess(b.config.UserInfo.ConfigDir, b.config.UserInfo.CACertPath(), int(b.config.ProxyPort))
|
||||
cmd.Env = append(cmd.Env, "CHILD=true")
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (b *LandJail) startProxy() error {
|
||||
// Start proxy server in background
|
||||
err := b.proxyServer.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Proxy server error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Give proxy time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *LandJail) stopProxy() error {
|
||||
// Stop proxy server
|
||||
if b.proxyServer != nil {
|
||||
err := b.proxyServer.Stop()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop proxy server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/tls"
|
||||
)
|
||||
|
||||
func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if len(config.AllowRules) == 0 {
|
||||
logger.Warn("No allow rules specified; all network traffic will be denied by default")
|
||||
}
|
||||
|
||||
// Parse allow rules
|
||||
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse allow rules", "error", err)
|
||||
return fmt.Errorf("failed to parse allow rules: %v", err)
|
||||
}
|
||||
|
||||
// Create rule engine
|
||||
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
|
||||
|
||||
// Create auditor
|
||||
auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup auditor: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate manager
|
||||
certManager, err := tls.NewCertificateManager(tls.Config{
|
||||
Logger: logger,
|
||||
ConfigDir: config.UserInfo.ConfigDir,
|
||||
Uid: config.UserInfo.Uid,
|
||||
Gid: config.UserInfo.Gid,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to create certificate manager", "error", err)
|
||||
return fmt.Errorf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
|
||||
// Setup TLS to get cert path for jailer
|
||||
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
|
||||
}
|
||||
|
||||
landjail, err := NewLandJail(ruleEngine, auditor, tlsConfig, logger, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create landjail: %v", err)
|
||||
}
|
||||
|
||||
return landjail.Run(ctx)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//go:build linux
|
||||
|
||||
package landjail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
)
|
||||
|
||||
func isChild() bool {
|
||||
return os.Getenv("CHILD") == "true"
|
||||
}
|
||||
|
||||
// Run is the main entry point that determines whether to execute as a parent or child process.
|
||||
// If running as a child (CHILD env var is set), it applies landlock restrictions
|
||||
// and executes the target command. Otherwise, it runs as the parent process, sets up the proxy server,
|
||||
// and manages the child process lifecycle.
|
||||
func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if isChild() {
|
||||
return RunChild(logger, config)
|
||||
}
|
||||
|
||||
return RunParent(ctx, logger, config)
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package log
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
)
|
||||
|
||||
// SetupLogging creates a slog logger with the specified level
|
||||
func SetupLogging(config config.AppConfig) (*slog.Logger, error) {
|
||||
var level slog.Level
|
||||
switch strings.ToLower(config.LogLevel) {
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
case "warn":
|
||||
level = slog.LevelWarn
|
||||
case "info":
|
||||
level = slog.LevelInfo
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
default:
|
||||
level = slog.LevelWarn // Default to warn if invalid level
|
||||
}
|
||||
|
||||
logTarget := os.Stderr
|
||||
|
||||
logDir := config.LogDir
|
||||
if logDir != "" {
|
||||
// Set up the logging directory if it doesn't exist yet
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return nil, xerrors.Errorf("could not set up log dir %s: %v", logDir, err)
|
||||
}
|
||||
|
||||
// Create a logfile (timestamp and pid to avoid race conditions with multiple boundary calls running)
|
||||
logFilePath := fmt.Sprintf("boundary-%s-%d.log",
|
||||
time.Now().Format("2006-01-02_15-04-05"),
|
||||
os.Getpid())
|
||||
|
||||
logFile, err := os.Create(filepath.Join(logDir, logFilePath))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("could not create log file %s: %v", logFilePath, err)
|
||||
}
|
||||
|
||||
// Set the log target to the file rather than stderr.
|
||||
logTarget = logFile
|
||||
}
|
||||
|
||||
// Create a standard slog logger with the appropriate level
|
||||
handler := slog.NewTextHandler(logTarget, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
})
|
||||
|
||||
return slog.New(handler), nil
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail_manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
|
||||
)
|
||||
|
||||
// waitForInterface waits for a network interface to appear in the namespace.
|
||||
// It retries checking for the interface with exponential backoff up to the specified timeout.
|
||||
func waitForInterface(interfaceName string, timeout time.Duration) error {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.InitialInterval = 50 * time.Millisecond
|
||||
b.MaxInterval = 500 * time.Millisecond
|
||||
b.Multiplier = 2.0
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
operation := func() (bool, error) {
|
||||
cmd := exec.Command("ip", "link", "show", interfaceName)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
AmbientCaps: []uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
}
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("interface %s not found: %w", interfaceName, err)
|
||||
}
|
||||
// Interface exists
|
||||
return true, nil
|
||||
}
|
||||
|
||||
_, err := backoff.Retry(ctx, operation, backoff.WithBackOff(b))
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s did not appear within %v: %w", interfaceName, timeout, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunChild(logger *slog.Logger, targetCMD []string) error {
|
||||
logger.Info("boundary CHILD process is started")
|
||||
|
||||
vethNetJail := os.Getenv("VETH_JAIL_NAME")
|
||||
if vethNetJail == "" {
|
||||
return fmt.Errorf("VETH_JAIL_NAME environment variable is not set")
|
||||
}
|
||||
|
||||
// Wait for the veth interface to be moved into the namespace by the parent process
|
||||
if err := waitForInterface(vethNetJail, 5*time.Second); err != nil {
|
||||
return fmt.Errorf("failed to wait for interface %s: %w", vethNetJail, err)
|
||||
}
|
||||
|
||||
err := nsjail.SetupChildNetworking(vethNetJail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup child networking: %v", err)
|
||||
}
|
||||
logger.Info("child networking is successfully configured")
|
||||
|
||||
if os.Getenv("CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER") == "true" {
|
||||
err = nsjail.ConfigureDNSForLocalStubResolver()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure DNS in namespace: %v", err)
|
||||
}
|
||||
logger.Info("DNS in namespace is configured successfully")
|
||||
}
|
||||
|
||||
// Program to run
|
||||
bin := targetCMD[0]
|
||||
args := targetCMD[1:]
|
||||
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
// Log at debug level for non-zero exits (normal behavior)
|
||||
logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
// Exit with the same code as the command - don't log as error
|
||||
// This is normal behavior (commands can exit with any code)
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
// This is an unexpected error (not just a non-zero exit)
|
||||
// Only log actual errors like "command not found" or "permission denied"
|
||||
logger.Error("Command execution failed", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Command exited successfully
|
||||
logger.Debug("Command completed successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail_manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/proxy"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
)
|
||||
|
||||
type NSJailManager struct {
|
||||
jailer nsjail.Jailer
|
||||
proxyServer *proxy.Server
|
||||
logger *slog.Logger
|
||||
config config.AppConfig
|
||||
}
|
||||
|
||||
func NewNSJailManager(
|
||||
ruleEngine rulesengine.Engine,
|
||||
auditor audit.Auditor,
|
||||
tlsConfig *tls.Config,
|
||||
jailer nsjail.Jailer,
|
||||
logger *slog.Logger,
|
||||
config config.AppConfig,
|
||||
) (*NSJailManager, error) {
|
||||
// Create proxy server
|
||||
proxyServer := proxy.NewProxyServer(proxy.Config{
|
||||
HTTPPort: int(config.ProxyPort),
|
||||
RuleEngine: ruleEngine,
|
||||
Auditor: auditor,
|
||||
Logger: logger,
|
||||
TLSConfig: tlsConfig,
|
||||
PprofEnabled: config.PprofEnabled,
|
||||
PprofPort: int(config.PprofPort),
|
||||
})
|
||||
|
||||
return &NSJailManager{
|
||||
config: config,
|
||||
jailer: jailer,
|
||||
proxyServer: proxyServer,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *NSJailManager) Run(ctx context.Context) error {
|
||||
b.logger.Info("Start namespace-jail manager")
|
||||
err := b.setupHostAndStartProxy()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start namespace-jail manager: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
b.logger.Info("Stop namespace-jail manager")
|
||||
err := b.stopProxyAndCleanupHost()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop namespace-jail manager", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
b.RunChildProcess(os.Args)
|
||||
}()
|
||||
|
||||
// Setup signal handling BEFORE any setup
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal or context cancellation
|
||||
select {
|
||||
case sig := <-sigChan:
|
||||
b.logger.Info("Received signal, shutting down...", "signal", sig)
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
// Context canceled by command completion
|
||||
b.logger.Info("Command completed, shutting down...")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *NSJailManager) RunChildProcess(command []string) {
|
||||
cmd := b.jailer.Command(command)
|
||||
|
||||
b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " "))
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Command failed to start", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = b.jailer.ConfigureHostNsCommunication(cmd.Process.Pid)
|
||||
if err != nil {
|
||||
b.logger.Error("configuration after command execution failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
b.logger.Debug("waiting on a child process to finish")
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
// Check if this is a normal exit with non-zero status code
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
exitCode := exitError.ExitCode()
|
||||
// Log at debug level for non-zero exits (normal behavior)
|
||||
b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode)
|
||||
} else {
|
||||
// This is an unexpected error (not just a non-zero exit)
|
||||
b.logger.Error("Command execution failed", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
b.logger.Debug("Command completed successfully")
|
||||
}
|
||||
|
||||
func (b *NSJailManager) setupHostAndStartProxy() error {
|
||||
// Configure the jailer (network isolation)
|
||||
err := b.jailer.ConfigureHost()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start jailer: %v", err)
|
||||
}
|
||||
|
||||
// Start proxy server in background
|
||||
err = b.proxyServer.Start()
|
||||
if err != nil {
|
||||
b.logger.Error("Proxy server error", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Give proxy time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *NSJailManager) stopProxyAndCleanupHost() error {
|
||||
// Stop proxy server
|
||||
if b.proxyServer != nil {
|
||||
err := b.proxyServer.Stop()
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to stop proxy server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close jailer
|
||||
return b.jailer.Close()
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type command struct {
|
||||
description string
|
||||
cmd *exec.Cmd
|
||||
ambientCaps []uintptr
|
||||
|
||||
// If ignoreErr isn't empty and this specific error occurs, suppress it (don’t log it, don’t return it).
|
||||
ignoreErr string
|
||||
}
|
||||
|
||||
func newCommand(
|
||||
description string,
|
||||
cmd *exec.Cmd,
|
||||
ambientCaps []uintptr,
|
||||
) *command {
|
||||
return newCommandWithIgnoreErr(description, cmd, ambientCaps, "")
|
||||
}
|
||||
|
||||
func newCommandWithIgnoreErr(
|
||||
description string,
|
||||
cmd *exec.Cmd,
|
||||
ambientCaps []uintptr,
|
||||
ignoreErr string,
|
||||
) *command {
|
||||
return &command{
|
||||
description: description,
|
||||
cmd: cmd,
|
||||
ambientCaps: ambientCaps,
|
||||
ignoreErr: ignoreErr,
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *command) isIgnorableError(err string) bool {
|
||||
return cmd.ignoreErr != "" && strings.Contains(err, cmd.ignoreErr)
|
||||
}
|
||||
|
||||
type commandRunner struct {
|
||||
commands []*command
|
||||
}
|
||||
|
||||
func newCommandRunner(commands []*command) *commandRunner {
|
||||
return &commandRunner{
|
||||
commands: commands,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *commandRunner) run() error {
|
||||
for _, command := range r.commands {
|
||||
command.cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
AmbientCaps: command.ambientCaps,
|
||||
}
|
||||
|
||||
output, err := command.cmd.CombinedOutput()
|
||||
if err != nil && !command.isIgnorableError(err.Error()) && !command.isIgnorableError(string(output)) {
|
||||
return fmt.Errorf("failed to %s: %v, output: %s", command.description, err, output)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *commandRunner) runIgnoreErrors() error {
|
||||
for _, command := range r.commands {
|
||||
command.cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
AmbientCaps: command.ambientCaps,
|
||||
}
|
||||
|
||||
output, err := command.cmd.CombinedOutput()
|
||||
if err != nil && !command.isIgnorableError(err.Error()) && !command.isIgnorableError(string(output)) {
|
||||
log.Printf("err: %v", err)
|
||||
log.Printf("")
|
||||
|
||||
log.Printf("failed to %s: %v, output: %s", command.description, err, output)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/util"
|
||||
)
|
||||
|
||||
// Returns environment variables intended to be set on the child process,
|
||||
// so they can later be inherited by the target process.
|
||||
func getEnvsForTargetProcess(configDir string, caCertPath string) []string {
|
||||
e := os.Environ()
|
||||
|
||||
e = util.MergeEnvs(e, map[string]string{
|
||||
// Set standard CA certificate environment variables for common tools
|
||||
// This makes tools like curl, git, etc. trust our dynamically generated CA
|
||||
"SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools
|
||||
"SSL_CERT_DIR": configDir, // OpenSSL certificate directory
|
||||
"CURL_CA_BUNDLE": caCertPath, // curl
|
||||
"GIT_SSL_CAINFO": caCertPath, // Git
|
||||
"REQUESTS_CA_BUNDLE": caCertPath, // Python requests
|
||||
"NODE_EXTRA_CA_CERTS": caCertPath, // Node.js
|
||||
})
|
||||
|
||||
return e
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type Jailer interface {
|
||||
ConfigureHost() error
|
||||
Command(command []string) *exec.Cmd
|
||||
ConfigureHostNsCommunication(processPID int) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger *slog.Logger
|
||||
HttpProxyPort int
|
||||
HomeDir string
|
||||
ConfigDir string
|
||||
CACertPath string
|
||||
ConfigureDNSForLocalStubResolver bool
|
||||
}
|
||||
|
||||
// LinuxJail implements Jailer using Linux network namespaces
|
||||
type LinuxJail struct {
|
||||
logger *slog.Logger
|
||||
vethHostName string // Host-side veth interface name for iptables rules
|
||||
vethJailName string // Jail-side veth interface name for iptables rules
|
||||
httpProxyPort int
|
||||
configDir string
|
||||
caCertPath string
|
||||
configureDNSForLocalStubResolver bool
|
||||
}
|
||||
|
||||
func NewLinuxJail(config Config) (*LinuxJail, error) {
|
||||
return &LinuxJail{
|
||||
logger: config.Logger,
|
||||
httpProxyPort: config.HttpProxyPort,
|
||||
configDir: config.ConfigDir,
|
||||
caCertPath: config.CACertPath,
|
||||
configureDNSForLocalStubResolver: config.ConfigureDNSForLocalStubResolver,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConfigureBeforeCommandExecution prepares the jail environment before the target
|
||||
// process is launched. It sets environment variables, creates the veth pair, and
|
||||
// installs iptables rules on the host. At this stage, the target PID and its netns
|
||||
// are not yet known.
|
||||
func (l *LinuxJail) ConfigureHost() error {
|
||||
if err := l.configureHostNetworkBeforeCmdExec(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := l.configureIptables(); err != nil {
|
||||
return fmt.Errorf("failed to configure iptables: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Command returns an exec.Cmd configured to run within the network namespace.
|
||||
func (l *LinuxJail) Command(command []string) *exec.Cmd {
|
||||
l.logger.Debug("Creating command with namespace")
|
||||
|
||||
cmd := exec.Command(command[0], command[1:]...)
|
||||
// Set env vars for the child process; they will be inherited by the target process.
|
||||
cmd.Env = getEnvsForTargetProcess(l.configDir, l.caCertPath)
|
||||
cmd.Env = append(cmd.Env, "CHILD=true")
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("VETH_JAIL_NAME=%v", l.vethJailName))
|
||||
if l.configureDNSForLocalStubResolver {
|
||||
cmd.Env = append(cmd.Env, "CONFIGURE_DNS_FOR_LOCAL_STUB_RESOLVER=true")
|
||||
}
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
l.logger.Debug("os.Getuid()", "os.Getuid()", os.Getuid())
|
||||
l.logger.Debug("os.Getgid()", "os.Getgid()", os.Getgid())
|
||||
currentUid := os.Getuid()
|
||||
currentGid := os.Getgid()
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET,
|
||||
UidMappings: []syscall.SysProcIDMap{
|
||||
{ContainerID: currentUid, HostID: currentUid, Size: 1},
|
||||
},
|
||||
GidMappings: []syscall.SysProcIDMap{
|
||||
{ContainerID: currentGid, HostID: currentGid, Size: 1},
|
||||
},
|
||||
AmbientCaps: []uintptr{unix.CAP_NET_ADMIN},
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ConfigureHostNsCommunication finalizes host-side networking after the target
|
||||
// process has started. It moves the jail-side veth into the target process's network
|
||||
// namespace using the provided PID. This requires the process to be running so
|
||||
// its PID (and thus its netns) are available.
|
||||
func (l *LinuxJail) ConfigureHostNsCommunication(pidInt int) error {
|
||||
PID := fmt.Sprintf("%v", pidInt)
|
||||
|
||||
runner := newCommandRunner([]*command{
|
||||
// Move the jail-side veth interface into the target network namespace.
|
||||
// This isolates the interface so that it becomes visible only inside the
|
||||
// jail's netns. From this point on, the jail will configure its end of
|
||||
// the veth pair (IP address, routes, etc.) independently of the host.
|
||||
newCommand(
|
||||
"Move jail-side veth into network namespace",
|
||||
exec.Command("ip", "link", "set", l.vethJailName, "netns", PID),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close removes the network namespace and iptables rules
|
||||
func (l *LinuxJail) Close() error {
|
||||
// Clean up iptables rules
|
||||
err := l.cleanupIptables()
|
||||
if err != nil {
|
||||
l.logger.Error("Failed to clean up iptables rules", "error", err)
|
||||
// Continue with other cleanup even if this fails
|
||||
}
|
||||
|
||||
// Clean up networking
|
||||
err = l.cleanupNetworking()
|
||||
if err != nil {
|
||||
l.logger.Error("Failed to clean up networking", "error", err)
|
||||
// Continue with other cleanup even if this fails
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// ConfigureDNSForLocalStubResolver configures DNS redirection from the network namespace
|
||||
// to the host's local stub resolver. This function should only be called when the host
|
||||
// runs a local stub resolver such as systemd-resolved, and /etc/resolv.conf contains
|
||||
// "nameserver 127.0.0.53" (listening on localhost). It redirects DNS requests from the
|
||||
// namespace to the host by setting up iptables NAT rules. Additionally, /etc/systemd/resolved.conf
|
||||
// should be configured with DNSStubListener=yes and DNSStubListenerExtra=192.168.100.1:53
|
||||
// to listen on the additional server address.
|
||||
// NOTE: it's called inside network namespace.
|
||||
func ConfigureDNSForLocalStubResolver() error {
|
||||
runner := newCommandRunner([]*command{
|
||||
// Redirect all DNS queries inside the namespace to the host DNS listener.
|
||||
// Needed because systemd-resolved listens on a host-side IP, not inside the namespace.
|
||||
newCommand(
|
||||
"Redirect DNS queries (DNAT 53 → host DNS)",
|
||||
exec.Command("iptables", "-t", "nat", "-A", "OUTPUT", "-p", "udp", "--dport", "53", "-j", "DNAT", "--to-destination", "192.168.100.1:53"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Rewrite the SOURCE IP of redirected DNS packets.
|
||||
// Required because DNS queries originating as 127.0.0.1 inside the namespace
|
||||
// must not leave the namespace with a loopback source (kernel drops them).
|
||||
// SNAT ensures packets arrive at systemd-resolved with a valid, routable source.
|
||||
newCommand(
|
||||
"Fix DNS source IP (SNAT 127.0.0.x → 192.168.100.2)",
|
||||
exec.Command("iptables", "-t", "nat", "-A", "POSTROUTING", "-p", "udp", "--dport", "53", "-d", "192.168.100.1", "-j", "SNAT", "--to-source", "192.168.100.2"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Allow packets destined for 127.0.0.0/8 to go through routing and NAT.
|
||||
// Without this, DNS queries to 127.0.0.53 never hit iptables OUTPUT
|
||||
// and cannot be redirected to the host.
|
||||
newCommand(
|
||||
"Allow loopback-destined traffic to pass through NAT (route_localnet)",
|
||||
// TODO(yevhenii): consider replacing with specific interfaces instead of all
|
||||
exec.Command("sysctl", "-w", "net.ipv4.conf.all.route_localnet=1"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// configureHostNetworkBeforeCmdExec prepares host-side networking before the target
|
||||
// process is started. At this point the target process is not running, so its PID and network
|
||||
// namespace ID are not yet known.
|
||||
func (l *LinuxJail) configureHostNetworkBeforeCmdExec() error {
|
||||
// Create veth pair with short names (Linux interface names limited to 15 chars)
|
||||
// Generate unique ID to avoid conflicts
|
||||
uniqueID := fmt.Sprintf("%d", time.Now().UnixNano()%10000000) // 7 digits max
|
||||
vethHostName := fmt.Sprintf("veth_h_%s", uniqueID) // veth_h_1234567 = 14 chars
|
||||
vethJailName := fmt.Sprintf("veth_n_%s", uniqueID) // veth_n_1234567 = 14 chars
|
||||
|
||||
// Store veth interface name for iptables rules
|
||||
l.vethHostName = vethHostName
|
||||
l.vethJailName = vethJailName
|
||||
|
||||
runner := newCommandRunner([]*command{
|
||||
// Create a virtual Ethernet (veth) pair that forms a point-to-point link
|
||||
// between the host and the jail namespace. One end stays on the host,
|
||||
// the other will be moved into the jail. This provides a dedicated,
|
||||
// isolated L2 network for the jail.
|
||||
newCommand(
|
||||
"Create host–jail veth interface pair",
|
||||
exec.Command("ip", "link", "add", vethHostName, "type", "veth", "peer", "name", vethJailName),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Assign an IP address to the host side of the veth pair. The /24 mask
|
||||
// implicitly defines the jail's entire subnet as 192.168.100.0/24.
|
||||
// The host address (192.168.100.1) becomes the default gateway for
|
||||
// processes inside the jail and is used by NAT and interception rules
|
||||
// to route traffic out of the namespace.
|
||||
newCommand(
|
||||
"Assign IP to host-side veth",
|
||||
exec.Command("ip", "addr", "add", "192.168.100.1/24", "dev", vethHostName),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
newCommand(
|
||||
"Activate host-side veth interface",
|
||||
exec.Command("ip", "link", "set", vethHostName, "up"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupIptables configures iptables rules for comprehensive TCP traffic interception
|
||||
func (l *LinuxJail) configureIptables() error {
|
||||
runner := newCommandRunner([]*command{
|
||||
// Enable IPv4 packet forwarding so the host can route packets between
|
||||
// the jail's veth interface and the outside network. Without this,
|
||||
// NAT and forwarding rules would have no effect because the kernel
|
||||
// would drop transit packets.
|
||||
newCommand(
|
||||
"enable IP forwarding",
|
||||
exec.Command("sysctl", "-w", "net.ipv4.ip_forward=1"),
|
||||
[]uintptr{},
|
||||
),
|
||||
// Apply source NAT (MASQUERADE) for all traffic leaving the jail’s
|
||||
// private subnet. This rewrites the source IP of packets originating
|
||||
// from 192.168.100.0/24 to the host’s external interface IP. It enables:
|
||||
//
|
||||
// - outbound connectivity for jailed processes,
|
||||
// - correct return routing from external endpoints,
|
||||
// - avoidance of static IP assignment for the host interface.
|
||||
//
|
||||
// MASQUERADE is used instead of SNAT so it works even when the host IP
|
||||
// changes dynamically.
|
||||
newCommand(
|
||||
"NAT rules for outgoing traffic (MASQUERADE for return traffic)",
|
||||
exec.Command("iptables", "-t", "nat", "-A", "POSTROUTING", "-s", "192.168.100.0/24", "-j", "MASQUERADE"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Redirect *ALL TCP traffic* coming from the jail’s veth interface
|
||||
// to the local HTTP/TLS-intercepting proxy. This causes *every* TCP
|
||||
// connection (HTTP, HTTPS, plain TCP protocols) initiated by jailed
|
||||
// processes to be transparently intercepted.
|
||||
//
|
||||
// The HTTP proxy will intelligently handle both HTTP and TLS traffic.
|
||||
//
|
||||
// PREROUTING is used so redirection happens before routing decisions.
|
||||
// REDIRECT rewrites the destination IP to 127.0.0.1 and the destination
|
||||
// port to the HTTP proxy's port, forcing traffic through the proxy without
|
||||
// requiring any configuration inside the jail.
|
||||
newCommand(
|
||||
"Route ALL TCP traffic to HTTP proxy",
|
||||
exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHostName, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Allow forwarding of non-TCP packets originating from the jail’s subnet.
|
||||
// This rule is primarily needed for traffic that is *not* intercepted by
|
||||
// the TCP REDIRECT rule — for example:
|
||||
//
|
||||
// - DNS queries (UDP/53)
|
||||
// - ICMP (ping, errors)
|
||||
// - Any other UDP or non-TCP protocols
|
||||
//
|
||||
// Redirected TCP flows never reach the FORWARD chain (they are locally
|
||||
// redirected in PREROUTING), so this rule does not apply to TCP traffic.
|
||||
newCommand(
|
||||
"Allow outbound non-TCP traffic from jail subnet",
|
||||
exec.Command("iptables", "-A", "FORWARD", "-s", "192.168.100.0/24", "-j", "ACCEPT"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Allow forwarding of return traffic destined for the jail’s subnet for
|
||||
// non-TCP flows. This complements the previous FORWARD rule and ensures
|
||||
// that responses to DNS (UDP) or ICMP packets can reach the jail.
|
||||
//
|
||||
// As with the previous rule, this has no effect on TCP traffic because
|
||||
// all TCP connections from the jail are intercepted and redirected to
|
||||
// the local proxy before reaching the forwarding path.
|
||||
newCommand(
|
||||
"Allow inbound return traffic to jail subnet (non-TCP)",
|
||||
exec.Command("iptables", "-A", "FORWARD", "-d", "192.168.100.0/24", "-j", "ACCEPT"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.logger.Debug("Comprehensive TCP boundarying enabled", "interface", l.vethHostName, "proxy_port", l.httpProxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupNetworking removes networking configuration
|
||||
func (l *LinuxJail) cleanupNetworking() error {
|
||||
runner := newCommandRunner([]*command{
|
||||
newCommandWithIgnoreErr(
|
||||
"delete veth pair",
|
||||
exec.Command("ip", "link", "del", l.vethHostName),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
"Cannot find device",
|
||||
),
|
||||
})
|
||||
if err := runner.runIgnoreErrors(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupIptables removes iptables rules
|
||||
func (l *LinuxJail) cleanupIptables() error {
|
||||
runner := newCommandRunner([]*command{
|
||||
newCommand(
|
||||
"Remove: NAT rules for outgoing traffic (MASQUERADE for return traffic)",
|
||||
exec.Command("iptables", "-t", "nat", "-D", "POSTROUTING", "-s", "192.168.100.0/24", "-j", "MASQUERADE"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
newCommand(
|
||||
"Remove: Route ALL TCP traffic to HTTP proxy",
|
||||
exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHostName, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
newCommand(
|
||||
"Remove: Allow outbound non-TCP traffic from jail subnet",
|
||||
exec.Command("iptables", "-D", "FORWARD", "-s", "192.168.100.0/24", "-j", "ACCEPT"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
newCommand(
|
||||
"Remove: Allow inbound return traffic to jail subnet (non-TCP)",
|
||||
exec.Command("iptables", "-D", "FORWARD", "-d", "192.168.100.0/24", "-j", "ACCEPT"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.runIgnoreErrors(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// SetupChildNetworking configures networking within the target process's network
|
||||
// namespace. This runs inside the child process after it has been
|
||||
// created and moved to its own network namespace.
|
||||
func SetupChildNetworking(vethNetJail string) error {
|
||||
runner := newCommandRunner([]*command{
|
||||
// Assign an IP address to the jail-side veth interface. The /24 mask
|
||||
// matches the subnet defined on the host side (192.168.100.0/24),
|
||||
// ensuring both interfaces appear on the same L2 network. This address
|
||||
// (192.168.100.2) will serve as the jail's primary outbound source IP.
|
||||
newCommand(
|
||||
"Assign IP to jail-side veth",
|
||||
exec.Command("ip", "addr", "add", "192.168.100.2/24", "dev", vethNetJail),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Bring the jail-side veth interface up. Until the interface is set UP,
|
||||
// the jail cannot send or receive any packets on this link, even if the
|
||||
// IP address and routes are configured correctly.
|
||||
newCommand(
|
||||
"Activate jail-side veth interface",
|
||||
exec.Command("ip", "link", "set", vethNetJail, "up"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Bring the jail-side veth interface up. Until the interface is set UP,
|
||||
// the jail cannot send or receive any packets on this link, even if the
|
||||
// IP address and routes are configured correctly.
|
||||
newCommand(
|
||||
"Enable loopback interface in jail",
|
||||
exec.Command("ip", "link", "set", "lo", "up"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
// Set the default route for all outbound traffic inside the jail. The
|
||||
// gateway is the host-side veth address (192.168.100.1), which performs
|
||||
// NAT and transparent TCP interception. This ensures that packets not
|
||||
// destined for the jail subnet are routed to the host for processing.
|
||||
newCommand(
|
||||
"Configure default gateway for jail",
|
||||
exec.Command("ip", "route", "add", "default", "via", "192.168.100.1"),
|
||||
[]uintptr{uintptr(unix.CAP_NET_ADMIN)},
|
||||
),
|
||||
})
|
||||
if err := runner.run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail_manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager/nsjail"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/tls"
|
||||
)
|
||||
|
||||
func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if len(config.AllowRules) == 0 {
|
||||
logger.Warn("No allow rules specified; all network traffic will be denied by default")
|
||||
}
|
||||
|
||||
// Parse allow rules
|
||||
allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules)
|
||||
if err != nil {
|
||||
logger.Error("Failed to parse allow rules", "error", err)
|
||||
return fmt.Errorf("failed to parse allow rules: %v", err)
|
||||
}
|
||||
|
||||
// Create rule engine
|
||||
ruleEngine := rulesengine.NewRuleEngine(allowRules, logger)
|
||||
|
||||
// Create auditor
|
||||
auditor, err := audit.SetupAuditor(ctx, logger, config.DisableAuditLogs, config.LogProxySocketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup auditor: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate manager
|
||||
certManager, err := tls.NewCertificateManager(tls.Config{
|
||||
Logger: logger,
|
||||
ConfigDir: config.UserInfo.ConfigDir,
|
||||
Uid: config.UserInfo.Uid,
|
||||
Gid: config.UserInfo.Gid,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to create certificate manager", "error", err)
|
||||
return fmt.Errorf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
|
||||
// Setup TLS to get cert path for jailer
|
||||
tlsConfig, err := certManager.SetupTLSAndWriteCACert()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
|
||||
}
|
||||
|
||||
// Create jailer with cert path from TLS setup
|
||||
jailer, err := nsjail.NewLinuxJail(nsjail.Config{
|
||||
Logger: logger,
|
||||
HttpProxyPort: int(config.ProxyPort),
|
||||
HomeDir: config.UserInfo.HomeDir,
|
||||
ConfigDir: config.UserInfo.ConfigDir,
|
||||
CACertPath: config.UserInfo.CACertPath(),
|
||||
ConfigureDNSForLocalStubResolver: config.ConfigureDNSForLocalStubResolver,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create jailer: %v", err)
|
||||
}
|
||||
|
||||
// Create boundary instance
|
||||
nsJailMgr, err := NewNSJailManager(ruleEngine, auditor, tlsConfig, jailer, logger, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create boundary instance: %v", err)
|
||||
}
|
||||
|
||||
return nsJailMgr.Run(ctx)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
//go:build linux
|
||||
|
||||
package nsjail_manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
)
|
||||
|
||||
func isChild() bool {
|
||||
return os.Getenv("CHILD") == "true"
|
||||
}
|
||||
|
||||
// Run is the main entry point that determines whether to execute as a parent or child process.
|
||||
// If running as a child (CHILD env var is set), it sets up networking in the namespace
|
||||
// and executes the target command. Otherwise, it runs as the parent process, setting up the jail,
|
||||
// proxy server, and managing the child process lifecycle.
|
||||
func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error {
|
||||
if isChild() {
|
||||
return RunChild(logger, config.TargetCMD)
|
||||
}
|
||||
|
||||
return RunParent(ctx, logger, config)
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
// Package proxy implements HTTP CONNECT method for tunneling HTTPS traffic through a proxy.
|
||||
//
|
||||
// # HTTP CONNECT Method Overview
|
||||
//
|
||||
// The HTTP CONNECT method is used to establish a tunnel through a proxy server.
|
||||
// This is essential for HTTPS proxying because HTTPS requires end-to-end encryption
|
||||
// that cannot be inspected or modified by intermediaries.
|
||||
//
|
||||
// How HTTP_PROXY Works
|
||||
//
|
||||
// When a client is configured to use an HTTP proxy (via HTTP_PROXY environment variable
|
||||
// or proxy settings), it behaves differently for HTTP vs HTTPS requests:
|
||||
//
|
||||
// - HTTP requests: The client sends the full request to the proxy, including the
|
||||
// complete URL. The proxy forwards it to the destination server.
|
||||
//
|
||||
// - HTTPS requests: The client cannot send the encrypted request directly because
|
||||
// the proxy needs to know where to connect. Instead, the client uses CONNECT
|
||||
// to establish a tunnel, then performs the TLS handshake and sends HTTPS
|
||||
// requests through that tunnel.
|
||||
//
|
||||
// # Non-Transparent Proxy
|
||||
//
|
||||
// This proxy is "non-transparent" because:
|
||||
// - Clients must be explicitly configured to use it (via HTTP_PROXY)
|
||||
// - Clients send CONNECT requests for HTTPS traffic
|
||||
// - The proxy terminates TLS, inspects requests, and re-encrypts to the destination
|
||||
// - Each HTTP request inside the tunnel is processed separately with rule evaluation
|
||||
//
|
||||
// # CONNECT Request Flow
|
||||
//
|
||||
// The following diagram illustrates how CONNECT works:
|
||||
//
|
||||
// Client Proxy (HTTP/1.1 Server) Real Server
|
||||
// | | |
|
||||
// |-- CONNECT example.com:443 -->| |
|
||||
// | | |
|
||||
// |<-- 200 Connection Established| |
|
||||
// | | |
|
||||
// |-- TLS Handshake ------------->| |
|
||||
// | | |
|
||||
// |<-- TLS Handshake -------------| |
|
||||
// | | |
|
||||
// |-- Request #1: GET /page1 --->| (decrypts) |
|
||||
// | |-- GET /page1 --------------------->|
|
||||
// | |<-- Response #1 --------------------|
|
||||
// |<-- Response #1 --------------| (encrypts) |
|
||||
// | | |
|
||||
// |-- Request #2: GET /page2 --->| (decrypts) |
|
||||
// | |-- GET /page2 --------------------->|
|
||||
// | |<-- Response #2 --------------------|
|
||||
// |<-- Response #2 --------------| (encrypts) |
|
||||
// | | |
|
||||
// |-- Request #3: GET /api ----->| (decrypts) |
|
||||
// | |-- GET /api ----------------------->|
|
||||
// | |<-- Response #3 --------------------|
|
||||
// |<-- Response #3 --------------| (encrypts) |
|
||||
// | | |
|
||||
// | (connection stays open...) | |
|
||||
// | | |
|
||||
// |-- [closes connection] ------->| |
|
||||
//
|
||||
// Key Points:
|
||||
//
|
||||
// 1. CONNECT establishes the tunnel endpoint (e.g., "example.com:443")
|
||||
// 2. The actual destination for each request is determined by the Host header
|
||||
// in the HTTP request inside the tunnel, not the CONNECT target
|
||||
// 3. The proxy acts as a TLS server to decrypt traffic from the client
|
||||
// 4. Each HTTP request inside the tunnel is evaluated against rules separately
|
||||
// 5. The connection remains open for multiple requests (HTTP/1.1 keep-alive)
|
||||
//
|
||||
// Implementation Details:
|
||||
//
|
||||
// - handleCONNECT: Receives the CONNECT request, sends "200 Connection Established"
|
||||
// - handleCONNECTTunnel: Wraps the connection with TLS, processes requests in a loop
|
||||
// - Each request uses req.Host to determine the actual destination, not the CONNECT target
|
||||
//
|
||||
//nolint:revive,gocritic,errname,unconvert,noctx,errorlint,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// handleCONNECT handles HTTP CONNECT requests for tunneling.
|
||||
//
|
||||
// When a client wants to make an HTTPS request through the proxy, it first sends
|
||||
// a CONNECT request with the target hostname:port (e.g., "example.com:443").
|
||||
// The proxy responds with "200 Connection Established" and then the client
|
||||
// performs a TLS handshake over the same connection.
|
||||
//
|
||||
// After the tunnel is established, handleCONNECTTunnel processes the encrypted
|
||||
// traffic and handles each HTTP request inside the tunnel separately.
|
||||
func (p *Server) handleCONNECT(conn net.Conn, req *http.Request) {
|
||||
p.logger.Debug("🔌 CONNECT request", "target", req.Host)
|
||||
|
||||
// Send 200 Connection established response
|
||||
response := "HTTP/1.1 200 Connection established\r\n\r\n"
|
||||
_, err := conn.Write([]byte(response))
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to send CONNECT response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("CONNECT tunnel established", "target", req.Host)
|
||||
|
||||
// Handle the tunnel - decrypt TLS and process each HTTP request
|
||||
p.handleCONNECTTunnel(conn)
|
||||
}
|
||||
|
||||
// handleCONNECTTunnel handles the tunnel after CONNECT is established.
|
||||
//
|
||||
// This function:
|
||||
// 1. Wraps the connection with TLS.Server to decrypt traffic from the client
|
||||
// 2. Performs the TLS handshake
|
||||
// 3. Reads HTTP requests from the tunnel in a loop
|
||||
// 4. Processes each request separately (rule evaluation, forwarding)
|
||||
//
|
||||
// Important: The actual destination for each request is determined by the Host
|
||||
// header in the HTTP request, not the CONNECT target. This allows multiple
|
||||
// domains to be accessed over the same tunnel.
|
||||
//
|
||||
// The connection lifecycle is managed by handleHTTPConnection's defer, which
|
||||
// closes the connection when this function returns.
|
||||
func (p *Server) handleCONNECTTunnel(conn net.Conn) {
|
||||
// Wrap connection with TLS server to decrypt traffic
|
||||
tlsConn := tls.Server(conn, p.tlsConfig)
|
||||
|
||||
// Perform TLS handshake
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
p.logger.Error("TLS handshake failed in CONNECT tunnel", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("✅ TLS handshake successful in CONNECT tunnel")
|
||||
|
||||
// Process HTTP requests in a loop
|
||||
reader := bufio.NewReader(tlsConn)
|
||||
for {
|
||||
// Read HTTP request from tunnel
|
||||
req, err := http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
p.logger.Debug("CONNECT tunnel closed by client")
|
||||
break
|
||||
}
|
||||
p.logger.Error("Failed to read HTTP request from CONNECT tunnel", "error", err)
|
||||
break
|
||||
}
|
||||
|
||||
p.logger.Debug("🔒 HTTP Request in CONNECT tunnel", "method", req.Method, "url", req.URL.String(), "target", req.Host)
|
||||
|
||||
// Process this request - check if allowed and forward to target
|
||||
p.processHTTPRequest(tlsConn, req, true)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
//nolint:revive,gocritic,errname,unconvert,noctx,errorlint,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof" // G108: pprof is intentionally exposed for debugging
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
)
|
||||
|
||||
// Server handles HTTP and HTTPS requests with rule-based filtering
|
||||
type Server struct {
|
||||
ruleEngine rulesengine.Engine
|
||||
auditor audit.Auditor
|
||||
logger *slog.Logger
|
||||
tlsConfig *tls.Config
|
||||
httpPort int
|
||||
started atomic.Bool
|
||||
|
||||
listener net.Listener
|
||||
pprofServer *http.Server
|
||||
pprofEnabled bool
|
||||
pprofPort int
|
||||
}
|
||||
|
||||
// Config holds configuration for the proxy server
|
||||
type Config struct {
|
||||
HTTPPort int
|
||||
RuleEngine rulesengine.Engine
|
||||
Auditor audit.Auditor
|
||||
Logger *slog.Logger
|
||||
TLSConfig *tls.Config
|
||||
PprofEnabled bool
|
||||
PprofPort int
|
||||
}
|
||||
|
||||
// NewProxyServer creates a new proxy server instance
|
||||
func NewProxyServer(config Config) *Server {
|
||||
return &Server{
|
||||
ruleEngine: config.RuleEngine,
|
||||
auditor: config.Auditor,
|
||||
logger: config.Logger,
|
||||
tlsConfig: config.TLSConfig,
|
||||
httpPort: config.HTTPPort,
|
||||
pprofEnabled: config.PprofEnabled,
|
||||
pprofPort: config.PprofPort,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the HTTP proxy server with TLS termination capability
|
||||
func (p *Server) Start() error {
|
||||
if p.isStarted() {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort)
|
||||
|
||||
// Start pprof server if enabled
|
||||
if p.pprofEnabled {
|
||||
p.pprofServer = &http.Server{ // G112: pprof server doesn't need ReadHeaderTimeout
|
||||
Addr: fmt.Sprintf(":%d", p.pprofPort),
|
||||
Handler: http.DefaultServeMux,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", p.pprofPort))
|
||||
if err != nil {
|
||||
p.logger.Error("failed to listen on port for pprof server", "port", p.pprofPort, "error", err)
|
||||
return xerrors.Errorf("failed to listen on port %v for pprof server: %v", p.pprofPort, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.logger.Info("Serving pprof on existing listener", "port", p.pprofPort)
|
||||
if err := p.pprofServer.Serve(ln); err != nil && errors.Is(err, http.ErrServerClosed) {
|
||||
p.logger.Error("pprof server error", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
var err error
|
||||
p.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort))
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to create HTTP listener", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.started.Store(true)
|
||||
|
||||
// Start HTTP server with custom listener for TLS detection
|
||||
go func() {
|
||||
for {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil && errors.Is(err, net.ErrClosed) && p.isStopped() {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to accept connection", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle connection with TLS detection
|
||||
go p.handleConnectionWithTLSDetection(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stops proxy server
|
||||
func (p *Server) Stop() error {
|
||||
if p.isStopped() {
|
||||
return nil
|
||||
}
|
||||
p.started.Store(false)
|
||||
|
||||
if p.listener == nil {
|
||||
p.logger.Error("unexpected nil listener")
|
||||
return xerrors.New("unexpected nil listener")
|
||||
}
|
||||
|
||||
err := p.listener.Close()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to close listener", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close pprof server
|
||||
if p.pprofServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := p.pprofServer.Shutdown(ctx); err != nil {
|
||||
p.logger.Error("Failed to shutdown pprof server", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Server) isStarted() bool {
|
||||
return p.started.Load()
|
||||
}
|
||||
|
||||
func (p *Server) isStopped() bool {
|
||||
return !p.started.Load()
|
||||
}
|
||||
|
||||
func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) {
|
||||
// Detect protocol using TLS handshake detection
|
||||
wrappedConn, isTLS, err := p.isTLSConnection(conn)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to check connection type", "error", err)
|
||||
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to close connection", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if isTLS {
|
||||
p.logger.Debug("🔒 Detected TLS connection - handling as HTTPS")
|
||||
p.handleTLSConnection(wrappedConn)
|
||||
} else {
|
||||
p.logger.Debug("🌐 Detected HTTP connection")
|
||||
p.handleHTTPConnection(wrappedConn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) isTLSConnection(conn net.Conn) (net.Conn, bool, error) {
|
||||
// Read first byte to detect TLS
|
||||
buf := make([]byte, 1)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil || n == 0 {
|
||||
return nil, false, xerrors.Errorf("failed to read first byte from connection: %v, read %v bytes", err, n)
|
||||
}
|
||||
|
||||
connWrapper := &connectionWrapper{conn, buf, false}
|
||||
|
||||
// TLS detection based on first byte:
|
||||
// 0x16 (22) = TLS Handshake
|
||||
// 0x17 (23) = TLS Application Data
|
||||
// 0x14 (20) = TLS Change Cipher Spec
|
||||
// 0x15 (21) = TLS Alert
|
||||
isTLS := buf[0] == 0x16 || buf[0] == 0x17 || buf[0] == 0x14 || buf[0] == 0x15
|
||||
|
||||
if isTLS {
|
||||
p.logger.Debug("TLS detected", "first byte", buf[0])
|
||||
}
|
||||
|
||||
return connWrapper, isTLS, nil
|
||||
}
|
||||
|
||||
func (p *Server) handleHTTPConnection(conn net.Conn) {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to close connection", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Read HTTP request
|
||||
req, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to read HTTP request", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
p.handleCONNECT(conn, req)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("🌐 HTTP Request", "method", req.Method, "url", req.URL.String())
|
||||
p.processHTTPRequest(conn, req, false)
|
||||
}
|
||||
|
||||
func (p *Server) handleTLSConnection(conn net.Conn) {
|
||||
// Create TLS connection
|
||||
tlsConn := tls.Server(conn, p.tlsConfig)
|
||||
|
||||
defer func() {
|
||||
err := tlsConn.Close()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to close TLS connection", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform TLS handshake
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
p.logger.Error("TLS handshake failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("✅ TLS handshake successful")
|
||||
|
||||
// Read HTTP request over TLS
|
||||
req, err := http.ReadRequest(bufio.NewReader(tlsConn))
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to read HTTPS request", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("🔒 HTTPS Request", "method", req.Method, "url", req.URL.String())
|
||||
p.processHTTPRequest(tlsConn, req, true)
|
||||
}
|
||||
|
||||
func (p *Server) processHTTPRequest(conn net.Conn, req *http.Request, https bool) {
|
||||
p.logger.Debug(" Host", "host", req.Host)
|
||||
p.logger.Debug(" User-Agent", "user-agent", req.Header.Get("User-Agent"))
|
||||
|
||||
// Construct fully qualified URL for rule evaluation and auditing.
|
||||
// In boundary's normal transparent proxy operation, req.URL only contains
|
||||
// the path since clients don't know they're going through a proxy.
|
||||
// When clients explicitly configure a proxy, req.URL contains the full URL.
|
||||
fullURL := req.URL.String()
|
||||
if req.URL.Scheme == "" {
|
||||
scheme := "http"
|
||||
if https {
|
||||
scheme = "https"
|
||||
}
|
||||
fullURL = scheme + "://" + req.Host + fullURL
|
||||
}
|
||||
|
||||
result := p.ruleEngine.Evaluate(req.Method, fullURL)
|
||||
|
||||
p.auditor.AuditRequest(audit.Request{
|
||||
Method: req.Method,
|
||||
URL: fullURL,
|
||||
Host: req.Host,
|
||||
Allowed: result.Allowed,
|
||||
Rule: result.Rule,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
p.writeBlockedResponse(conn, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward request to destination
|
||||
p.forwardRequest(conn, req, https)
|
||||
}
|
||||
|
||||
func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) {
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse // Don't follow redirects
|
||||
},
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if https {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Create a new request to the target server
|
||||
targetURL := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: req.Host,
|
||||
Path: req.URL.Path,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
}
|
||||
|
||||
body := req.Body
|
||||
if req.Method == http.MethodGet || req.Method == http.MethodHead {
|
||||
body = nil
|
||||
}
|
||||
newReq, err := http.NewRequest(req.Method, targetURL.String(), body)
|
||||
if err != nil {
|
||||
p.logger.Error("can't create http request", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Copy headers
|
||||
for name, values := range req.Header {
|
||||
// Skip connection-specific headers
|
||||
if strings.ToLower(name) == "connection" || strings.ToLower(name) == "proxy-connection" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
newReq.Header.Add(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Make request to destination
|
||||
resp, err := client.Do(newReq)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to forward HTTPS request", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("🔒 HTTPS Response", "status code", resp.StatusCode, "status", resp.Status)
|
||||
|
||||
p.logger.Debug("Forwarded Request",
|
||||
"method", newReq.Method,
|
||||
"host", newReq.Host,
|
||||
"URL", newReq.URL,
|
||||
)
|
||||
|
||||
// Read the body and explicitly set Content-Length header, otherwise client can hung up on the request.
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
p.logger.Error("can't read response body", "error", err)
|
||||
return
|
||||
}
|
||||
resp.Header.Add("Content-Length", strconv.Itoa(len(bodyBytes)))
|
||||
resp.ContentLength = int64(len(bodyBytes))
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to close HTTP response body", "error", err)
|
||||
return
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// The downstream client (Claude) always communicates over HTTP/1.1.
|
||||
// However, Go's default HTTP client may negotiate an HTTP/2 connection
|
||||
// with the upstream server via ALPN during TLS handshake.
|
||||
// This can cause the response's Proto field to be set to "HTTP/2.0",
|
||||
// which would produce an invalid response for an HTTP/1.1 client.
|
||||
// To prevent this mismatch, we explicitly normalize the response
|
||||
// to HTTP/1.1 before writing it back to the client.
|
||||
resp.Proto = "HTTP/1.1"
|
||||
resp.ProtoMajor = 1
|
||||
resp.ProtoMinor = 1
|
||||
|
||||
// Copy response back to client
|
||||
err = resp.Write(conn)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to forward back HTTP response",
|
||||
"error", err,
|
||||
"host", req.Host,
|
||||
"method", req.Method,
|
||||
// "bodyBytes", string(bodyBytes),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("Successfully wrote to connection")
|
||||
}
|
||||
|
||||
func (p *Server) writeBlockedResponse(conn net.Conn, req *http.Request) {
|
||||
// Create a response object
|
||||
resp := &http.Response{
|
||||
Status: "403 Forbidden",
|
||||
StatusCode: http.StatusForbidden,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: nil,
|
||||
ContentLength: 0,
|
||||
}
|
||||
|
||||
// Set headers
|
||||
resp.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
// Create the response body
|
||||
host := req.URL.Host
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
|
||||
body := fmt.Sprintf(`🚫 Request Blocked by Boundary
|
||||
|
||||
Request: %s %s
|
||||
Host: %s
|
||||
|
||||
To allow this request, restart boundary with:
|
||||
--allow "domain=%s" # Allow all methods to this host
|
||||
--allow "method=%s domain=%s" # Allow only %s requests to this host
|
||||
|
||||
For more help: https://github.com/coder/boundary
|
||||
`,
|
||||
req.Method, req.URL.Path, host, host, req.Method, host, req.Method)
|
||||
|
||||
resp.Body = io.NopCloser(strings.NewReader(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
|
||||
// Copy response back to client
|
||||
err := resp.Write(conn)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to write blocker response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("Successfully wrote to connection")
|
||||
}
|
||||
|
||||
// connectionWrapper lets us "unread" the peeked byte
|
||||
type connectionWrapper struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
bufUsed bool
|
||||
}
|
||||
|
||||
func (c *connectionWrapper) Read(p []byte) (int, error) {
|
||||
if !c.bufUsed && len(c.buf) > 0 {
|
||||
n := copy(p, c.buf)
|
||||
c.bufUsed = true
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
)
|
||||
|
||||
// capturingAuditor captures all audit requests for test verification.
|
||||
type capturingAuditor struct {
|
||||
mu sync.Mutex
|
||||
requests []audit.Request
|
||||
}
|
||||
|
||||
func (c *capturingAuditor) AuditRequest(req audit.Request) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.requests = append(c.requests, req)
|
||||
}
|
||||
|
||||
func (c *capturingAuditor) getRequests() []audit.Request {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return append([]audit.Request{}, c.requests...)
|
||||
}
|
||||
|
||||
func TestAuditURLIsFullyFormed_HTTP(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
serverURL, err := url.Parse(server.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
auditor := &capturingAuditor{}
|
||||
|
||||
pt := NewProxyTest(t,
|
||||
WithCertManager(t.TempDir()),
|
||||
WithAllowedRule("domain="+serverURL.Hostname()+" path=/allowed/*"),
|
||||
WithAuditor(auditor),
|
||||
).Start()
|
||||
defer pt.Stop()
|
||||
|
||||
t.Run("allowed", func(t *testing.T) {
|
||||
resp, err := pt.proxyClient.Get(server.URL + "/allowed/path?q=1")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
requests := auditor.getRequests()
|
||||
require.NotEmpty(t, requests)
|
||||
|
||||
req := requests[len(requests)-1]
|
||||
require.True(t, req.Allowed)
|
||||
|
||||
expectedURL := "http://" + net.JoinHostPort(serverURL.Hostname(), serverURL.Port()) + "/allowed/path?q=1"
|
||||
assert.Equal(t, expectedURL, req.URL)
|
||||
})
|
||||
|
||||
t.Run("denied", func(t *testing.T) {
|
||||
resp, err := pt.proxyClient.Get(server.URL + "/denied/path")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
|
||||
requests := auditor.getRequests()
|
||||
require.NotEmpty(t, requests)
|
||||
|
||||
req := requests[len(requests)-1]
|
||||
require.False(t, req.Allowed)
|
||||
|
||||
expectedURL := "http://" + net.JoinHostPort(serverURL.Hostname(), serverURL.Port()) + "/denied/path"
|
||||
assert.Equal(t, expectedURL, req.URL)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuditURLIsFullyFormed_HTTPS(t *testing.T) {
|
||||
auditor := &capturingAuditor{}
|
||||
|
||||
pt := NewProxyTest(t,
|
||||
WithCertManager(t.TempDir()),
|
||||
WithAllowedDomain("dev.coder.com"),
|
||||
WithAuditor(auditor),
|
||||
).Start()
|
||||
defer pt.Stop()
|
||||
|
||||
tunnel, err := pt.establishExplicitCONNECT("dev.coder.com:443")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, tunnel.close())
|
||||
}()
|
||||
|
||||
t.Run("allowed", func(t *testing.T) {
|
||||
_, err := tunnel.sendRequest("dev.coder.com", "/api/v2?q=1")
|
||||
require.NoError(t, err)
|
||||
|
||||
requests := auditor.getRequests()
|
||||
require.NotEmpty(t, requests)
|
||||
|
||||
req := requests[len(requests)-1]
|
||||
require.True(t, req.Allowed)
|
||||
|
||||
assert.Equal(t, "https://dev.coder.com/api/v2?q=1", req.URL)
|
||||
})
|
||||
|
||||
t.Run("denied", func(t *testing.T) {
|
||||
err := tunnel.sendRequestAndExpectDeny("blocked.example.com", "/some/path")
|
||||
require.NoError(t, err)
|
||||
|
||||
requests := auditor.getRequests()
|
||||
require.NotEmpty(t, requests)
|
||||
|
||||
req := requests[len(requests)-1]
|
||||
require.False(t, req.Allowed)
|
||||
|
||||
assert.Equal(t, "https://blocked.example.com/some/path", req.URL)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestProxyServerImplicitCONNECT tests HTTP CONNECT method for HTTPS tunneling
|
||||
// CONNECT happens implicitly when using proxy transport with HTTPS requests
|
||||
func TestProxyServerImplicitCONNECT(t *testing.T) {
|
||||
pt := NewProxyTest(t,
|
||||
WithCertManager("/tmp/boundary_connect_test"),
|
||||
WithAllowedDomain("dev.coder.com"),
|
||||
WithAllowedDomain("jsonplaceholder.typicode.com"),
|
||||
).
|
||||
Start()
|
||||
defer pt.Stop()
|
||||
|
||||
// Test HTTPS request through proxy transport (automatic CONNECT)
|
||||
t.Run("HTTPSRequestThroughProxyTransport", func(t *testing.T) {
|
||||
expectedResponse := `{"message":"👋"}
|
||||
`
|
||||
// Because this is HTTPS, Go will issue CONNECT localhost:8080 → dev.coder.com:443
|
||||
pt.ExpectAllowedViaProxy("https://dev.coder.com/api/v2", expectedResponse)
|
||||
})
|
||||
|
||||
// Test HTTP request through proxy transport
|
||||
t.Run("HTTPRequestThroughProxyTransport", func(t *testing.T) {
|
||||
expectedResponse := `{
|
||||
"userId": 1,
|
||||
"id": 1,
|
||||
"title": "delectus aut autem",
|
||||
"completed": false
|
||||
}`
|
||||
// For HTTP requests, Go will send the request directly to the proxy
|
||||
// The proxy will forward it to the target server
|
||||
pt.ExpectAllowedViaProxy("http://jsonplaceholder.typicode.com/todos/1", expectedResponse)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMultipleRequestsOverExplicitCONNECT tests explicit CONNECT requests with multiple requests over the same tunnel
|
||||
func TestMultipleRequestsOverExplicitCONNECT(t *testing.T) {
|
||||
pt := NewProxyTest(t,
|
||||
WithCertManager("/tmp/boundary_explicit_connect_test"),
|
||||
WithAllowedDomain("dev.coder.com"),
|
||||
WithAllowedDomain("jsonplaceholder.typicode.com"),
|
||||
).
|
||||
Start()
|
||||
defer pt.Stop()
|
||||
|
||||
// Establish explicit CONNECT tunnel
|
||||
// Note: The CONNECT target is just the tunnel endpoint. The actual destination
|
||||
// for each request is determined by the Host header in the HTTP request inside the tunnel.
|
||||
tunnel, err := pt.establishExplicitCONNECT("dev.coder.com:443")
|
||||
require.NoError(t, err, "Failed to establish CONNECT tunnel")
|
||||
defer tunnel.close()
|
||||
|
||||
// Positive test: Send first request to dev.coder.com over the tunnel
|
||||
t.Run("AllowedRequestToDevCoder", func(t *testing.T) {
|
||||
body1, err := tunnel.sendRequest("dev.coder.com", "/api/v2")
|
||||
require.NoError(t, err, "Failed to send first request")
|
||||
expectedResponse1 := `{"message":"👋"}
|
||||
`
|
||||
require.Equal(t, expectedResponse1, string(body1), "First response does not match")
|
||||
})
|
||||
|
||||
// Positive test: Send second request to a different domain (jsonplaceholder.typicode.com) over the same tunnel
|
||||
t.Run("AllowedRequestToJsonPlaceholder", func(t *testing.T) {
|
||||
body2, err := tunnel.sendRequest("jsonplaceholder.typicode.com", "/todos/1")
|
||||
require.NoError(t, err, "Failed to send second request")
|
||||
expectedResponse2 := `{
|
||||
"userId": 1,
|
||||
"id": 1,
|
||||
"title": "delectus aut autem",
|
||||
"completed": false
|
||||
}`
|
||||
require.Equal(t, expectedResponse2, string(body2), "Second response does not match")
|
||||
})
|
||||
|
||||
// Negative test: Try to send request to a blocked domain over the same tunnel
|
||||
t.Run("BlockedDomainOverSameTunnel", func(t *testing.T) {
|
||||
err := tunnel.sendRequestAndExpectDeny("example.com", "/")
|
||||
require.NoError(t, err, "Expected request to be blocked")
|
||||
})
|
||||
|
||||
// Negative test: Try to send request to another blocked domain
|
||||
t.Run("AnotherBlockedDomainOverSameTunnel", func(t *testing.T) {
|
||||
err := tunnel.sendRequestAndExpectDeny("github.com", "/")
|
||||
require.NoError(t, err, "Expected request to be blocked")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,438 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/audit"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/rulesengine"
|
||||
boundary_tls "github.com/coder/coder/v2/enterprise/cli/boundary/tls"
|
||||
)
|
||||
|
||||
// mockAuditor is a simple mock auditor for testing
|
||||
type mockAuditor struct{}
|
||||
|
||||
func (m *mockAuditor) AuditRequest(req audit.Request) {
|
||||
// No-op for testing
|
||||
}
|
||||
|
||||
// ProxyTest is a high-level test framework for proxy tests
|
||||
type ProxyTest struct {
|
||||
t *testing.T
|
||||
server *Server
|
||||
client *http.Client
|
||||
proxyClient *http.Client
|
||||
port int
|
||||
useCertManager bool
|
||||
configDir string
|
||||
startupDelay time.Duration
|
||||
allowedRules []string
|
||||
auditor audit.Auditor
|
||||
}
|
||||
|
||||
// ProxyTestOption is a function that configures ProxyTest
|
||||
type ProxyTestOption func(*ProxyTest)
|
||||
|
||||
// NewProxyTest creates a new ProxyTest instance
|
||||
func NewProxyTest(t *testing.T, opts ...ProxyTestOption) *ProxyTest {
|
||||
pt := &ProxyTest{
|
||||
t: t,
|
||||
port: 8080,
|
||||
useCertManager: false,
|
||||
configDir: "/tmp/boundary",
|
||||
startupDelay: 100 * time.Millisecond,
|
||||
allowedRules: []string{}, // Default: deny all (no rules = deny by default)
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(pt)
|
||||
}
|
||||
|
||||
return pt
|
||||
}
|
||||
|
||||
// WithProxyPort sets the proxy server port
|
||||
func WithProxyPort(port int) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.port = port
|
||||
}
|
||||
}
|
||||
|
||||
// WithCertManager enables TLS certificate manager
|
||||
func WithCertManager(configDir string) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.useCertManager = true
|
||||
pt.configDir = configDir
|
||||
}
|
||||
}
|
||||
|
||||
// WithStartupDelay sets how long to wait after starting server before making requests
|
||||
func WithStartupDelay(delay time.Duration) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.startupDelay = delay
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedDomain adds an allowed domain rule
|
||||
func WithAllowedDomain(domain string) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.allowedRules = append(pt.allowedRules, fmt.Sprintf("domain=%s", domain))
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedRule adds a full allow rule (e.g., "method=GET domain=example.com path=/api/*")
|
||||
func WithAllowedRule(rule string) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.allowedRules = append(pt.allowedRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuditor sets a custom auditor for capturing audit requests
|
||||
func WithAuditor(auditor audit.Auditor) ProxyTestOption {
|
||||
return func(pt *ProxyTest) {
|
||||
pt.auditor = auditor
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the proxy server
|
||||
func (pt *ProxyTest) Start() *ProxyTest {
|
||||
pt.t.Helper()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelError,
|
||||
}))
|
||||
|
||||
testRules, err := rulesengine.ParseAllowSpecs(pt.allowedRules)
|
||||
require.NoError(pt.t, err, "Failed to parse test rules")
|
||||
|
||||
ruleEngine := rulesengine.NewRuleEngine(testRules, logger)
|
||||
|
||||
// Use custom auditor if provided, otherwise use no-op mock
|
||||
auditor := pt.auditor
|
||||
if auditor == nil {
|
||||
auditor = &mockAuditor{}
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if pt.useCertManager {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(pt.t, err, "Failed to get current user")
|
||||
|
||||
uid, _ := strconv.Atoi(currentUser.Uid)
|
||||
gid, _ := strconv.Atoi(currentUser.Gid)
|
||||
|
||||
certManager, err := boundary_tls.NewCertificateManager(boundary_tls.Config{
|
||||
Logger: logger,
|
||||
ConfigDir: pt.configDir,
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
})
|
||||
require.NoError(pt.t, err, "Failed to create certificate manager")
|
||||
|
||||
tlsConfig, err = certManager.SetupTLSAndWriteCACert()
|
||||
require.NoError(pt.t, err, "Failed to setup TLS")
|
||||
} else {
|
||||
tlsConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
pt.server = NewProxyServer(Config{
|
||||
HTTPPort: pt.port,
|
||||
RuleEngine: ruleEngine,
|
||||
Auditor: auditor,
|
||||
Logger: logger,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
|
||||
err = pt.server.Start()
|
||||
require.NoError(pt.t, err, "Failed to start server")
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(pt.startupDelay)
|
||||
|
||||
// Create HTTP client for direct proxy requests
|
||||
pt.client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // G402: Skip cert verification for testing
|
||||
},
|
||||
},
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Create HTTP client for proxy transport (implicit CONNECT)
|
||||
proxyURL, err := url.Parse("http://localhost:" + strconv.Itoa(pt.port))
|
||||
require.NoError(pt.t, err, "Failed to parse proxy URL")
|
||||
|
||||
pt.proxyClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // G402: Skip cert verification for testing
|
||||
},
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
return pt
|
||||
}
|
||||
|
||||
// Stop gracefully stops the proxy server
|
||||
func (pt *ProxyTest) Stop() {
|
||||
if pt.server != nil {
|
||||
err := pt.server.Stop()
|
||||
if err != nil {
|
||||
pt.t.Logf("Failed to stop proxy server: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExpectAllowed makes a request through the proxy and expects it to be allowed with the given response body
|
||||
func (pt *ProxyTest) ExpectAllowed(proxyURL, hostHeader, expectedBody string) {
|
||||
pt.t.Helper()
|
||||
|
||||
req, err := http.NewRequest("GET", proxyURL, nil)
|
||||
require.NoError(pt.t, err, "Failed to create request")
|
||||
req.Host = hostHeader
|
||||
|
||||
resp, err := pt.client.Do(req)
|
||||
require.NoError(pt.t, err, "Failed to make request")
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Equal(pt.t, expectedBody, string(body), "Expected response body does not match")
|
||||
}
|
||||
|
||||
// ExpectAllowedContains makes a request through the proxy and expects it to be allowed, checking that response contains the given text
|
||||
func (pt *ProxyTest) ExpectAllowedContains(proxyURL, hostHeader, containsText string) {
|
||||
pt.t.Helper()
|
||||
|
||||
req, err := http.NewRequest("GET", proxyURL, nil)
|
||||
require.NoError(pt.t, err, "Failed to create request")
|
||||
req.Host = hostHeader
|
||||
|
||||
resp, err := pt.client.Do(req)
|
||||
require.NoError(pt.t, err, "Failed to make request")
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Contains(pt.t, string(body), containsText, "Response does not contain expected text")
|
||||
}
|
||||
|
||||
// ExpectDeny makes a request through the proxy and expects it to be denied
|
||||
func (pt *ProxyTest) ExpectDeny(proxyURL, hostHeader string) {
|
||||
pt.t.Helper()
|
||||
|
||||
req, err := http.NewRequest("GET", proxyURL, nil)
|
||||
require.NoError(pt.t, err, "Failed to create request")
|
||||
req.Host = hostHeader
|
||||
|
||||
resp, err := pt.client.Do(req)
|
||||
require.NoError(pt.t, err, "Failed to make request")
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(pt.t, http.StatusForbidden, resp.StatusCode, "Expected 403 Forbidden status")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Contains(pt.t, string(body), "Request Blocked by Boundary", "Expected request to be blocked")
|
||||
}
|
||||
|
||||
// ExpectDenyViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
|
||||
// and expects it to be denied
|
||||
func (pt *ProxyTest) ExpectDenyViaProxy(targetURL string) {
|
||||
pt.t.Helper()
|
||||
|
||||
resp, err := pt.proxyClient.Get(targetURL)
|
||||
require.NoError(pt.t, err, "Failed to make request via proxy")
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(pt.t, http.StatusForbidden, resp.StatusCode, "Expected 403 Forbidden status")
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Contains(pt.t, string(body), "Request Blocked by Boundary", "Expected request to be blocked")
|
||||
}
|
||||
|
||||
// ExpectAllowedViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
|
||||
// and expects it to be allowed with the given response body
|
||||
func (pt *ProxyTest) ExpectAllowedViaProxy(targetURL, expectedBody string) {
|
||||
pt.t.Helper()
|
||||
|
||||
resp, err := pt.proxyClient.Get(targetURL)
|
||||
require.NoError(pt.t, err, "Failed to make request via proxy")
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Equal(pt.t, expectedBody, string(body), "Expected response body does not match")
|
||||
}
|
||||
|
||||
// ExpectAllowedContainsViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS)
|
||||
// and expects it to be allowed, checking that response contains the given text
|
||||
func (pt *ProxyTest) ExpectAllowedContainsViaProxy(targetURL, containsText string) {
|
||||
pt.t.Helper()
|
||||
|
||||
resp, err := pt.proxyClient.Get(targetURL)
|
||||
require.NoError(pt.t, err, "Failed to make request via proxy")
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(pt.t, err, "Failed to read response body")
|
||||
|
||||
require.Contains(pt.t, string(body), containsText, "Response does not contain expected text")
|
||||
}
|
||||
|
||||
// explicitCONNECTTunnel represents an established CONNECT tunnel
|
||||
type explicitCONNECTTunnel struct {
|
||||
tlsConn *tls.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// establishExplicitCONNECT establishes a CONNECT tunnel and returns a tunnel object
|
||||
// targetHost should be in format "hostname:port" (e.g., "dev.coder.com:443")
|
||||
func (pt *ProxyTest) establishExplicitCONNECT(targetHost string) (*explicitCONNECTTunnel, error) {
|
||||
pt.t.Helper()
|
||||
|
||||
// Extract hostname for TLS ServerName (remove port if present)
|
||||
hostParts := strings.Split(targetHost, ":")
|
||||
serverName := hostParts[0]
|
||||
|
||||
// Connect to proxy
|
||||
conn, err := net.Dial("tcp", "localhost:"+strconv.Itoa(pt.port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send explicit CONNECT request
|
||||
connectReq := "CONNECT " + targetHost + " HTTP/1.1\r\n" +
|
||||
"Host: " + targetHost + "\r\n" +
|
||||
"\r\n"
|
||||
_, err = conn.Write([]byte(connectReq))
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read CONNECT response
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(reader, nil)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
conn.Close()
|
||||
return nil, xerrors.Errorf("CONNECT failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Wrap connection with TLS client
|
||||
tlsConn := tls.Client(conn, &tls.Config{
|
||||
InsecureSkipVerify: true, // G402: Skip cert verification for testing
|
||||
ServerName: serverName,
|
||||
})
|
||||
|
||||
// Perform TLS handshake
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &explicitCONNECTTunnel{
|
||||
tlsConn: tlsConn,
|
||||
reader: bufio.NewReader(tlsConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sendRequest sends an HTTP request over the tunnel and returns the response body
|
||||
func (tunnel *explicitCONNECTTunnel) sendRequest(targetHost, path string) ([]byte, error) {
|
||||
// Send HTTP request over the tunnel
|
||||
httpReq := "GET " + path + " HTTP/1.1\r\n" +
|
||||
"Host: " + targetHost + "\r\n" +
|
||||
"Connection: keep-alive\r\n" +
|
||||
"\r\n"
|
||||
_, err := tunnel.tlsConn.Write([]byte(httpReq))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read HTTP response
|
||||
httpResp, err := http.ReadResponse(tunnel.reader, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// sendRequestAndExpectDeny sends an HTTP request over the tunnel and expects it to be denied
|
||||
func (tunnel *explicitCONNECTTunnel) sendRequestAndExpectDeny(targetHost, path string) error {
|
||||
// Send HTTP request over the tunnel
|
||||
httpReq := "GET " + path + " HTTP/1.1\r\n" +
|
||||
"Host: " + targetHost + "\r\n" +
|
||||
"Connection: keep-alive\r\n" +
|
||||
"\r\n"
|
||||
_, err := tunnel.tlsConn.Write([]byte(httpReq))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read HTTP response
|
||||
httpResp, err := http.ReadResponse(tunnel.reader, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
if httpResp.StatusCode != http.StatusForbidden {
|
||||
return xerrors.Errorf("expected 403 Forbidden, got %d", httpResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), "Request Blocked by Boundary") {
|
||||
return xerrors.Errorf("expected blocked response, got: %s", string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// close closes the tunnel connection
|
||||
func (tunnel *explicitCONNECTTunnel) close() error {
|
||||
return tunnel.tlsConn.Close()
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic,noctx,bodyclose
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProxyServerBasicHTTP tests basic HTTP request handling
|
||||
func TestProxyServerBasicHTTP(t *testing.T) {
|
||||
pt := NewProxyTest(t,
|
||||
WithAllowedDomain("jsonplaceholder.typicode.com"),
|
||||
).
|
||||
Start()
|
||||
defer pt.Stop()
|
||||
|
||||
t.Run("BasicHTTPRequest", func(t *testing.T) {
|
||||
expectedResponse := `{
|
||||
"userId": 1,
|
||||
"id": 1,
|
||||
"title": "delectus aut autem",
|
||||
"completed": false
|
||||
}`
|
||||
pt.ExpectAllowed("http://localhost:8080/todos/1", "jsonplaceholder.typicode.com", expectedResponse)
|
||||
})
|
||||
|
||||
t.Run("BlockedHTTPRequest", func(t *testing.T) {
|
||||
pt.ExpectDeny("http://localhost:8080/", "example.com")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProxyServerBasicHTTPS tests basic HTTPS request handling
|
||||
func TestProxyServerBasicHTTPS(t *testing.T) {
|
||||
pt := NewProxyTest(t,
|
||||
WithCertManager("/tmp/boundary"),
|
||||
WithAllowedDomain("dev.coder.com"),
|
||||
).
|
||||
Start()
|
||||
defer pt.Stop()
|
||||
|
||||
t.Run("BasicHTTPSRequest", func(t *testing.T) {
|
||||
expectedResponse := `{"message":"👋"}
|
||||
`
|
||||
pt.ExpectAllowed("https://localhost:8080/api/v2", "dev.coder.com", expectedResponse)
|
||||
})
|
||||
|
||||
t.Run("BlockedHTTPSRequest", func(t *testing.T) {
|
||||
pt.ExpectDeny("https://localhost:8080/", "example.com")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package rulesengine
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
neturl "net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Engine evaluates HTTP requests against a set of rules.
|
||||
type Engine struct {
|
||||
rules []Rule
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRuleEngine creates a new rule engine
|
||||
func NewRuleEngine(rules []Rule, logger *slog.Logger) Engine {
|
||||
return Engine{
|
||||
rules: rules,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Result contains the result of rule evaluation
|
||||
type Result struct {
|
||||
Allowed bool
|
||||
Rule string // The rule that matched (if any)
|
||||
}
|
||||
|
||||
// Evaluate evaluates a request and returns both result and matching rule
|
||||
func (re *Engine) Evaluate(method, url string) Result {
|
||||
// Check if any allow rule matches
|
||||
for _, rule := range re.rules {
|
||||
if re.matches(rule, method, url) {
|
||||
return Result{
|
||||
Allowed: true,
|
||||
Rule: rule.Raw,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default deny if no allow rules match
|
||||
return Result{
|
||||
Allowed: false,
|
||||
Rule: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Matches checks if the rule matches the given method and URL using wildcard patterns
|
||||
func (re *Engine) matches(r Rule, method, url string) bool {
|
||||
// Check method patterns if they exist
|
||||
if r.MethodPatterns != nil {
|
||||
methodMatches := false
|
||||
for mp := range r.MethodPatterns {
|
||||
if string(mp) == method || mp == "*" {
|
||||
methodMatches = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !methodMatches {
|
||||
re.logger.Debug("rule does not match", "reason", "method pattern mismatch", "rule", r.Raw, "method", method, "url", url)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// If the provided url doesn't have a scheme parsing will fail. This can happen when you do something like `curl google.com`
|
||||
|
||||
if !strings.Contains(url, "://") {
|
||||
// This is just for parsing, we won't use the scheme.
|
||||
url = "https://" + url
|
||||
}
|
||||
parsedURL, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
re.logger.Debug("rule does not match", "reason", "invalid URL", "rule", r.Raw, "method", method, "url", url, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if r.HostPattern != nil {
|
||||
// For a host pattern to match, every label has to match or be an `*`.
|
||||
// Subdomains also match automatically, meaning if the pattern is "example.com"
|
||||
// and the real is "api.example.com", it should match. We check this by comparing
|
||||
// from the end of the actual hostname with the pattern (which is in normal order).
|
||||
|
||||
labels := strings.Split(parsedURL.Hostname(), ".")
|
||||
|
||||
// If the host pattern is longer than the actual host, it's definitely not a match
|
||||
if len(r.HostPattern) > len(labels) {
|
||||
re.logger.Debug("rule does not match", "reason", "host pattern too long", "rule", r.Raw, "method", method, "url", url, "pattern_length", len(r.HostPattern), "hostname_labels", len(labels))
|
||||
return false
|
||||
}
|
||||
|
||||
// Since host patterns cannot end with asterisk, we only need to handle:
|
||||
// "example.com" or "*.example.com" - match from the end (allowing subdomains)
|
||||
for i, lp := range r.HostPattern {
|
||||
labelIndex := len(labels) - len(r.HostPattern) + i
|
||||
if string(lp) != labels[labelIndex] && lp != "*" {
|
||||
re.logger.Debug("rule does not match", "reason", "host pattern label mismatch", "rule", r.Raw, "method", method, "url", url, "expected", string(lp), "actual", labels[labelIndex])
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if r.PathPattern != nil {
|
||||
segments := strings.Split(parsedURL.Path, "/")
|
||||
|
||||
// Skip the first empty segment if the path starts with "/"
|
||||
if len(segments) > 0 && segments[0] == "" {
|
||||
segments = segments[1:]
|
||||
}
|
||||
|
||||
// Check if any of the path patterns match
|
||||
pathMatches := false
|
||||
for _, pattern := range r.PathPattern {
|
||||
// If the path pattern is longer than the actual path, definitely not a match
|
||||
if len(pattern) > len(segments) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Each segment in the pattern must be either as asterisk or match the actual path segment
|
||||
patternMatches := true
|
||||
for i, sp := range pattern {
|
||||
if sp != segments[i] && sp != "*" {
|
||||
patternMatches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !patternMatches {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the path is longer than the path pattern, it should only match if:
|
||||
// 1. The pattern is empty (root path matches any path), OR
|
||||
// 2. The final segment of the pattern is an asterisk
|
||||
if len(segments) > len(pattern) && len(pattern) > 0 && pattern[len(pattern)-1] != "*" {
|
||||
continue
|
||||
}
|
||||
|
||||
pathMatches = true
|
||||
break
|
||||
}
|
||||
|
||||
if !pathMatches {
|
||||
re.logger.Debug("rule does not match", "reason", "no path pattern matches", "rule", r.Raw, "method", method, "url", url)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
re.logger.Debug("rule matches", "reason", "all patterns matched", "rule", r.Raw, "method", method, "url", url)
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
//nolint:paralleltest,testpackage
|
||||
package rulesengine
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEngineMatches(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
engine := NewRuleEngine(nil, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rule Rule
|
||||
method string
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
// Method pattern tests
|
||||
{
|
||||
name: "method matches exact",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"GET": {}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "method does not match",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"POST": {}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "method wildcard matches any",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"*": {}},
|
||||
},
|
||||
method: "PUT",
|
||||
url: "https://example.com/api",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no method pattern allows all methods",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"example", "com"},
|
||||
},
|
||||
method: "DELETE",
|
||||
url: "https://example.com/api",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Host pattern tests
|
||||
{
|
||||
name: "host matches exact",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"example", "com"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host does not match",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"example", "org"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "subdomain matches",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"example", "com"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://api.example.com/users",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "host pattern too long",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"v1", "api", "example", "com"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://api.example.com/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "host wildcard matches",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"*", "com"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://test.com/api",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple host wildcards",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"*", "*"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://api.example.com/users",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Path pattern tests
|
||||
{
|
||||
name: "path matches exact",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "users"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "path does not match",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "posts"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "subpath does not implicitly match",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users/123",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "asterisk matches in path",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "*"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users/123",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "one asterisk at end matches any number of trailing segments",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "*"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/foo/bar/baz",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "asterisk in middle of path only matches one segment",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "*", "foo"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users/admin/foo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "path pattern too long",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "v1", "users", "profile"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/v1/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "path wildcard matches",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"api", "*", "profile"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users/profile",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "multiple path wildcards",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{"*", "*"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users/123",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Combined pattern tests
|
||||
{
|
||||
name: "all patterns match",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"POST": {}},
|
||||
HostPattern: []string{"api", "com"},
|
||||
PathPattern: [][]string{{"users"}},
|
||||
},
|
||||
method: "POST",
|
||||
url: "https://api.com/users",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "method fails combined test",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"POST": {}},
|
||||
HostPattern: []string{"api", "com"},
|
||||
PathPattern: [][]string{{"users"}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://api.com/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "host fails combined test",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"POST": {}},
|
||||
HostPattern: []string{"api", "org"},
|
||||
PathPattern: [][]string{{"users"}},
|
||||
},
|
||||
method: "POST",
|
||||
url: "https://api.com/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "path fails combined test",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"POST": {}},
|
||||
HostPattern: []string{"api", "com"},
|
||||
PathPattern: [][]string{{"posts"}},
|
||||
},
|
||||
method: "POST",
|
||||
url: "https://api.com/users",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "all wildcards match",
|
||||
rule: Rule{
|
||||
MethodPatterns: map[string]struct{}{"*": {}},
|
||||
HostPattern: []string{"*", "*"},
|
||||
PathPattern: [][]string{{"*", "*"}},
|
||||
},
|
||||
method: "PATCH",
|
||||
url: "https://test.example.com/api/users/123",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty rule matches everything",
|
||||
rule: Rule{},
|
||||
method: "GET",
|
||||
url: "https://example.com/api/users",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"example", "com"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "not-a-valid-url",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "root path",
|
||||
rule: Rule{
|
||||
PathPattern: [][]string{{}},
|
||||
},
|
||||
method: "GET",
|
||||
url: "https://example.com/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost host",
|
||||
rule: Rule{
|
||||
HostPattern: []string{"localhost"},
|
||||
},
|
||||
method: "GET",
|
||||
url: "http://localhost:8080/api",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := engine.matches(tt.rule, tt.method, tt.url)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
//nolint:paralleltest,testpackage
|
||||
package rulesengine
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRoundTrip(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
rules []string
|
||||
url string
|
||||
method string
|
||||
expectParse bool
|
||||
expectMatch bool
|
||||
}{
|
||||
{
|
||||
name: "basic all three",
|
||||
rules: []string{"method=GET,HEAD domain=github.com path=/wibble/wobble"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "method rejects properly",
|
||||
rules: []string{"method=GET"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "POST",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "domain rejects properly",
|
||||
rules: []string{"domain=github.com"},
|
||||
url: "https://example.com/wibble/wobble",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "path rejects properly",
|
||||
rules: []string{"path=/wibble/wobble"},
|
||||
url: "https://github.com/different/path",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "multiple rules - one matches",
|
||||
rules: []string{"domain=github.com", "domain=example.com"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "method wildcard matches anything",
|
||||
rules: []string{"method=*"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "POST",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard matches anything",
|
||||
rules: []string{"domain=*"},
|
||||
url: "https://example.com/wibble/wobble",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path wildcard matches anything",
|
||||
rules: []string{"path=*"},
|
||||
url: "https://github.com/any/path/here",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "all three wildcards match anything",
|
||||
rules: []string{"method=* domain=* path=*"},
|
||||
url: "https://example.com/some/random/path",
|
||||
method: "DELETE",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "query parameters don't break matching",
|
||||
rules: []string{"domain=github.com path=/wibble/wobble"},
|
||||
url: "https://github.com/wibble/wobble?param1=value1¶m2=value2",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard segment matches",
|
||||
rules: []string{"domain=*.github.com"},
|
||||
url: "https://api.github.com/repos",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain cannot end with asterisk",
|
||||
rules: []string{"domain=github.*"},
|
||||
url: "https://github.com/repos",
|
||||
method: "GET",
|
||||
expectParse: false,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "domain asterisk in middle matches",
|
||||
rules: []string{"domain=github.*.com"},
|
||||
url: "https://github.api.com/repos",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard matches multiple subdomains",
|
||||
rules: []string{"domain=*.github.com"},
|
||||
url: "https://v1.api.github.com/repos",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path asterisk in middle matches",
|
||||
rules: []string{"path=/api/*/users"},
|
||||
url: "https://github.com/api/v1/users",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path asterisk at start matches",
|
||||
rules: []string{"path=/*/users"},
|
||||
url: "https://github.com/api/users",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path asterisk doesn't match multiple segments",
|
||||
rules: []string{"path=/api/*/users"},
|
||||
url: "https://github.com/api/../admin/users",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "path asterisk at end matches",
|
||||
rules: []string{"path=/api/v1/*"},
|
||||
url: "https://github.com/api/v1/users",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "path asterisk at end matches multiple segments",
|
||||
rules: []string{"path=/api/*"},
|
||||
url: "https://github.com/api/v1/users/123/details",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "subpaths do not match automatically",
|
||||
rules: []string{"path=/api"},
|
||||
url: "https://github.com/api/users",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "multiple rules match specific path and subpaths",
|
||||
rules: []string{"path=/wibble/wobble,/wibble/wobble/*"},
|
||||
url: "https://github.com/wibble/wobble/sub",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain matches without scheme - example.com case",
|
||||
rules: []string{"domain=example.com"},
|
||||
url: "example.com",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain matches without scheme - jsonplaceholder case",
|
||||
rules: []string{"domain=jsonplaceholder.typicode.com"},
|
||||
url: "jsonplaceholder.typicode.com",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain matches without scheme - dev.coder.com case",
|
||||
rules: []string{"domain=dev.coder.com"},
|
||||
url: "dev.coder.com",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
logHandler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
})
|
||||
logger := slog.New(logHandler)
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rules, err := ParseAllowSpecs(tc.rules)
|
||||
if tc.expectParse {
|
||||
require.Nil(t, err)
|
||||
engine := NewRuleEngine(rules, logger)
|
||||
result := engine.Evaluate(tc.method, tc.url)
|
||||
require.Equal(t, tc.expectMatch, result.Allowed)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripExtraRules(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
rules []string
|
||||
url string
|
||||
method string
|
||||
expectParse bool
|
||||
expectMatch bool
|
||||
}{
|
||||
{
|
||||
name: "domain=* allows everything",
|
||||
rules: []string{"domain=*"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "DELETE",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "specifying port in Domain key is NOT allowed",
|
||||
rules: []string{"domain=github.com:8080"},
|
||||
url: "https://github.com/wibble/wobble",
|
||||
method: "DELETE",
|
||||
expectParse: false,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "specifying port in URL is allowed",
|
||||
rules: []string{"domain=github.com"},
|
||||
url: "https://github.com:8080/wibble/wobble",
|
||||
method: "DELETE",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard symbol at the end of path",
|
||||
rules: []string{"method=GET,POST,PUT domain=github.com path=/api/issues/*"},
|
||||
url: "https://github.com/api/issues/123/edit",
|
||||
method: "POST",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard symbol at the end of path doesn't match base path",
|
||||
rules: []string{"method=GET domain=github.com path=/api/issues/*"},
|
||||
url: "https://github.com/api/issues",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
{
|
||||
name: "includes all subdomains by default",
|
||||
rules: []string{"domain=github.com"},
|
||||
url: "https://x.users.api.github.com",
|
||||
method: "GET",
|
||||
expectParse: true,
|
||||
expectMatch: true,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard in the middle matches exactly one label",
|
||||
rules: []string{"domain=api.*.com"},
|
||||
url: "https://api.v1.github.com",
|
||||
method: "POST",
|
||||
expectParse: true,
|
||||
expectMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
logHandler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
})
|
||||
logger := slog.New(logHandler)
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rules, err := ParseAllowSpecs(tc.rules)
|
||||
if tc.expectParse {
|
||||
require.Nil(t, err)
|
||||
engine := NewRuleEngine(rules, logger)
|
||||
result := engine.Evaluate(tc.method, tc.url)
|
||||
require.Equal(t, tc.expectMatch, result.Allowed)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package rulesengine
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Rule represents an allow rule passed to the cli with --allow or read from the config file.
|
||||
// Rules have a specific grammar that we need to parse carefully.
|
||||
// Example: --allow="method=GET,PATCH domain=wibble.wobble.com, path=/posts/*"
|
||||
type Rule struct {
|
||||
// The path patterns that can match for this rule.
|
||||
// - nil means all paths allowed
|
||||
// - Each []string represents a path pattern (list of segments)
|
||||
// - a path segment of `*` acts as a wild card.
|
||||
PathPattern [][]string
|
||||
|
||||
// The labels of the host, i.e. ["google", "com"].
|
||||
// - nil means all hosts allowed
|
||||
// - A label of `*` acts as a wild card.
|
||||
// - subdomains automatically match
|
||||
HostPattern []string
|
||||
|
||||
// The allowed http methods.
|
||||
// - nil means all methods allowed
|
||||
MethodPatterns map[string]struct{}
|
||||
|
||||
// Raw rule string for logging
|
||||
Raw string
|
||||
}
|
||||
|
||||
// ParseAllowSpecs parses a slice of --allow specs into allow Rules.
|
||||
func ParseAllowSpecs(allowStrings []string) ([]Rule, error) {
|
||||
var out []Rule
|
||||
for _, s := range allowStrings {
|
||||
r, err := parseAllowRule(s)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse allow '%s': %v", s, err)
|
||||
}
|
||||
out = append(out, r)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// parseAllowRule takes an allow rule string and tries to parse it as a rule.
|
||||
func parseAllowRule(ruleStr string) (Rule, error) {
|
||||
rule := Rule{
|
||||
Raw: ruleStr,
|
||||
}
|
||||
|
||||
// Functions called by this function used a really common pattern: recursive descent parsing.
|
||||
// All the helper functions for parsing an allow rule will be called like `thing, rest, err := parseThing(rest)`.
|
||||
// What's going on here is that we try to parse some expected text from the front of the string.
|
||||
// If we succeed, we get back the thing we parsed and the remaining text. If we fail, we get back a non nil error.
|
||||
rest := ruleStr
|
||||
var key string
|
||||
var err error
|
||||
|
||||
// Ann allow rule can have as many key=value pairs as needed, we go until there's no more text in the rule.
|
||||
for rest != "" {
|
||||
// Parse the key
|
||||
key, rest, err = parseKey(rest)
|
||||
if err != nil {
|
||||
return Rule{}, xerrors.Errorf("failed to parse key: %v", err)
|
||||
}
|
||||
|
||||
// Parse the value based on the key type
|
||||
switch key {
|
||||
case "method":
|
||||
// Initialize Methods map if needed
|
||||
if rule.MethodPatterns == nil {
|
||||
rule.MethodPatterns = make(map[string]struct{})
|
||||
}
|
||||
|
||||
var method string
|
||||
for {
|
||||
method, rest, err = parseMethodPattern(rest)
|
||||
if err != nil {
|
||||
return Rule{}, xerrors.Errorf("failed to parse method: %v", err)
|
||||
}
|
||||
|
||||
rule.MethodPatterns[method] = struct{}{}
|
||||
|
||||
// Check if there's a comma for more methods
|
||||
if rest != "" && rest[0] == ',' {
|
||||
rest = rest[1:] // Skip the comma
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
case "domain":
|
||||
var host []string
|
||||
host, rest, err = parseHostPattern(rest)
|
||||
if err != nil {
|
||||
return Rule{}, xerrors.Errorf("failed to parse domain: %v", err)
|
||||
}
|
||||
|
||||
// Convert labels to strings
|
||||
rule.HostPattern = append(rule.HostPattern, host...)
|
||||
|
||||
case "path":
|
||||
for {
|
||||
var segments []string
|
||||
segments, rest, err = parsePathPattern(rest)
|
||||
if err != nil {
|
||||
return Rule{}, xerrors.Errorf("failed to parse path: %v", err)
|
||||
}
|
||||
|
||||
// Add this path pattern to the list of patterns
|
||||
rule.PathPattern = append(rule.PathPattern, segments)
|
||||
|
||||
// Check if there's a comma for more paths
|
||||
if rest != "" && rest[0] == ',' {
|
||||
rest = rest[1:] // Skip the comma
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
default:
|
||||
return Rule{}, xerrors.Errorf("unknown key: %s", key)
|
||||
}
|
||||
|
||||
// Skip whitespace separators (only support mac and linux so \r\n shouldn't be a thing)
|
||||
for rest != "" && (rest[0] == ' ' || rest[0] == '\t' || rest[0] == '\n') {
|
||||
rest = rest[1:]
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// Beyond the 9 methods defined in HTTP 1.1, there actually are many more seldom used extension methods by
|
||||
// various systems.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.6
|
||||
func parseMethodPattern(token string) (string, string, error) {
|
||||
if token == "" {
|
||||
return "", "", xerrors.New("expected http token, got empty string")
|
||||
}
|
||||
|
||||
// Find the first invalid HTTP token character
|
||||
for i := 0; i < len(token); i++ {
|
||||
if !isHTTPTokenChar(token[i]) {
|
||||
return token[:i], token[i:], nil
|
||||
}
|
||||
}
|
||||
|
||||
// Entire string is a valid HTTP token
|
||||
return token, "", nil
|
||||
}
|
||||
|
||||
// The valid characters that can be in an http token (like the lexer/parser kind of token).
|
||||
func isHTTPTokenChar(c byte) bool {
|
||||
switch {
|
||||
// Alpha numeric is fine.
|
||||
case c >= 'A' && c <= 'Z':
|
||||
return true
|
||||
case c >= 'a' && c <= 'z':
|
||||
return true
|
||||
case c >= '0' && c <= '9':
|
||||
return true
|
||||
|
||||
// These special characters are also allowed unbelievably.
|
||||
case c == '!' || c == '#' || c == '$' || c == '%' || c == '&' ||
|
||||
c == '\'' || c == '*' || c == '+' || c == '-' || c == '.' ||
|
||||
c == '^' || c == '_' || c == '`' || c == '|' || c == '~':
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Represents a valid host.
|
||||
// https://datatracker.ietf.org/doc/html/rfc952
|
||||
// https://datatracker.ietf.org/doc/html/rfc1123#page-13
|
||||
func parseHostPattern(input string) ([]string, string, error) {
|
||||
rest := input
|
||||
var host []string
|
||||
var err error
|
||||
|
||||
if input == "" {
|
||||
return nil, "", xerrors.New("expected host, got empty string")
|
||||
}
|
||||
|
||||
// There should be at least one label.
|
||||
var label string
|
||||
label, rest, err = parseLabelPattern(rest)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
host = append(host, label)
|
||||
|
||||
// A host is just a bunch of labels separated by `.` characters.
|
||||
var found bool
|
||||
for {
|
||||
rest, found = strings.CutPrefix(rest, ".")
|
||||
if !found {
|
||||
break
|
||||
}
|
||||
|
||||
label, rest, err = parseLabelPattern(rest)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
host = append(host, label)
|
||||
}
|
||||
|
||||
// If the host is a single standalone asterisk, that's the same as "matches anything"
|
||||
if len(host) == 1 && host[0] == "*" {
|
||||
return host, rest, nil
|
||||
}
|
||||
|
||||
// Validate: host patterns other than a single `*` cannot end with asterisk
|
||||
if len(host) > 0 && host[len(host)-1] == "*" {
|
||||
return nil, "", xerrors.New("host patterns cannot end with asterisk")
|
||||
}
|
||||
|
||||
return host, rest, nil
|
||||
}
|
||||
|
||||
func parseLabelPattern(rest string) (string, string, error) {
|
||||
if rest == "" {
|
||||
return "", "", xerrors.New("expected label, got empty string")
|
||||
}
|
||||
|
||||
// If the label is simply an asterisk, good to go.
|
||||
if rest[0] == '*' {
|
||||
return "*", rest[1:], nil
|
||||
}
|
||||
|
||||
// First try to get a valid leading char. Leading char in a label cannot be a hyphen.
|
||||
if !isValidLabelChar(rest[0]) || rest[0] == '-' {
|
||||
return "", "", xerrors.Errorf("could not pull label from front of string: %s", rest)
|
||||
}
|
||||
|
||||
// Go until the next character is not a valid char
|
||||
var i int
|
||||
for i = 1; i < len(rest) && isValidLabelChar(rest[i]); i++ {
|
||||
}
|
||||
|
||||
// Final char in a label cannot be a hyphen.
|
||||
if rest[i-1] == '-' {
|
||||
return "", "", xerrors.Errorf("invalid label: %s", rest[:i])
|
||||
}
|
||||
|
||||
return rest[:i], rest[i:], nil
|
||||
}
|
||||
|
||||
func isValidLabelChar(c byte) bool {
|
||||
switch {
|
||||
// Alpha numeric is fine.
|
||||
case c >= 'A' && c <= 'Z':
|
||||
return true
|
||||
case c >= 'a' && c <= 'z':
|
||||
return true
|
||||
case c >= '0' && c <= '9':
|
||||
return true
|
||||
|
||||
// Hyphens are good
|
||||
case c == '-':
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// https://myfileserver.com/"my file"
|
||||
|
||||
func parsePathPattern(input string) ([]string, string, error) {
|
||||
if input == "" {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
rest := input
|
||||
var segments []string
|
||||
var err error
|
||||
|
||||
// If the path doesn't start with '/', it's not a valid absolute path
|
||||
// But we'll be flexible and parse relative paths too
|
||||
for {
|
||||
// Skip leading slash if present
|
||||
if rest != "" && rest[0] == '/' {
|
||||
rest = rest[1:]
|
||||
}
|
||||
|
||||
// If we've consumed all input, we're done
|
||||
if rest == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse the next segment
|
||||
var segment string
|
||||
segment, rest, err = parsePathSegmentPattern(rest)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// If we got an empty segment and there's still input,
|
||||
// it means we hit an invalid character
|
||||
if segment == "" && rest != "" {
|
||||
break
|
||||
}
|
||||
|
||||
segments = append(segments, segment)
|
||||
|
||||
// If there's no slash after the segment, we're done parsing the path
|
||||
if rest == "" || rest[0] != '/' {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return segments, rest, nil
|
||||
}
|
||||
|
||||
func parsePathSegmentPattern(input string) (string, string, error) {
|
||||
if input == "" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
if len(input) > 0 && input[0] == '*' {
|
||||
if len(input) > 1 && input[1] != '/' {
|
||||
return "", "", xerrors.Errorf("path segment wildcards must be for the entire segment, got: %s", input)
|
||||
}
|
||||
|
||||
return "*", input[1:], nil
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 0; i < len(input); i++ {
|
||||
c := input[i]
|
||||
|
||||
// Check for percent-encoded characters (%XX)
|
||||
if c == '%' {
|
||||
if i+2 >= len(input) || !isHexDigit(input[i+1]) || !isHexDigit(input[i+2]) {
|
||||
break
|
||||
}
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for valid pchar characters
|
||||
if !isPChar(c) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return input[:i], input[i:], nil
|
||||
}
|
||||
|
||||
// isUnreserved returns true if the character is unreserved per RFC 3986
|
||||
// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
|
||||
func isUnreserved(c byte) bool {
|
||||
return (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '-' || c == '.' || c == '_' || c == '~'
|
||||
}
|
||||
|
||||
// isPChar returns true if the character is valid in a path segment (excluding percent-encoded)
|
||||
// pchar = unreserved / sub-delims / ":" / "@"
|
||||
// Note: We exclude comma from sub-delims for our rule parsing to support comma-separated paths
|
||||
func isPChar(c byte) bool {
|
||||
return isUnreserved(c) || isSubDelimExceptComma(c) || c == ':' || c == '@'
|
||||
}
|
||||
|
||||
// isSubDelimExceptComma returns true if the character is a sub-delimiter except comma
|
||||
func isSubDelimExceptComma(c byte) bool {
|
||||
return c == '!' || c == '$' || c == '&' || c == '\'' ||
|
||||
c == '(' || c == ')' || c == '*' || c == '+' ||
|
||||
c == ';' || c == '='
|
||||
}
|
||||
|
||||
// isHexDigit returns true if the character is a hexadecimal digit
|
||||
func isHexDigit(c byte) bool {
|
||||
return (c >= '0' && c <= '9') ||
|
||||
(c >= 'A' && c <= 'F') ||
|
||||
(c >= 'a' && c <= 'f')
|
||||
}
|
||||
|
||||
// parseKey parses the predefined keys that the cli can handle. Also strips the `=` following the key.
|
||||
func parseKey(rule string) (string, string, error) {
|
||||
if rule == "" {
|
||||
return "", "", xerrors.New("expected key")
|
||||
}
|
||||
|
||||
// These are the current keys we support.
|
||||
keys := []string{"method", "domain", "path"}
|
||||
|
||||
for _, key := range keys {
|
||||
if rest, found := strings.CutPrefix(rule, key+"="); found {
|
||||
return key, rest, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", xerrors.New("expected key")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,24 @@
|
||||
//go:build linux
|
||||
|
||||
package run
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/landjail"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/nsjail_manager"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, logger *slog.Logger, cfg config.AppConfig) error {
|
||||
switch cfg.JailType {
|
||||
case config.NSJailType:
|
||||
return nsjail_manager.Run(ctx, logger, cfg)
|
||||
case config.LandjailType:
|
||||
return landjail.Run(ctx, logger, cfg)
|
||||
default:
|
||||
return fmt.Errorf("unknown jail type: %s", cfg.JailType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary/config"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
SetupTLSAndWriteCACert() (*tls.Config, string, string, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Logger *slog.Logger
|
||||
ConfigDir string
|
||||
Uid int
|
||||
Gid int
|
||||
}
|
||||
|
||||
// CertificateManager manages TLS certificates for the proxy
|
||||
type CertificateManager struct {
|
||||
caKey *rsa.PrivateKey
|
||||
caCert *x509.Certificate
|
||||
certCache map[string]*tls.Certificate
|
||||
mutex sync.RWMutex
|
||||
logger *slog.Logger
|
||||
configDir string
|
||||
uid int
|
||||
gid int
|
||||
}
|
||||
|
||||
// NewCertificateManager creates a new certificate manager
|
||||
func NewCertificateManager(config Config) (*CertificateManager, error) {
|
||||
cm := &CertificateManager{
|
||||
certCache: make(map[string]*tls.Certificate),
|
||||
logger: config.Logger,
|
||||
configDir: config.ConfigDir,
|
||||
uid: config.Uid,
|
||||
gid: config.Gid,
|
||||
}
|
||||
|
||||
// Load or generate CA certificate
|
||||
err := cm.loadOrGenerateCA()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to load or generate CA: %v", err)
|
||||
}
|
||||
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
// SetupTLSAndWriteCACert sets up TLS config and writes CA certificate to file
|
||||
// Returns the TLS config, CA cert path, and config directory
|
||||
func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, error) {
|
||||
// Get TLS config
|
||||
tlsConfig := cm.getTLSConfig()
|
||||
|
||||
// Get CA certificate PEM
|
||||
caCertPEM, err := cm.getCACertPEM()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to get CA certificate: %v", err)
|
||||
}
|
||||
|
||||
// Write CA certificate to file
|
||||
caCertPath := filepath.Join(cm.configDir, config.CACertName)
|
||||
err = os.WriteFile(caCertPath, caCertPEM, 0o600)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to write CA certificate file: %v", err)
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// loadOrGenerateCA loads existing CA or generates a new one
|
||||
func (cm *CertificateManager) loadOrGenerateCA() error {
|
||||
caKeyPath := filepath.Join(cm.configDir, config.CAKeyName)
|
||||
caCertPath := filepath.Join(cm.configDir, config.CACertName)
|
||||
|
||||
cm.logger.Debug("paths", "cm.configDir", cm.configDir, "caCertPath", caCertPath)
|
||||
|
||||
// Try to load existing CA
|
||||
if cm.loadExistingCA(caKeyPath, caCertPath) {
|
||||
cm.logger.Debug("Loaded existing CA certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate new CA
|
||||
cm.logger.Info("Generating new CA certificate")
|
||||
return cm.generateCA(caKeyPath, caCertPath)
|
||||
}
|
||||
|
||||
// getTLSConfig returns a TLS config that generates certificates on-demand
|
||||
func (cm *CertificateManager) getTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: cm.getCertificate,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
// getCACertPEM returns the CA certificate in PEM format
|
||||
func (cm *CertificateManager) getCACertPEM() ([]byte, error) {
|
||||
return pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cm.caCert.Raw,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// loadExistingCA attempts to load existing CA files
|
||||
func (cm *CertificateManager) loadExistingCA(keyPath, certPath string) bool {
|
||||
// Check if files exist
|
||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(certPath); os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Load private key
|
||||
keyData, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to read CA key", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode(keyData)
|
||||
if keyBlock == nil {
|
||||
cm.logger.Warn("Failed to decode CA key PEM")
|
||||
return false
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to parse CA private key", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Load certificate
|
||||
certData, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to read CA cert", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
certBlock, _ := pem.Decode(certData)
|
||||
if certBlock == nil {
|
||||
cm.logger.Warn("Failed to decode CA cert PEM")
|
||||
return false
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to parse CA certificate", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if certificate is still valid
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
cm.logger.Warn("CA certificate has expired")
|
||||
return false
|
||||
}
|
||||
|
||||
cm.caKey = privateKey
|
||||
cm.caCert = cert
|
||||
return true
|
||||
}
|
||||
|
||||
// generateCA generates a new CA certificate and key
|
||||
func (cm *CertificateManager) generateCA(keyPath, certPath string) error {
|
||||
// Create config directory if it doesn't exist
|
||||
err := os.MkdirAll(cm.configDir, 0o700)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create config directory at %s: %v", cm.configDir, err)
|
||||
}
|
||||
|
||||
// ensure the directory is owned by the original user
|
||||
err = os.Chown(cm.configDir, cm.uid, cm.gid)
|
||||
if err != nil {
|
||||
cm.logger.Warn("Failed to change config directory ownership", "error", err)
|
||||
}
|
||||
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"coder"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{""},
|
||||
StreetAddress: []string{""},
|
||||
PostalCode: []string{""},
|
||||
CommonName: "coder CA",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
// Parse certificate
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
// Save private key
|
||||
keyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create key file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := keyFile.Close()
|
||||
if err != nil {
|
||||
cm.logger.Error("Failed to close key file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = pem.Encode(keyFile, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write key to file: %v", err)
|
||||
}
|
||||
|
||||
// Save certificate
|
||||
certFile, err := os.OpenFile(certPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create cert file: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := certFile.Close()
|
||||
if err != nil {
|
||||
cm.logger.Error("Failed to close cert file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = pem.Encode(certFile, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write cert to file: %v", err)
|
||||
}
|
||||
|
||||
cm.caKey = privateKey
|
||||
cm.caCert = cert
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCertificate generates or retrieves a certificate for the given hostname
|
||||
func (cm *CertificateManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
hostname := hello.ServerName
|
||||
if hostname == "" {
|
||||
return nil, xerrors.New("no server name provided")
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
cm.mutex.RLock()
|
||||
if cert, exists := cm.certCache[hostname]; exists {
|
||||
cm.mutex.RUnlock()
|
||||
return cert, nil
|
||||
}
|
||||
cm.mutex.RUnlock()
|
||||
|
||||
// Generate new certificate
|
||||
cm.mutex.Lock()
|
||||
defer cm.mutex.Unlock()
|
||||
|
||||
// Double-check cache (another goroutine might have generated it)
|
||||
if cert, exists := cm.certCache[hostname]; exists {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
cert, err := cm.generateServerCertificate(hostname)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to generate certificate for %s: %v", hostname, err)
|
||||
}
|
||||
|
||||
cm.certCache[hostname] = cert
|
||||
cm.logger.Debug("Generated certificate", "hostname", hostname)
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// generateServerCertificate generates a server certificate for the given hostname
|
||||
func (cm *CertificateManager) generateServerCertificate(hostname string) (*tls.Certificate, error) {
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"coder"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{""},
|
||||
StreetAddress: []string{""},
|
||||
PostalCode: []string{""},
|
||||
CommonName: hostname,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour), // 1 day
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{hostname},
|
||||
}
|
||||
|
||||
// Add IP address if hostname is an IP
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
template.IPAddresses = []net.IP{ip}
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, cm.caCert, &privateKey.PublicKey, cm.caKey)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{certDER},
|
||||
PrivateKey: privateKey,
|
||||
}
|
||||
|
||||
cm.logger.Debug("Generated certificate", "hostname", hostname)
|
||||
|
||||
return tlsCert, nil
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
//nolint:paralleltest,testpackage,revive,gocritic
|
||||
package tls
|
||||
|
||||
import "testing"
|
||||
|
||||
// Stub test file - tests removed
|
||||
func TestStub(t *testing.T) {
|
||||
// This is a stub test
|
||||
t.Skip("stub test file")
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//nolint:revive,gocritic,errname,unconvert
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
func MergeEnvs(base []string, extra map[string]string) []string {
|
||||
envMap := make(map[string]string)
|
||||
for _, env := range base {
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
envMap[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range extra {
|
||||
envMap[key] = value
|
||||
}
|
||||
|
||||
merged := make([]string, 0, len(envMap))
|
||||
for key, value := range envMap {
|
||||
merged = append(merged, key+"="+value)
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/enterprise/cli/boundary"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (*RootCmd) boundary() *serpent.Command {
|
||||
cmd := boundary.BaseCommand(buildinfo.Version())
|
||||
cmd.Use += " [args...]" // The base command looks like `boundary -- command`. Serpent adds the flags piece, but we need to add the args.
|
||||
return cmd
|
||||
}
|
||||
@@ -29,8 +29,10 @@ func (r *RootCmd) enterpriseOnly() []*serpent.Command {
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) enterpriseExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{}
|
||||
func (r *RootCmd) enterpriseExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.boundary(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) EnterpriseSubcommands() []*serpent.Command {
|
||||
|
||||
@@ -453,7 +453,7 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
howett.net/plist v1.0.0 // indirect
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73 // indirect
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 // indirect
|
||||
sigs.k8s.io/yaml v1.5.0 // indirect
|
||||
)
|
||||
|
||||
@@ -472,10 +472,10 @@ require (
|
||||
require (
|
||||
github.com/anthropics/anthropic-sdk-go v1.19.0
|
||||
github.com/brianvoe/gofakeit/v7 v7.14.0
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225
|
||||
github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1
|
||||
github.com/coder/aisdk-go v0.0.9
|
||||
github.com/coder/boundary v0.0.1-alpha
|
||||
github.com/coder/preview v1.0.4
|
||||
github.com/danieljoos/wincred v1.2.3
|
||||
github.com/dgraph-io/ristretto/v2 v2.3.0
|
||||
@@ -483,6 +483,7 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-git/go-git/v5 v5.16.2
|
||||
github.com/icholy/replace v0.6.0
|
||||
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c
|
||||
github.com/mark3labs/mcp-go v0.38.0
|
||||
gonum.org/v1/gonum v0.17.0
|
||||
)
|
||||
@@ -516,7 +517,6 @@ require (
|
||||
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect
|
||||
github.com/bits-and-blooms/bitset v1.24.4 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -931,8 +931,6 @@ github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1 h1:cr2K36NgU1fHKt
|
||||
github.com/coder/aibridge v0.3.1-0.20260105111716-7535a71e91a1/go.mod h1:5Ztcl+9HF0tog85iEEuFdaBkBe8EkxJe5XjbMOFviQs=
|
||||
github.com/coder/aisdk-go v0.0.9 h1:Vzo/k2qwVGLTR10ESDeP2Ecek1SdPfZlEjtTfMveiVo=
|
||||
github.com/coder/aisdk-go v0.0.9/go.mod h1:KF6/Vkono0FJJOtWtveh5j7yfNrSctVTpwgweYWSp5M=
|
||||
github.com/coder/boundary v0.0.1-alpha h1:6shUQ2zkrWrfbgVcqWvpV2ibljOQvPvYqTctWBkKoUA=
|
||||
github.com/coder/boundary v0.0.1-alpha/go.mod h1:d1AMFw81rUgrGHuZzWdPNhkY0G8w7pvLNLYF0e3ceC4=
|
||||
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41 h1:SBN/DA63+ZHwuWwPHPYoCZ/KLAjHv5g4h2MS4f2/MTI=
|
||||
github.com/coder/bubbletea v1.2.2-0.20241212190825-007a1cdb2c41/go.mod h1:I9ULxr64UaOSUv7hcb3nX4kowodJCVS7vt7VVJk/kW4=
|
||||
github.com/coder/clistat v1.2.0 h1:37KJKqiCllJsRvWqTHf3qiLIXX0JB6oqE5oxcqgdLkY=
|
||||
@@ -1544,6 +1542,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/kyokomi/emoji/v2 v2.2.13 h1:GhTfQa67venUUvmleTNFnb+bi7S3aocF7ZCXU9fSO7U=
|
||||
github.com/kyokomi/emoji/v2 v2.2.13/go.mod h1:JUcn42DTdsXJo1SWanHh4HKDEyPaR5CqkmoirZZP9qE=
|
||||
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c h1:QcKqiunpt7hooa/xIx0iyepA6Cs2BgKexaYOxHvHNCs=
|
||||
github.com/landlock-lsm/go-landlock v0.0.0-20251103212306-430f8e5cd97c/go.mod h1:stwyhp9tfeEy3A4bRJLdOEvjW/CetRJg/vcijNG8M5A=
|
||||
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
|
||||
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
@@ -2845,8 +2845,9 @@ k8s.io/utils v0.0.0-20241210054802-24370beab758 h1:sdbE21q2nlQtFh65saZY+rRM6x6aJ
|
||||
k8s.io/utils v0.0.0-20241210054802-24370beab758/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
|
||||
kernel.org/pub/linux/libs/security/libcap/cap v1.2.73 h1:Th2b8jljYqkyZKS3aD3N9VpYsQpHuXLgea+SZUIfODA=
|
||||
kernel.org/pub/linux/libs/security/libcap/cap v1.2.73/go.mod h1:hbeKwKcboEsxARYmcy/AdPVN11wmT/Wnpgv4k4ftyqY=
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73 h1:SEAEUiPVylTD4vqqi+vtGkSnXeP2FcRO3FoZB1MklMw=
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.73/go.mod h1:+l6Ee2F59XiJ2I6WR5ObpC1utCQJZ/VLsEbQCD8RG24=
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77 h1:Z06sMOzc0GNCwp6efaVrIrz4ywGJ1v+DP0pjVkOfDuA=
|
||||
kernel.org/pub/linux/libs/security/libcap/psx v1.2.77/go.mod h1:+l6Ee2F59XiJ2I6WR5ObpC1utCQJZ/VLsEbQCD8RG24=
|
||||
lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
|
||||
lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk=
|
||||
modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI=
|
||||
|
||||
Reference in New Issue
Block a user